# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various data iterator classes.
"""
# pylint: disable=protected-access
from typing import (
Dict, Iterable, Iterator, List, Optional, Sequence, Union, Mapping)
import pkg_resources
import torch
from torch import __version__ as _torch_version # type: ignore
from torch.utils.data import DataLoader
from texar.torch.data.data.data_base import DatasetBase
from texar.torch.data.data.dataset_utils import Batch
from texar.torch.data.data.sampler import (
SamplerBase, SequentialSampler, RandomSampler, BufferShuffleSampler,
BatchingStrategy, DynamicBatchSampler)
from texar.torch.utils.types import MaybeSeq
from texar.torch.utils.utils import ceildiv, map_structure
_torch_version = pkg_resources.parse_version(_torch_version)
__all__ = [
"DataIterator",
"TrainTestDataIterator",
]
# `Dict` is invariant, `Mapping` is not.
DatasetsType = Union[Mapping[str, DatasetBase], MaybeSeq[DatasetBase]]
# pylint: disable=ungrouped-imports
if _torch_version >= pkg_resources.parse_version("1.2.0"): # PyTorch 1.2.0 +
from torch.utils.data._utils.pin_memory import ( # type: ignore
pin_memory as _pin_memory)
elif _torch_version >= pkg_resources.parse_version("1.1.0"): # PyTorch 1.1.0 +
from torch.utils.data._utils.pin_memory import ( # type: ignore
pin_memory_batch as _pin_memory)
else:
from torch.utils.data.dataloader import ( # type: ignore
pin_memory_batch as _pin_memory)
def move_memory(data, device):
def _move_fn(x):
if isinstance(x, torch.Tensor):
return x.to(device=device, non_blocking=True)
return x
if isinstance(data, Batch):
return Batch(len(data), batch={
key: map_structure(_move_fn, value)
for key, value in data.items()
})
return map_structure(_move_fn, data)
if _torch_version >= pkg_resources.parse_version("1.2.0"): # PyTorch 1.2.0 +
# PyTorch 1.2 split the `_DataLoaderIter` class into two:
# `_SingleProcessDataLoaderIter` for when `num_workers == 0`, i.e. when
# multi-processing is disabled; `_MultiProcessingDataLoaderIter` for
# otherwise. The implementation is also slightly different from previous
# releases.
#
# To keep compatibility, our iterator classes should be a subclass of both
# PyTorch `_Single...`/`_Multi...` (for single/multi-process), and our own
# `_Cache...`/`_Data...` (for caching/no caching). This results in four
# different concrete classes, as this regex shows:
# `_[SM]P(Cache)?DataLoaderIter`.
#
# We only expose `_DataLoaderIter` and `_CacheDataLoaderIter` to other
# classes, and construct concrete classes in their `__new__` methods
# depending on the value of `num_workers`. This is for compatibility with
# previous versions, so we don't need to change other parts of the code.
from texar.torch.data.data.data_iterators_utils import \
TexarBaseDataLoaderIter as _BaseDataLoaderIter
from texar.torch.data.data.data_iterators_utils import \
TexarSingleProcessDataLoaderIter as _SingleProcessDataLoaderIter
from texar.torch.data.data.data_iterators_utils import \
TexarMultiProcessingDataLoaderIter as _MultiProcessingDataLoaderIter
class _DataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset. This is almost
identical to PyTorch
:class:`torch.utils.data.dataloader._BaseDataLoaderIter`, except that we
check `allow_smaller_final_batch` here. This is because using
`drop_last` in :class:`~torch.utils.data.sampler.BatchSampler` would
cause the dataset to not load/process/cache certain elements from the
final batch, which complicates the already complex logic.
"""
def __new__(cls, loader: 'SingleDatasetIterator'):
if loader.num_workers > 0:
return super().__new__(_MPDataLoaderIter)
else:
return super().__new__(_SPDataLoaderIter)
def __init__(self, loader: 'SingleDatasetIterator'):
self.device = loader.device
self._batch_size = loader.batch_size
super().__init__(loader)
def __next__(self):
batch = super().__next__()
# Drop smaller final batch according to settings. Note that
# `_batch_size` could be None if dynamic batching is used.
if (self._batch_size is not None and
batch.batch_size < self._batch_size and
not self.dataset.hparams.allow_smaller_final_batch):
raise StopIteration
if self.device is not None:
batch = move_memory(batch, self.device)
return batch
class _SPDataLoaderIter(_DataLoaderIter, _SingleProcessDataLoaderIter):
pass
class _MPDataLoaderIter(_DataLoaderIter, _MultiProcessingDataLoaderIter):
pass
class _CacheDataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset. This class is used when
examples are processed and returned by worker processes. We need to
record the corresponding indices of each batch, call
:meth:`texar.torch.data.data.DatasetBase._add_cached_examples` to cache
the processed examples, and return only the
:class:`~texar.torch.data.data.Batch` instance to the user.
"""
def __new__(cls, loader: 'SingleDatasetIterator'):
if loader.num_workers > 0:
return super().__new__(_MPCacheDataLoaderIter)
else:
return super().__new__(_SPCacheDataLoaderIter)
def __init__(self, loader: 'SingleDatasetIterator'):
self._indices_dict: Dict[int, List[int]] = {}
self._batch_size = loader.batch_size
self.device = loader.device
super().__init__(loader)
class _SPCacheDataLoaderIter(_CacheDataLoaderIter,
_SingleProcessDataLoaderIter):
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.dataset._should_yield_raw_example:
index = [idx[0] for idx in index]
examples, data = data
self.dataset._add_cached_examples(index, examples)
if self.pin_memory:
data = move_memory(_pin_memory(data), self.device)
return data
class _MPCacheDataLoaderIter(_CacheDataLoaderIter,
_MultiProcessingDataLoaderIter):
dataset: DatasetBase
worker_queue_idx: int # so that Pylint gives no errors
def _try_put_index(self):
assert self.tasks_outstanding < 2 * self.num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self.num_workers): # find next active worker, if any
worker_queue_idx = next(self.worker_queue_idx_cycle)
if self.workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self.index_queues[worker_queue_idx].put((self.send_idx, index))
if self.dataset._should_yield_raw_example:
index = [idx[0] for idx in index]
self._indices_dict[self.send_idx] = index
self.task_info[self.send_idx] = (worker_queue_idx,)
self.tasks_outstanding += 1
self.send_idx += 1
def _process_data(self, batch):
batch = super()._process_data(batch)
indices = self._indices_dict[self.rcvd_idx - 1]
del self._indices_dict[self.rcvd_idx - 1]
examples, batch = batch
self.dataset._add_cached_examples(indices, examples)
return batch
def __next__(self):
batch = super().__next__()
if (self._batch_size is not None and
batch.batch_size < self.dataset.batch_size and
not self.dataset.hparams.allow_smaller_final_batch):
raise StopIteration
batch = move_memory(batch, self.device)
return batch
else:
# PyTorch 1.1 and lower defines only the class `_DataLoaderIter` for
# iterating over `DataLoader`.
from torch.utils.data.dataloader import ( # type: ignore
_DataLoaderIter as torch_DataLoaderIter)
class _DataLoaderIter(torch_DataLoaderIter): # type: ignore
r"""Iterates once over the DataLoader's dataset. This is almost
identical to PyTorch
:class:`torch.utils.data.dataloader._DataLoaderIter`, except that we
check `allow_smaller_final_batch` here. This is because using
`drop_last` in :class:`~torch.utils.data.sampler.BatchSampler` would
cause the dataset to not load/process/cache certain elements from the
final batch, which complicates the already complex logic.
"""
def __init__(self, loader: 'SingleDatasetIterator'):
self._batch_size = loader.batch_size
self.device = loader.device
super().__init__(loader)
def __next__(self):
batch = super().__next__()
# Drop smaller final batch according to settings. Note that
# `_batch_size` could be None if dynamic batching is used.
if (self._batch_size is not None and
batch.batch_size < self._batch_size and
not self.dataset.hparams.allow_smaller_final_batch):
raise StopIteration
batch = move_memory(batch, self.device)
return batch
class _CacheDataLoaderIter(torch_DataLoaderIter): # type: ignore
r"""Iterates once over the DataLoader's dataset. This class is used when
examples are processed and returned by worker processes. We need to
record the corresponding indices of each batch, call
:meth:`texar.torch.data.data.DatasetBase._add_cached_examples` to cache
the processed examples, and return only the
:class:`~texar.torch.data.data.Batch` instance to the user.
"""
dataset: DatasetBase
worker_queue_idx: int # so that Pylint gives no errors
def __init__(self, loader: 'SingleDatasetIterator'):
self._indices_dict: Dict[int, List[int]] = {}
self._batch_size = loader.batch_size
self.device = loader.device
super().__init__(loader)
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put(
(self.send_idx, indices))
if self.dataset._should_yield_raw_example:
indices = [index[0] for index in indices]
self._indices_dict[self.send_idx] = indices
self.worker_queue_idx = ((self.worker_queue_idx + 1) %
self.num_workers)
self.batches_outstanding += 1
self.send_idx += 1
def _process_next_batch(self, batch):
batch = super()._process_next_batch(batch)
indices = self._indices_dict[self.rcvd_idx - 1]
del self._indices_dict[self.rcvd_idx - 1]
examples, batch = batch
self.dataset._add_cached_examples(indices, examples)
return batch
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.dataset._should_yield_raw_example:
indices = [index[0] for index in indices]
examples, batch = batch
self.dataset._add_cached_examples(indices, examples)
if self.pin_memory:
batch = _pin_memory(batch)
else:
batch = super().__next__()
if (self._batch_size is not None and
batch.batch_size < self.dataset.batch_size and
not self.dataset.hparams.allow_smaller_final_batch):
raise StopIteration
batch = move_memory(batch, self.device)
return batch
class SingleDatasetIterator(DataLoader):
r"""Iterator for a single dataset. This iterator is based on the PyTorch
:class:`~torch.utils.data.DataLoader` interface, with a custom shuffling
routine. This class is used internally.
Args:
dataset: The dataset to iterator through. The dataset must be an
instance of :class:`texar.torch.data.DatasetBase`, because
configurations are read from the dataset `HParams`.
batching_strategy: The batching strategy to use when performing dynamic
batching. If `None`, fixed-sized batching is used.
pin_memory: If `True`, tensors will be moved onto page-locked memory
before returning. This argument is passed into the constructor for
:torch_docs:`DataLoader <data.html#torch.utils.data.DataLoader>`.
Defaults to `None`, which will set the value to `True` if the
:class:`~texar.torch.data.DatasetBase` instance is set to use a CUDA
device. Set to `True` or `False` to override this behavior.
"""
dataset: DatasetBase
def __init__(self, dataset: DatasetBase,
batching_strategy: Optional[BatchingStrategy] = None,
pin_memory: Optional[bool] = None):
shuffle = dataset.hparams.shuffle
shuffle_buffer_size = dataset.hparams.shuffle_buffer_size
sampler: SamplerBase
if shuffle and shuffle_buffer_size is not None:
sampler = BufferShuffleSampler(dataset, shuffle_buffer_size)
elif shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
num_workers = dataset.hparams.num_parallel_calls
collate_fn = dataset._collate_and_maybe_return
is_cuda = dataset.device is not None and dataset.device.type == "cuda"
if pin_memory is None:
pin_memory = is_cuda
self.device = None
if pin_memory and is_cuda:
self.device = dataset.device
if batching_strategy is not None:
batch_sampler = DynamicBatchSampler(
dataset, sampler, batching_strategy)
super().__init__(
dataset, batch_sampler=batch_sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory)
else:
super().__init__(
dataset, batch_size=dataset.batch_size, drop_last=False,
sampler=sampler, collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory)
def __iter__(self):
if self.dataset._should_return_processed_examples:
# Accepts processed examples from workers and add to dataset cache.
return _CacheDataLoaderIter(self)
else:
return _DataLoaderIter(self)
def __len__(self):
if self.batch_size is None:
raise TypeError("__len__ not supported for dynamic batching")
data_length = len(self.dataset) # may throw TypeError
if self.dataset.hparams.allow_smaller_final_batch:
return ceildiv(data_length, self.batch_size)
return data_length // self.batch_size
[docs]class DataIterator:
r"""Data iterator that switches and iterates through multiple datasets.
This is a wrapper of :class:`~texar.torch.data.SingleDatasetIterator`.
Args:
datasets: Datasets to iterate through. This can be:
- A single instance of :class:`~texar.torch.data.DatasetBase`.
- A `dict` that maps dataset name to instances of
:class:`~texar.torch.data.DatasetBase`.
- A `list` of instances of :class:`texar.torch.data.DatasetBase`.
The name of instances (:attr:`texar.torch.data.DatasetBase.name`)
must be unique.
batching_strategy: The batching strategy to use when performing dynamic
batching. If `None`, fixed-sized batching is used.
pin_memory: If `True`, tensors will be moved onto page-locked memory
before returning. This argument is passed into the constructor for
:torch_docs:`DataLoader <data.html#torch.utils.data.DataLoader>`.
Defaults to `None`, which will set the value to `True` if the
:class:`~texar.torch.data.DatasetBase` instance is set to use a CUDA
device. Set to `True` or `False` to override this behavior.
Example:
Create an iterator over two datasets and generating fixed-sized batches:
.. code-block:: python
train_data = MonoTextData(hparams_train)
test_data = MonoTextData(hparams_test)
iterator = DataIterator({'train': train_data, 'test': test_data})
for epoch in range(200): # Run 200 epochs of train/test
# Starts iterating through training data from the beginning.
iterator.switch_to_dataset('train')
for batch in iterator:
... # Do training with the batch.
# Starts iterating through test data from the beginning
for batch in iterator.get_iterator('test'):
... # Do testing with the batch.
Dynamic batching based on total number of tokens:
.. code-block:: python
iterator = DataIterator(
{'train': train_data, 'test': test_data},
batching_strategy=TokenCountBatchingStrategy(max_tokens=1000))
Dynamic batching with custom strategy (e.g. total number of tokens in
examples from :class:`~texar.torch.data.PairedTextData`, including
padding):
.. code-block:: python
class CustomBatchingStrategy(BatchingStrategy):
def __init__(self, max_tokens: int):
self.max_tokens = max_tokens
self.reset_batch()
def reset_batch(self) -> None:
self.max_src_len = 0
self.max_tgt_len = 0
self.cur_batch_size = 0
def add_example(self, ex: Tuple[List[str], List[str]]) -> bool:
max_src_len = max(self.max_src_len, len(ex[0]))
max_tgt_len = max(self.max_tgt_len, len(ex[0]))
if (max(max_src_len + max_tgt_len) *
(self.cur_batch_size + 1) > self.max_tokens):
return False
self.max_src_len = max_src_len
self.max_tgt_len = max_tgt_len
self.cur_batch_size += 1
return True
iterator = DataIterator(
{'train': train_data, 'test': test_data},
batching_strategy=CustomBatchingStrategy(max_tokens=1000))
"""
# TODO: Think about whether we should support save/load.
def __init__(self, datasets: DatasetsType,
batching_strategy: Optional[BatchingStrategy] = None,
pin_memory: Optional[bool] = None):
self._default_dataset_name = 'data'
if isinstance(datasets, DatasetBase):
datasets = {self._default_dataset_name: datasets}
elif isinstance(datasets, Sequence):
if any(not isinstance(d, DatasetBase) for d in datasets):
raise ValueError("`datasets` must be an non-empty list of "
"`texar.torch.data.DatasetBase` instances.")
num_datasets = len(datasets)
datasets = {d.name: d for d in datasets}
if len(datasets) < num_datasets:
raise ValueError("Names of datasets must be unique.")
_datasets = {
name: SingleDatasetIterator(dataset, batching_strategy, pin_memory)
for name, dataset in datasets.items()}
self._datasets = _datasets
if len(self._datasets) <= 0:
raise ValueError("`datasets` must not be empty.")
self._current_dataset_name: Optional[str] = None
@property
def num_datasets(self) -> int:
r"""Number of datasets.
"""
return len(self._datasets)
@property
def dataset_names(self) -> List[str]:
r"""A list of dataset names.
"""
return list(self._datasets.keys())
def _validate_dataset_name(self, dataset_name: Optional[str]) -> str:
r"""Validate the provided dataset name, and return the validated name.
"""
if dataset_name is None:
if self.num_datasets > 1:
raise ValueError("`dataset_name` is required if there are "
"more than one datasets.")
dataset_name = next(iter(self._datasets))
if dataset_name not in self._datasets:
raise ValueError("Dataset not found: ", dataset_name)
return dataset_name
[docs] def switch_to_dataset(self, dataset_name: Optional[str] = None):
r"""Re-initializes the iterator of a given dataset and starts iterating
over the dataset (from the beginning).
Args:
dataset_name (optional): Name of the dataset. If not provided,
there must be only one Dataset.
"""
self._current_dataset_name = self._validate_dataset_name(dataset_name)
[docs] def get_iterator(self,
dataset_name: Optional[str] = None) -> Iterator[Batch]:
r"""Re-initializes the iterator of a given dataset and starts iterating
over the dataset (from the beginning).
Args:
dataset_name (optional): Name of the dataset. If not provided,
there must be only one Dataset.
"""
if dataset_name is not None or self._current_dataset_name is None:
dataset_name = self._validate_dataset_name(dataset_name)
elif self._current_dataset_name is not None:
dataset_name = self._current_dataset_name
else:
raise ValueError("No dataset is selected.")
return iter(self._datasets[dataset_name])
def __iter__(self) -> Iterator[Batch]:
r"""Returns the iterator for the currently selected or default dataset.
"""
return self.get_iterator()
def __len__(self):
return len(self._datasets[
self._validate_dataset_name(self._current_dataset_name)
])
[docs]class TrainTestDataIterator(DataIterator):
r"""Data iterator that alternates between training, validation, and test
datasets.
:attr:`train`, :attr:`val`, and :attr:`test` are instances of
:class:`~texar.torch.data.DatasetBase`. At least one of them must be
provided.
This is a wrapper of :class:`~texar.torch.data.DataIterator`.
Args:
train (optional): Training data.
val (optional): Validation data.
test (optional): Test data.
batching_strategy: The batching strategy to use when performing dynamic
batching. If `None`, fixed-sized batching is used.
pin_memory: If `True`, tensors will be moved onto page-locked memory
before returning. This argument is passed into the constructor for
:torch_docs:`DataLoader <data.html#torch.utils.data.DataLoader>`.
Defaults to `None`, which will set the value to `True` if the
:class:`~texar.torch.data.DatasetBase` instance is set to use a CUDA
device. Set to `True` or `False` to override this behavior.
Example:
.. code-block:: python
train_data = MonoTextData(hparams_train)
val_data = MonoTextData(hparams_val)
iterator = TrainTestDataIterator(train=train_data, val=val_data)
for epoch in range(200): # Run 200 epochs of train/val
# Starts iterating through training data from the beginning.
iterator.switch_to_train_data(sess)
for batch in iterator:
... # Do training with the batch.
# Starts iterating through val data from the beginning.
for batch in iterator.get_val_iterator():
... # Do validation on the batch.
"""
def __init__(self, train: Optional[DatasetBase] = None,
val: Optional[DatasetBase] = None,
test: Optional[DatasetBase] = None,
batching_strategy: Optional[BatchingStrategy] = None,
pin_memory: Optional[bool] = None):
dataset_dict = {}
self._train_name = 'train'
self._val_name = 'val'
self._test_name = 'test'
if train is not None:
dataset_dict[self._train_name] = train
if val is not None:
dataset_dict[self._val_name] = val
if test is not None:
dataset_dict[self._test_name] = test
if len(dataset_dict) == 0:
raise ValueError("At least one of `train`, `val`, and `test` "
"must be provided.")
super().__init__(dataset_dict, batching_strategy, pin_memory)
[docs] def switch_to_train_data(self) -> None:
r"""Switch to training data."""
if self._train_name not in self._datasets:
raise ValueError("Training data not provided.")
self.switch_to_dataset(self._train_name)
[docs] def switch_to_val_data(self) -> None:
r"""Switch to validation data."""
if self._val_name not in self._datasets:
raise ValueError("Validation data not provided.")
self.switch_to_dataset(self._val_name)
[docs] def switch_to_test_data(self) -> None:
r"""Switch to test data."""
if self._test_name not in self._datasets:
raise ValueError("Test data not provided.")
self.switch_to_dataset(self._test_name)
[docs] def get_train_iterator(self) -> Iterable[Batch]:
r"""Obtain an iterator over training data."""
if self._train_name not in self._datasets:
raise ValueError("Training data not provided.")
return self.get_iterator(self._train_name)
[docs] def get_val_iterator(self) -> Iterable[Batch]:
r"""Obtain an iterator over validation data."""
if self._val_name not in self._datasets:
raise ValueError("Validation data not provided.")
return self.get_iterator(self._val_name)
[docs] def get_test_iterator(self) -> Iterable[Batch]:
r"""Obtain an iterator over test data."""
if self._test_name not in self._datasets:
raise ValueError("Test data not provided.")
return self.get_iterator(self._test_name)