Source code for texar.torch.data.data.paired_text_data

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paired text data that consists of source text and target text.
"""
import math
from typing import List, Optional, Tuple

import torch

from texar.torch.data.data.data_base import (
    DataSource, FilterDataSource, ZipDataSource)
from texar.torch.data.data.dataset_utils import Batch, padded_batch
from texar.torch.data.data.mono_text_data import (
    MonoTextData, _LengthFilterMode, _default_mono_text_dataset_hparams)
from texar.torch.data.data.text_data_base import (
    TextDataBase, TextLineDataSource)
from texar.torch.data.embedding import Embedding
from texar.torch.data.vocabulary import SpecialTokens, Vocab
from texar.torch.hyperparams import HParams
from texar.torch.utils import utils

__all__ = [
    "_default_paired_text_dataset_hparams",
    "PairedTextData",
]


def _default_paired_text_dataset_hparams():
    r"""Returns hyperparameters of a paired text dataset with default values.

    See :meth:`texar.torch.data.PairedTextData.default_hparams` for details.
    """
    source_hparams = _default_mono_text_dataset_hparams()
    source_hparams["bos_token"] = None
    source_hparams["data_name"] = "source"
    target_hparams = _default_mono_text_dataset_hparams()
    target_hparams.update(
        {
            "vocab_share": False,
            "embedding_init_share": False,
            "processing_share": False,
            "data_name": "target"
        }
    )
    return {
        "source_dataset": source_hparams,
        "target_dataset": target_hparams
    }


[docs]class PairedTextData(TextDataBase[Tuple[List[str], List[str]], Tuple[List[str], List[str]]]): r"""Text data processor that reads parallel source and target text. This can be used in, e.g., seq2seq models. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. device: The device of the produced batches. For GPU training, set to current CUDA device. By default, the processor reads raw data files, performs tokenization, batching and other pre-processing steps, and results in a Dataset whose element is a python `dict` including six fields: "source_text": A list of ``[batch_size]`` elements each containing a list of **raw** text tokens of source sequences. Short sequences in the batch are padded with **empty string**. By default only ``EOS`` token is appended to each sequence. Out-of-vocabulary tokens are **NOT** replaced with ``UNK``. "source_text_ids": A list of ``[batch_size]`` elements each containing a list of token indexes of source sequences in the batch. "source_length": A list of ``[batch_size]`` elements of integers containing the length of each source sequence in the batch. "target_text": A list same as "source_text" but for target sequences. By default both BOS and EOS are added. "target_text_ids": A list same as "source_text_ids" but for target sequences. "target_length": An list same as "source_length" but for target sequences. The above field names can be accessed through :attr:`source_text_name`, :attr:`source_text_id_name`, :attr:`source_length_name`, and those prefixed with ``target_``, respectively. Example: .. code-block:: python hparams={ 'source_dataset': {'files': 's', 'vocab_file': 'vs'}, 'target_dataset': {'files': ['t1', 't2'], 'vocab_file': 'vt'}, 'batch_size': 1 } data = PairedTextData(hparams) iterator = DataIterator(data) for batch in iterator: # batch contains the following # batch_ == { # 'source_text': [['source', 'sequence', '<EOS>']], # 'source_text_ids': [[5, 10, 2]], # 'source_length': [3] # 'target_text': [['<BOS>', 'target', 'sequence', '1', '<EOS>']], # 'target_text_ids': [[1, 6, 10, 20, 2]], # 'target_length': [5] # } """ def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) src_hparams = self.hparams.source_dataset tgt_hparams = self.hparams.target_dataset # create vocabulary self._src_bos_token = src_hparams["bos_token"] self._src_eos_token = src_hparams["eos_token"] self._src_transforms = src_hparams["other_transformations"] self._src_vocab = Vocab(src_hparams.vocab_file, bos_token=src_hparams.bos_token, eos_token=src_hparams.eos_token) if tgt_hparams["processing_share"]: self._tgt_bos_token = src_hparams["bos_token"] self._tgt_eos_token = src_hparams["eos_token"] else: self._tgt_bos_token = tgt_hparams["bos_token"] self._tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(self._tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(self._tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == self._src_vocab.bos_token and \ tgt_eos_token == self._src_vocab.eos_token: self._tgt_vocab = self._src_vocab else: self._tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: self._tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) # create embeddings self._src_embedding = MonoTextData.make_embedding( src_hparams.embedding_init, self._src_vocab.token_to_id_map_py) if self._hparams.target_dataset.embedding_init_share: self._tgt_embedding = self._src_embedding else: tgt_emb_file = tgt_hparams.embedding_init["file"] self._tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": self._tgt_embedding = MonoTextData.make_embedding( self._tgt_vocab.token_to_id_map_py, tgt_hparams.embedding_init) # create data source self._src_delimiter = src_hparams.delimiter self._src_max_seq_length = src_hparams.max_seq_length self._src_length_filter_mode = _LengthFilterMode( src_hparams.length_filter_mode) self._src_pad_length = self._src_max_seq_length if self._src_pad_length is not None: self._src_pad_length += sum(int(x is not None and x != '') for x in [src_hparams.bos_token, src_hparams.eos_token]) src_data_source = TextLineDataSource( src_hparams.files, compression_type=src_hparams.compression_type) self._tgt_transforms = tgt_hparams["other_transformations"] self._tgt_delimiter = tgt_hparams.delimiter self._tgt_max_seq_length = tgt_hparams.max_seq_length self._tgt_length_filter_mode = _LengthFilterMode( tgt_hparams.length_filter_mode) self._tgt_pad_length = self._tgt_max_seq_length if self._tgt_pad_length is not None: self._tgt_pad_length += sum(int(x is not None and x != '') for x in [tgt_hparams.bos_token, tgt_hparams.eos_token]) tgt_data_source = TextLineDataSource( tgt_hparams.files, compression_type=tgt_hparams.compression_type) data_source: DataSource[Tuple[List[str], List[str]]] data_source = ZipDataSource( # type: ignore src_data_source, tgt_data_source) if ((self._src_length_filter_mode is _LengthFilterMode.DISCARD and self._src_max_seq_length is not None) or (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and self._tgt_length_filter_mode is not None)): max_source_length = self._src_max_seq_length or math.inf max_tgt_length = self._tgt_max_seq_length or math.inf def filter_fn(raw_example): return (len(raw_example[0]) <= max_source_length and len(raw_example[1]) <= max_tgt_length) data_source = FilterDataSource(data_source, filter_fn) super().__init__(data_source, hparams, device=device)
[docs] @staticmethod def default_hparams(): r"""Returns a dictionary of default hyperparameters. .. code-block:: python { # (1) Hyperparams specific to text dataset "source_dataset": { "files": [], "compression_type": None, "vocab_file": "", "embedding_init": {}, "delimiter": None, "max_seq_length": None, "length_filter_mode": "truncate", "pad_to_max_seq_length": False, "bos_token": None, "eos_token": "<EOS>", "other_transformations": [], "variable_utterance": False, "utterance_delimiter": "|||", "max_utterance_cnt": 5, "data_name": "source", }, "target_dataset": { # ... # Same fields are allowed as in "source_dataset" with the # same default values, except the # following new fields/values: "bos_token": "<BOS>" "vocab_share": False, "embedding_init_share": False, "processing_share": False, "data_name": "target" } # (2) General hyperparams "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "paired_text_data", # (3) Bucketing "bucket_boundaries": [], "bucket_batch_sizes": None, "bucket_length_fn": None, } Here: 1. Hyperparameters in the :attr:`"source_dataset"` and :attr:`"target_dataset"` fields have the same definition as those in :meth:`texar.torch.data.MonoTextData.default_hparams`, for source and target text, respectively. For the new hyperparameters in :attr:`"target_dataset"`: `"vocab_share"`: bool Whether to share the vocabulary of source. If `True`, the vocab file of target is ignored. `"embedding_init_share"`: bool Whether to share the embedding initial value of source. If `True`, :attr:`"embedding_init"` of target is ignored. :attr:`"vocab_share"` must be true to share the embedding initial value. `"processing_share"`: bool Whether to share the processing configurations of source, including `"delimiter"`, `"bos_token"`, `"eos_token"`, and `"other_transformations"`. 2. For the **general** hyperparameters, see :meth:`texar.torch.data.DatasetBase.default_hparams` for details. 3. For **bucketing** hyperparameters, see :meth:`texar.torch.data.MonoTextData.default_hparams` for details, except that the default `"bucket_length_fn"` is the maximum sequence length of source and target sequences. .. warning:: Bucketing is not yet supported. These options will be ignored. """ hparams = TextDataBase.default_hparams() hparams["name"] = "paired_text_data" hparams.update(_default_paired_text_dataset_hparams()) return hparams
@staticmethod def make_embedding(src_emb_hparams, src_token_to_id_map, tgt_emb_hparams=None, tgt_token_to_id_map=None, emb_init_share=False): r"""Optionally loads source and target embeddings from files (if provided), and returns respective :class:`texar.torch.data.Embedding` instances. """ src_embedding = MonoTextData.make_embedding( src_emb_hparams, src_token_to_id_map) if emb_init_share: tgt_embedding = src_embedding else: tgt_emb_file = tgt_emb_hparams["file"] tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams) return src_embedding, tgt_embedding def process(self, raw_example: Tuple[List[str], List[str]]) -> \ Tuple[List[str], List[str]]: # `_process` truncates sentences and appends BOS/EOS tokens. src_words = raw_example[0] if (self._src_max_seq_length is not None and len(src_words) > self._src_max_seq_length): if self._src_length_filter_mode is _LengthFilterMode.TRUNC: src_words = src_words[:self._src_max_seq_length] if self._src_bos_token is not None and self._src_bos_token != '': src_words.insert(0, self._src_bos_token) if self._src_eos_token is not None and self._src_eos_token != '': src_words.append(self._src_eos_token) # apply the transformations to source for transform in self._src_transforms: src_words = transform(src_words) tgt_words = raw_example[1] if (self._tgt_max_seq_length is not None and len(tgt_words) > self._tgt_max_seq_length): if self._tgt_length_filter_mode is _LengthFilterMode.TRUNC: tgt_words = tgt_words[:self._tgt_max_seq_length] if self._tgt_bos_token is not None and self._tgt_bos_token != '': tgt_words.insert(0, self._tgt_bos_token) if self._tgt_eos_token is not None and self._tgt_eos_token != '': tgt_words.append(self._tgt_eos_token) # apply the transformations to target for transform in self._tgt_transforms: tgt_words = transform(tgt_words) return src_words, tgt_words @staticmethod def _get_name_prefix(src_hparams, tgt_hparams): name_prefix = [ src_hparams["data_name"], tgt_hparams["data_name"]] if name_prefix[0] == name_prefix[1]: raise ValueError("'data_name' of source and target " "datasets cannot be the same.") return name_prefix def collate(self, examples: List[Tuple[List[str], List[str]]]) -> Batch: # For `PairedTextData`, each example is represented as a tuple of list # of strings. # `_collate` takes care of padding and numericalization. # If `pad_length` is `None`, pad to the longest sentence in the batch. src_examples = [example[0] for example in examples] source_ids = [self._src_vocab.map_tokens_to_ids_py(sent) for sent in src_examples] source_ids, source_lengths = \ padded_batch(source_ids, self._src_pad_length, pad_value=self._src_vocab.pad_token_id) src_pad_length = self._src_pad_length or max(source_lengths) src_examples = [ sent + [''] * (src_pad_length - len(sent)) if len(sent) < src_pad_length else sent for sent in src_examples ] source_ids = torch.from_numpy(source_ids) source_lengths = torch.tensor(source_lengths, dtype=torch.long) tgt_examples = [example[1] for example in examples] target_ids = [self._tgt_vocab.map_tokens_to_ids_py(sent) for sent in tgt_examples] target_ids, target_lengths = \ padded_batch(target_ids, self._tgt_pad_length, pad_value=self._tgt_vocab.pad_token_id) tgt_pad_length = self._tgt_pad_length or max(target_lengths) tgt_examples = [ sent + [''] * (tgt_pad_length - len(sent)) if len(sent) < tgt_pad_length else sent for sent in tgt_examples ] target_ids = torch.from_numpy(target_ids) target_lengths = torch.tensor(target_lengths, dtype=torch.long) return Batch(len(examples), source_text=src_examples, source_text_ids=source_ids, source_length=source_lengths, target_text=tgt_examples, target_text_ids=target_ids, target_length=target_lengths)
[docs] def list_items(self) -> List[str]: r"""Returns the list of item names that the data can produce. Returns: A list of strings. """ items = ['text', 'text_ids', 'length'] src_name = self._hparams.source_dataset['data_name'] tgt_name = self._hparams.target_dataset['data_name'] if src_name is not None: src_items = [src_name + '_' + item for item in items] else: src_items = items if tgt_name is not None: tgt_items = [tgt_name + '_' + item for item in items] else: tgt_items = items return src_items + tgt_items
@property def vocab(self): r"""A pair instances of :class:`~texar.torch.data.Vocab` that are source and target vocabs, respectively. """ return self._src_vocab, self._tgt_vocab @property def source_vocab(self): r"""The source vocab, an instance of :class:`~texar.torch.data.Vocab`. """ return self._src_vocab @property def target_vocab(self): r"""The target vocab, an instance of :class:`~texar.torch.data.Vocab`. """ return self._tgt_vocab @property def source_text_name(self): r"""The name for source text""" name = "{}_text".format(self.hparams.source_dataset["data_name"]) return name @property def source_text_id_name(self): r"""The name for source text id""" name = "{}_text_ids".format(self.hparams.source_dataset["data_name"]) return name @property def source_length_name(self): r"""The name for source length""" name = "{}_length".format(self.hparams.source_dataset["data_name"]) return name @property def target_text_name(self): r"""The name for target text""" name = "{}_text".format(self.hparams.target_dataset["data_name"]) return name @property def target_text_id_name(self): r"""The name for target text id""" name = "{}_text_ids".format(self.hparams.target_dataset["data_name"]) return name @property def target_length_name(self): r"""The name for target length""" name = "{}_length".format(self.hparams.target_dataset["data_name"]) return name
[docs] def embedding_init_value(self): r"""A pair of `Tensor` containing the embedding values of source and target data loaded from file. """ src_emb = self.hparams.source_dataset["embedding_init"] tgt_emb = self.hparams.target_dataser["embedding_init"] return src_emb, tgt_emb