Source code for texar.torch.data.data.mono_text_data

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mono text data class that define data reading, parsing, batching, and other
preprocessing operations.
"""
from enum import Enum
from typing import List, Optional

import torch

from texar.torch.data.data.data_base import DataSource
from texar.torch.data.data.dataset_utils import Batch, padded_batch
from texar.torch.data.data.text_data_base import (
    TextDataBase, TextLineDataSource)
from texar.torch.data.embedding import Embedding
from texar.torch.data.vocabulary import SpecialTokens, Vocab
from texar.torch.hyperparams import HParams
from texar.torch.utils import utils

__all__ = [
    "_default_mono_text_dataset_hparams",
    "MonoTextData",
]


class _LengthFilterMode(Enum):
    r"""Options of length filter mode.
    """
    TRUNC = "truncate"
    DISCARD = "discard"


def _default_mono_text_dataset_hparams():
    r"""Returns hyperparameters of a mono text dataset with default values.

    See :meth:`texar.torch.MonoTextData.default_hparams` for details.
    """
    return {
        "files": [],
        "compression_type": None,
        "vocab_file": "",
        "embedding_init": Embedding.default_hparams(),
        "delimiter": None,
        "max_seq_length": None,
        "length_filter_mode": "truncate",
        "pad_to_max_seq_length": False,
        "bos_token": SpecialTokens.BOS,
        "eos_token": SpecialTokens.EOS,
        "other_transformations": [],
        "variable_utterance": False,
        "utterance_delimiter": "|||",
        "max_utterance_cnt": 5,
        "data_name": None,
        "@no_typecheck": ["files"]
    }


# todo(avinash): Add variable utterance logic
[docs]class MonoTextData(TextDataBase[List[str], List[str]]): r"""Text data processor that reads single set of text files. This can be used for, e.g., language models, auto-encoders, etc. Args: hparams: A `dict` or instance of :class:`~texar.torch.HParams` containing hyperparameters. See :meth:`default_hparams` for the defaults. device: The device of the produced batches. For GPU training, set to current CUDA device. By default, the processor reads raw data files, performs tokenization, batching and other pre-processing steps, and results in a Dataset whose element is a python `dict` including three fields: "text": A list of ``[batch_size]`` elements each containing a list of **raw** text tokens of the sequences. Short sequences in the batch are padded with **empty string**. By default only ``EOS`` token is appended to each sequence. Out-of-vocabulary tokens are **NOT** replaced with ``UNK``. "text_ids": A list of ``[batch_size]`` elements each containing a list of token indexes of source sequences in the batch. "length": A list of ``[batch_size]`` elements of integers containing the length of each source sequence in the batch (including ``BOS`` and ``EOS`` if added). The above field names can be accessed through :attr:`text_name`, :attr:`text_id_name`, :attr:`length_name`. Example: .. code-block:: python hparams={ 'dataset': { 'files': 'data.txt', 'vocab_file': 'vocab.txt' }, 'batch_size': 1 } data = MonoTextData(hparams) iterator = DataIterator(data) for batch in iterator: # batch contains the following # batch_ == { # 'text': [['<BOS>', 'example', 'sequence', '<EOS>']], # 'text_ids': [[1, 5, 10, 2]], # 'length': [4] # } """ _delimiter: Optional[str] _bos: Optional[str] _eos: Optional[str] _max_seq_length: Optional[int] _should_pad: bool def __init__(self, hparams, device: Optional[torch.device] = None, vocab: Optional[Vocab] = None, embedding: Optional[Embedding] = None, data_source: Optional[DataSource] = None): self._hparams = HParams(hparams, self.default_hparams()) if self._hparams.dataset.variable_utterance: raise NotImplementedError # Create vocabulary self._bos_token = self._hparams.dataset.bos_token self._eos_token = self._hparams.dataset.eos_token self._other_transforms = self._hparams.dataset.other_transformations bos = utils.default_str(self._bos_token, SpecialTokens.BOS) eos = utils.default_str(self._eos_token, SpecialTokens.EOS) if vocab is None: self._vocab = Vocab(self._hparams.dataset.vocab_file, bos_token=bos, eos_token=eos) else: self._vocab = vocab # Create embedding if embedding is not None: self._embedding = self.make_embedding( self._hparams.dataset.embedding_init, self._vocab.token_to_id_map_py) else: self._embedding = embedding self._delimiter = self._hparams.dataset.delimiter self._max_seq_length = self._hparams.dataset.max_seq_length self._length_filter_mode = _LengthFilterMode( self._hparams.dataset.length_filter_mode) self._pad_length = self._max_seq_length if self._pad_length is not None: self._pad_length += sum(int(x != '') for x in [self._bos_token, self._eos_token]) if data_source is None: if (self._length_filter_mode is _LengthFilterMode.DISCARD and self._max_seq_length is not None): data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type, delimiter=self._delimiter, max_length=self._max_seq_length) else: data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type) super().__init__(data_source, hparams, device=device)
[docs] @staticmethod def default_hparams(): r"""Returns a dictionary of default hyperparameters: .. code-block:: python { # (1) Hyperparameters specific to text dataset "dataset": { "files": [], "compression_type": None, "vocab_file": "", "embedding_init": {}, "delimiter": None, "max_seq_length": None, "length_filter_mode": "truncate", "pad_to_max_seq_length": False, "bos_token": "<BOS>" "eos_token": "<EOS>" "other_transformations": [], "variable_utterance": False, "utterance_delimiter": "|||", "max_utterance_cnt": 5, "data_name": None, } # (2) General hyperparameters "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "mono_text_data", # (3) Bucketing "bucket_boundaries": [], "bucket_batch_sizes": None, "bucket_length_fn": None, } Here: 1. For the hyperparameters in the :attr:`"dataset"` field: `"files"`: str or list A (list of) text file path(s). Each line contains a single text sequence. `"compression_type"`: str, optional One of `None` (no compression), ``"ZLIB"``, or ``"GZIP"``. `"vocab_file"`: str Path to vocabulary file. Each line of the file should contain one vocabulary token. Used to create an instance of :class:`~texar.torch.data.Vocab`. `"embedding_init"`: dict The hyperparameters for pre-trained embedding loading and initialization. The structure and default values are defined in :meth:`texar.torch.data.Embedding.default_hparams`. `"delimiter"`: str, optional The delimiter to split each line of the text files into tokens. If `None` (default), behavior will be equivalent to `str.split()`, i.e. split on any blank character. `"max_seq_length"`: int, optional Maximum length of output sequences. Data samples exceeding the length will be truncated or discarded according to :attr:`"length_filter_mode"`. The length does not include any added :attr:`"bos_token"` or :attr:`"eos_token"`. If `None` (default), no filtering is performed. `"length_filter_mode"`: str Either ``"truncate"`` or ``"discard"``. If ``"truncate"`` (default), tokens exceeding :attr:`"max_seq_length"` will be truncated. If ``"discard"``, data samples longer than :attr:`"max_seq_length"` will be discarded. `"pad_to_max_seq_length"`: bool If `True`, pad all data instances to length :attr:`"max_seq_length"`. Raises error if :attr:`"max_seq_length"` is not provided. `"bos_token"`: str The Begin-Of-Sequence token prepended to each sequence. Set to an empty string to avoid prepending. `"eos_token"`: str The End-Of-Sequence token appended to each sequence. Set to an empty string to avoid appending. `"other_transformations"`: list A list of transformation functions or function names/paths to further transform each single data instance. (More documentations to be added.) `"variable_utterance"`: bool If `True`, each line of the text file is considered to contain multiple sequences (utterances) separated by :attr:`"utterance_delimiter"`. For example, in dialog data, each line can contain a series of dialog history utterances. See the example in `examples/hierarchical_dialog` for a use case. .. warning:: Variable utterances is not yet supported. This option (and related ones below) will be ignored. `"utterance_delimiter"`: str The delimiter to split over utterance level. Should not be the same with :attr:`"delimiter"`. Used only when :attr:`"variable_utterance"` is `True`. `"max_utterance_cnt"`: int Maximally allowed number of utterances in a data instance. Extra utterances are truncated out. `"data_name"`: str Name of the dataset. 2. For the **general** hyperparameters, see :meth:`texar.torch.data.DatasetBase.default_hparams` for details. 3. **Bucketing** is to group elements of the dataset together by length and then pad and batch. For bucketing hyperparameters: `"bucket_boundaries"`: list An int list containing the upper length boundaries of the buckets. Set to an empty list (default) to disable bucketing. `"bucket_batch_sizes"`: list An int list containing batch size per bucket. Length should be `len(bucket_boundaries) + 1`. If `None`, every bucket will have the same batch size specified in :attr:`batch_size`. `"bucket_length_fn"`: str or callable Function maps dataset element to ``int``, determines the length of the element. This can be a function, or the name or full module path to the function. If function name is given, the function must be in the :mod:`texar.torch.custom` module. If `None` (default), length is determined by the number of tokens (including BOS and EOS if added) of the element. .. warning:: Bucketing is not yet supported. These options will be ignored. """ hparams = TextDataBase.default_hparams() hparams["name"] = "mono_text_data" hparams.update({ "dataset": _default_mono_text_dataset_hparams() }) return hparams
@staticmethod def make_embedding(emb_hparams, token_to_id_map): r"""Optionally loads embedding from file (if provided), and returns an instance of :class:`texar.torch.data.Embedding`. """ embedding = None if emb_hparams["file"] is not None and len(emb_hparams["file"]) > 0: embedding = Embedding(token_to_id_map, emb_hparams) return embedding def process(self, raw_example: List[str]) -> List[str]: # Truncates sentences and appends BOS/EOS tokens. words = raw_example if (self._max_seq_length is not None and len(words) > self._max_seq_length): if self._length_filter_mode is _LengthFilterMode.TRUNC: words = words[:self._max_seq_length] if self._hparams.dataset["bos_token"] != '': words.insert(0, self._hparams.dataset["bos_token"]) if self._hparams.dataset["eos_token"] != '': words.append(self._hparams.dataset["eos_token"]) # Apply the "other transformations". for transform in self._other_transforms: words = transform(words) return words def collate(self, examples: List[List[str]]) -> Batch: # For `MonoTextData`, each example is represented as a list of strings. # `_collate` takes care of padding and numericalization. # If `pad_length` is `None`, pad to the longest sentence in the batch. text_ids = [self._vocab.map_tokens_to_ids_py(sent) for sent in examples] text_ids, lengths = padded_batch(text_ids, self._pad_length, pad_value=self._vocab.pad_token_id) # Also pad the examples pad_length = self._pad_length or max(lengths) examples = [ sent + [''] * (pad_length - len(sent)) if len(sent) < pad_length else sent for sent in examples ] text_ids = torch.from_numpy(text_ids) lengths = torch.tensor(lengths, dtype=torch.long) batch = {self.text_name: examples, self.text_id_name: text_ids, self.length_name: lengths} return Batch(len(examples), batch=batch)
[docs] def list_items(self) -> List[str]: r"""Returns the list of item names that the data can produce. Returns: A list of strings. """ items = ['text', 'text_ids', 'length'] data_name = self._hparams.dataset.data_name if data_name is not None: items = [data_name + '_' + item for item in items] return items
@property def vocab(self) -> Vocab: r"""The vocabulary, an instance of :class:`~texar.torch.data.Vocab`. """ return self._vocab @property def text_name(self): r"""The name for the text field""" if self.hparams.dataset["data_name"]: name = "{}_text".format(self.hparams.dataset["data_name"]) else: name = "text" return name @property def text_id_name(self): r"""The name for text ids""" if self.hparams.dataset["data_name"]: name = "{}_text_ids".format(self.hparams.dataset["data_name"]) else: name = "text_ids" return name @property def length_name(self): r"""The name for text length""" if self.hparams.dataset["data_name"]: name = "{}_length".format(self.hparams.dataset["data_name"]) else: name = "length" return name @property def embedding_init_value(self): r"""The `Tensor` containing the embedding value loaded from file. `None` if embedding is not specified. """ if self._embedding is None: return None return self._embedding.word_vecs