Source code for texar.torch.data.data.scalar_data

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various data classes that define data reading, parsing, batching, and other
preprocessing operations.
"""
from typing import (List, Optional, Union)

from distutils.util import strtobool
import numpy as np
import torch

from texar.torch.data.data.data_base import DatasetBase, DataSource
from texar.torch.data.data.dataset_utils import Batch
from texar.torch.data.data.text_data_base import TextLineDataSource
from texar.torch.hyperparams import HParams
from texar.torch.utils.dtypes import get_numpy_dtype, \
    get_supported_scalar_types, torch_bool


__all__ = [
    "_default_scalar_dataset_hparams",
    "ScalarData"
]


def _default_scalar_dataset_hparams():
    """Returns hyperparameters of a scalar dataset with default values.

    See :meth:`texar.torch.data.ScalarData.default_hparams` for details.
    """
    return {
        "files": [],
        "compression_type": None,
        "data_type": "int",
        "data_name": "data",
        "other_transformations": [],
        "@no_typecheck": ["files"]
    }


[docs]class ScalarData(DatasetBase[List[str], Union[int, float]]): r"""Scalar data where each line of the files is a scalar (int or float), e.g., a data label. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. device: The device of the produced batches. For GPU training, set to current CUDA device. The processor reads and processes raw data and results in a dataset whose element is a python `dict` including one field. The field name is specified in :attr:`hparams["dataset"]["data_name"]`. If not specified, the default name is `"data"`. The field name can be accessed through :attr:`data_name`. This field is a Tensor of shape `[batch_size]` containing a batch of scalars, of either int or float type as specified in :attr:`hparams`. Example: .. code-block:: python hparams={ 'dataset': { 'files': 'data.txt', 'data_name': 'label' }, 'batch_size': 2 } data = ScalarData(hparams) iterator = DataIterator(data) for batch in iterator: # batch contains the following # batch == { # 'label': [2, 9] # } """ def __init__(self, hparams, device: Optional[torch.device] = None, data_source: Optional[DataSource] = None): self._hparams = HParams(hparams, self.default_hparams()) self._other_transforms = self._hparams.dataset.other_transformations data_type = self._hparams.dataset["data_type"] if data_type not in get_supported_scalar_types(): raise ValueError(f"Unsupported data type '{data_type}'") # In Pytorch versions < 1.1.0, "torch.uint8" is treated as "bool" type # hence we set self.data_type = np.uint8 here if data_type == "bool": self._data_type = get_numpy_dtype(str(torch_bool)) else: self._data_type = get_numpy_dtype(data_type) if data_source is None: data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type) super().__init__(data_source, hparams, device=device)
[docs] @staticmethod def default_hparams(): r"""Returns a dictionary of default hyperparameters. .. code-block:: python { # (1) Hyperparams specific to scalar dataset "dataset": { "files": [], "compression_type": None, "data_type": "int", "other_transformations": [], "data_name": "data", } # (2) General hyperparams "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "scalar_data", } Here: 1. For the hyperparameters in the :attr:`"dataset"` field: `"files"`: str or list A (list of) file path(s). Each line contains a single scalar number. `"compression_type"`: str, optional One of "" (no compression), "ZLIB", or "GZIP". `"data_type"`: str The scalar type. Types defined in :meth:`~texar.torch.utils.dtypes.get_supported_scalar_types` are supported. `"other_transformations"`: list A list of transformation functions or function names/paths to further transform each single data instance. (More documentations to be added.) `"data_name"`: str Name of the dataset. 2. For the **general** hyperparameters, see :meth:`texar.torch.data.DatasetBase.default_hparams` for details. """ hparams = DatasetBase.default_hparams() hparams["name"] = "scalar_data" hparams.update({ "dataset": _default_scalar_dataset_hparams() }) return hparams
def process(self, raw_example: List[str]) -> Union[bool, int, float]: assert len(raw_example) == 1 example_: Union[int, str] if self._data_type == np.bool_: example_ = strtobool(raw_example[0]) else: example_ = raw_example[0] example = self._data_type(example_) for transform in self._other_transforms: example = transform(example) return example def collate(self, examples: List[Union[bool, int, float]]) -> Batch: # convert the list of strings into appropriate tensors here examples_np = np.array(examples, dtype=self._data_type) collated_examples = torch.from_numpy(examples_np) return Batch(len(examples), batch={self.data_name: collated_examples})
[docs] def list_items(self): r"""Returns the list of item names that the data can produce. Returns: A list of strings. """ return [self.hparams.dataset["data_name"]]
@property def data_name(self): r"""The name of the data tensor, "data" by default if not specified in :attr:`hparams`. """ return self.hparams.dataset["data_name"]