Source code for texar.torch.data.data.sampler

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various sampler classes.
"""

# pylint: disable=protected-access

from typing import (
    Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar, Union)

import torch
from torch.utils.data import sampler as torch_sampler

from texar.torch.data.data.data_base import DatasetBase


__all__ = [
    "BatchingStrategy",
    "TokenCountBatchingStrategy",
]

Example = TypeVar('Example')


# pylint: disable=attribute-defined-outside-init
# TODO: Remove this when Pylint fixes the bug. If the `disable` directive is not
#  added, Pylint incorrectly reports this error for `self.size` in subclasses of
#  `SamplerBase` in Python 3.6 due to use of the Generic class.
#  See Pylint issue: https://github.com/PyCQA/pylint/issues/2981

class SamplerBase(torch_sampler.Sampler, Generic[Example]):
    r"""A subclass of :torch_docs:`~torch.utils.data.Sampler
    <data.html#torch.utils.data.Sampler>` that supports:

    - Returning raw examples when required.
    - Creating iterators with unknown dataset size.

    This class is used internally in
    :class:`~texar.torch.data.data.DataIterator`. It calls the
    :meth:`~texar.torch.data.data.DatasetBase._prefetch_source` method to ensure
    the required number of raw examples are prefetched from source.

    Args:
        data: The :class:`~texar.torch.data.data.DatasetBase` instance.
    """
    size: Optional[int]

    def __init__(self, data: DatasetBase[Any, Example]):
        super().__init__(data)

        self._data = data
        self.size = None

    def _iterator_given_size(self, size: int) -> Iterator[int]:
        r"""Return an iterator that generates samples when the dataset size
        is given.

        Args:
            size: The dataset size.
        """
        raise NotImplementedError

    def _iterator_unknown_size(self) -> Iterator[int]:
        r"""Return an iterator that generates samples when the dataset size
        is unknown. This iterator must also call
        :meth:`texar.torch.data.data.DatasetBase._prefetch_source` and check
        whether the dataset size can be determined, before yielding the index.
        See example implementations for details.
        """
        raise NotImplementedError

    def __iter__(self) -> Union[Iterator[int], Iterator[Tuple[int, Example]]]:
        r"""Return an iterator based on the dataset settings.
        """
        self.size = self._data._dataset_size
        if (not self._data._fully_cached or
                self._data._should_call_prefetch_source):
            self._data._start_iteration()
            # First epoch of lazy loading, calling prefetch, and returning
            # indices and examples.
            iterator = self._iterator_unknown_size()
        else:
            # Non-lazy loading, or when dataset has been fully iterated.
            assert self.size is not None
            iterator = self._iterator_given_size(self.size)

        if self._data._should_call_prefetch_processed:
            # Processing routine is performed in main process. Yield
            # processed examples instead.
            map_fn = lambda idx: (idx, self._data._processed_cache[idx])
        elif self._data._should_yield_raw_example:
            # Return indices and examples for any epoch in this case.
            map_fn = lambda idx: (idx, self._data._source[idx])
        else:
            map_fn = None  # type: ignore
        if map_fn is not None:
            return map(map_fn, iterator)

        return iterator

    def __len__(self):
        if self.size is not None:
            return self.size
        raise AttributeError("Dataset size cannot be determined at this point")


class SequentialSampler(SamplerBase[Example]):
    r"""Samples elements sequentially, always in the same order. Same as
    :torch_docs:`~torch.utils.data.SequentialSampler
    <data.html#torch.utils.data.SequentialSampler>`
    """

    def _iterator_given_size(self, size: int) -> Iterator[int]:
        return iter(range(size))

    def _iterator_unknown_size(self) -> Iterator[int]:
        index = 0
        while True:
            cur_size = self._data._prefetch_source(index)
            if cur_size is not None:
                self.size = cur_size
                break
            yield index
            index += 1


class RandomSampler(SamplerBase[Example]):
    r"""Samples elements randomly. If without replacement, then sample from a
    shuffled dataset. If with replacement, then user can specify ``num_samples``
    to draw.

    This class uses :torch_docs:`torch.utils.data.RandomSampler
    <data.html#torch.utils.data.RandomSampler>` directly. Given the
    nature of such shuffling, it cannot be used for iterators with unknown size.

    Args:
        data: The :class:`~texar.torch.data.data.DatasetBase` instance.
        num_samples (int): number of samples to draw, default=len(dataset)
        replacement (bool): samples are drawn with replacement if `True`,
            default=False
    """

    def __init__(self, data: DatasetBase[Any, Example],
                 replacement: bool = False, num_samples: Optional[int] = None):
        super().__init__(data)
        self._sampler = torch_sampler.RandomSampler(
            data, replacement, num_samples)

    def _iterator_given_size(self, size: int) -> Iterator[int]:
        del size  # not used
        return iter(self._sampler)

    def _iterator_unknown_size(self) -> Iterator[int]:
        raise TypeError(
            "RandomSampler does not support lazy data loading. To perform "
            "shuffling with lazy loading, use BufferShuffleSampler.")


class BufferShuffleSampler(SamplerBase[Example]):
    r"""A :torch_docs:`~torch.utils.data.Sampler
    <data.html#torch.utils.data.Sampler>` that uses a shuffle buffer, as
    in TensorFlow. The buffer is first filled with data examples. Each time a
    sample is drawn from the buffer, and the drawn sample is replaced with the
    next data example.

    This class is used internally in
    :class:`~texar.torch.data.data.DataIterator`. It calls the
    :meth:`~texar.torch.data.data.DatasetBase._prefetch_source` method to ensure
    the required number of raw examples are prefetched from source.

    Args:
        data: The :class:`~texar.torch.data.data.DatasetBase` instance.
        buffer_size: The size of the shuffle buffer. Use larger buffer sizes for
            more uniformly-random shuffling.
    """

    def __init__(self, data: DatasetBase[Any, Example], buffer_size: int):
        super().__init__(data)
        self.buffer_size = buffer_size

    def _iterator_given_size(self, size) -> Iterator[int]:
        if self.buffer_size >= size:
            yield from iter(torch.randperm(size).tolist())
            return

        buffer = list(range(self.buffer_size))
        for x in range(self.buffer_size, size):
            sample = torch.randint(self.buffer_size, (1,)).item()
            index = buffer[sample]
            yield index
            buffer[sample] = x
        yield from (buffer[x] for x in torch.randperm(self.buffer_size))

    def _iterator_unknown_size(self) -> Iterator[int]:
        buffer = list(range(self.buffer_size))
        x = self.buffer_size
        while True:
            sample = torch.randint(self.buffer_size, (1,)).item()
            index = buffer[sample]
            cur_size = self._data._prefetch_source(index)
            if cur_size is not None and index >= cur_size:
                self.size = cur_size
            if self.size is not None and index >= self.size:
                break
            yield index
            buffer[sample] = x
            x += 1
        yield from (buffer[x] for x in torch.randperm(self.buffer_size)
                    if buffer[x] < self.size)


# pylint: enable=attribute-defined-outside-init


[docs]class BatchingStrategy(Generic[Example]): r"""Decides batch boundaries in dynamic batching. Please refer to :class:`TokenCountBatchingStrategy` for a concrete example. """
[docs] def reset_batch(self) -> None: r"""Reset the internal state of the batching strategy. This method is called at the start of iteration, and after each batch is yielded. """ raise NotImplementedError
[docs] def add_example(self, example: Example) -> bool: r"""Add an example into the current batch, and modify internal states accordingly. If the example should not be added to the batch, this method does not modify the internal state, and returns `False`. Args: example: The example to add to the batch. Returns: A boolean value indicating whether :attr:`example` should be added to the batch. """ raise NotImplementedError
[docs]class TokenCountBatchingStrategy(BatchingStrategy[Example]): r"""Create dynamically-sized batches so that the total number of tokens inside each batch is constrained. Args: max_tokens (int): The maximum number of tokens inside each batch. max_batch_size (int, optional): The maximum number of examples for each batch. If `None`, batches can contain arbitrary number of examples as long as the total number of tokens does not exceed :attr:`max_tokens`. length_fn (callable, optional): A function taking a data example as argument, and returning the number of tokens in the example. By default, :python:`len` is used, which is the desired behavior if the dataset in question is a :class:`~texar.torch.data.MonoTextData`. """ sum_tokens: int cur_batch_size: int def __init__(self, max_tokens: int, max_batch_size: Optional[int] = None, length_fn: Optional[Callable[[Example], int]] = None): self.max_batch_size = max_batch_size self.max_tokens = max_tokens self.length_fn: Callable[[Example], int] self.length_fn = length_fn or len # type: ignore def reset_batch(self) -> None: self.sum_tokens = 0 self.cur_batch_size = 0 def add_example(self, example: Example) -> bool: if self.cur_batch_size == self.max_batch_size: return False cur_tokens = self.length_fn(example) if cur_tokens + self.sum_tokens > self.max_tokens: return False self.cur_batch_size += 1 self.sum_tokens += cur_tokens return True
class DynamicBatchSampler(torch_sampler.BatchSampler, Generic[Example]): r"""A subclass of :torch_docs:`~torch.utils.data.BatchSampler <data.html#torch.utils.data.BatchSampler>` that supports dynamic batching through a user-provided :class:`BatchingStrategy`. This class is used internally. Args: dataset: The dataset to create batches from. sampler: An instance of :class:`SamplerBase` that returns indices of each sampled example. strategy: An instance of :class:`BatchingStrategy` that decides whether a batch should be yielded. """ def __init__(self, dataset: DatasetBase[Any, Example], # pylint: disable=super-init-not-called sampler: SamplerBase, strategy: BatchingStrategy[Example]): self.dataset = dataset self.sampler = sampler self.strategy = strategy def __iter__(self) -> Union[Iterator[List[int]], # type: ignore Iterator[List[Tuple[int, Example]]]]: batch = [] # type: ignore self.strategy.reset_batch() for idx in self.sampler: if isinstance(idx, tuple): example = self.dataset[idx[0]] else: example = self.dataset[idx] while not self.strategy.add_example(example): if len(batch) == 0: raise ValueError(f"Batching strategy refused to add " f"example {idx} to empty batch.") yield batch batch = [] self.strategy.reset_batch() batch.append(idx) if len(batch) > 0: yield batch self.strategy.reset_batch() def __len__(self): raise TypeError("DynamicBatchSampler does not support __len__")