Source code for

# Copyright 2019 The Texar Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
Data class that supports reading pickled data as record structures.
import copy
import io
import pickle
import warnings
from enum import Enum
from typing import (
    Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar, Union)

import numpy as np
import torch

from import DatasetBase, DataSource
from import Batch, padded_batch
from texar.torch.hyperparams import HParams
from texar.torch.utils.dtypes import get_numpy_dtype
from texar.torch.utils.types import MaybeList

__all__ = [

def _default_record_dataset_hparams():
    r"""Returns hyperparameters of a record dataset with default values.

    See :meth:`` for details.
    return {
        "files": [],
        "feature_types": None,
        "feature_original_types": None,
        "feature_convert_types": {},
        "image_options": {},
        "compression_type": None,
        "other_transformations": [],
        "num_shards": None,
        "shard_id": None,
        "data_name": None,
        "@no_typecheck": [

RawExample = TypeVar('RawExample')

[docs]class PickleDataSource(DataSource[RawExample]): r"""Data source for reading from (multiple) pickled binary files. Each file could contain multiple pickled objects, and each object is yielded as an example. This data source does not support indexing. Args: file_paths (str or list[str]): Paths to pickled binary files. lists_are_examples (bool): If `True`, lists will be treated as a single example; if `False`, each element in the list will be treated as separate examples. Default is `True`. Set this to `False` if the entire pickled binary file is a list. .. note:: It is recommended against storing all examples as a list, because in this case, all examples can only be accessed after the whole list is parsed. pickle_kwargs: Additional keyword arguments to pass to :meth:`pickle.load`. """ def __init__(self, file_paths: MaybeList[str], lists_are_examples: bool = True, **pickle_kwargs): if isinstance(file_paths, str): file_paths = [file_paths] self._file_paths = file_paths self._lists_are_examples = lists_are_examples self._pickle_kwargs = pickle_kwargs def __iter__(self): for path in self._file_paths: with open(path, 'rb') as f: if self._lists_are_examples: while True: try: ex = pickle.load(f, **self._pickle_kwargs) if isinstance(ex, list): yield from ex else: yield ex except EOFError: break else: while True: try: yield pickle.load(f, **self._pickle_kwargs) except EOFError: break
TransformFn = Callable[[bytes], torch.ByteTensor] def _create_image_transform(height: Optional[int], width: Optional[int], resize_method: Union[str, int] = 'bilinear') \ -> TransformFn: r"""Create a function based on `Pillow image transforms <>` that performs resizing with desired resize method (interpolation). Args: height (int, optional): Height of the transformed image. Set to `None` to not perform resizing. width (int, optional): Width of the transformed image. Set to `None` to not perform resizing. resize_method (str or int, optional): Interpolation method to use. Supported values are ``"nearest"`` (nearest neighbor), ``"bilinear"``, ``"bicubic"``, and ``"lanczos"``. Enum values from PIL (e.g., ``PIL.Image.BILINEAR``) are also supported. Defaults to ``"bilinear"``. Returns: The created transformation function. """ try: import PIL.Image except ImportError: raise ImportError( "To use image resizing with RecordData, the Pillow library must be " "installed. Please see " "") # We take the final part of a possibly dot-separated string for # compatibility reasons, because in Texar-TF `resize_method` could take the # form of "tf.image.ResizeMethod.BILINEAR". if isinstance(resize_method, int): interpolation = resize_method else: method = resize_method.lower().split('.')[-1] if method in ["nearest_neighbor", "nearest"]: interpolation = PIL.Image.NEAREST elif method == "bilinear": interpolation = PIL.Image.BILINEAR elif method == "bicubic": interpolation = PIL.Image.BICUBIC elif method == "lanczos": interpolation = PIL.Image.LANCZOS else: raise ValueError(f"Unsupported resize method '{resize_method}'") if height is None or width is None: size = None else: size = (height, width) def transform(raw_bytes): image = if size is not None: image = image.resize(size, interpolation) # Convert to torch Tensor. Adapted from # torchvision.transform.functional.to_tensor. if image.mode == '1': tensor = 255 * torch.from_numpy( np.array(image, np.uint8, copy=False)) else: tensor = torch.ByteTensor( torch.ByteStorage.from_buffer(image.tobytes())) # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK. if image.mode == 'YCbCr': n_channel = 3 elif image.mode == 'I;16': n_channel = 1 else: n_channel = len(image.mode) tensor = tensor.view(image.size[1], image.size[0], n_channel) return tensor return transform class CollateMethod(Enum): StackedTensor = "stacked_tensor" PaddedTensor = "padded_tensor" List = "list" class FeatureDescription(NamedTuple): r"""Description of a feature.""" collate_method: CollateMethod dtype: Optional[np.dtype] shape: Optional[Tuple[int, ...]] _DEPRECATED_NAME_MAPPING = { "FixedLenFeature": "stacked_tensor", "FixedLenSequenceFeature": "padded_tensor", "VarLenFeature": "list", } def _convert_feature_hparams(feature_types: Union[Dict[str, Any], HParams]) \ -> Dict[str, FeatureDescription]: if isinstance(feature_types, HParams): feature_types = feature_types.todict() else: feature_types = copy.deepcopy(feature_types) show_deprecation_warning = False for key, value in feature_types.items(): if len(value) > 1 and value[1] in _DEPRECATED_NAME_MAPPING: feature_types[key] = ( value[0], _DEPRECATED_NAME_MAPPING[value[1]], *value[2:]) show_deprecation_warning = True if show_deprecation_warning: warnings.warn(f"RecordData feature types " f"{', '.join(repr(x) for x in _DEPRECATED_NAME_MAPPING)} " f"are deprecated. Please see RecordData.default_hparams " f"for update instructions.", UserWarning) features = {} for key, value in feature_types.items(): shape: Optional[Tuple[int, ...]] = None if len(value) == 3: if isinstance(value[-1], int): shape = (value[-1],) elif all(isinstance(x, int) for x in value[-1]): shape = tuple(value[-1]) if len(shape) == 0: shape = (1,) # scalar tensor else: raise ValueError(f"'shape' of feature '{key}' is not of type " f"int, tuple, or torch.Size") if len(value) < 2: collate_method = CollateMethod.StackedTensor else: try: collate_method = CollateMethod(value[1]) except ValueError: values = [x.value for x in CollateMethod.__members__.values()] raise ValueError( f"Unsupported feature collate method '{value[1]}' for " f"feature '{key}', only " f"{', '.join(repr(x) for x in values)} are " f"supported as of now.") dtype = None if value[0] is not None: dtype = get_numpy_dtype(value[0]) elif collate_method is not CollateMethod.List: raise ValueError(f"'dtype' for feature '{key}' must not be None " f"unless collate method is 'list'") features[key] = FeatureDescription(collate_method, dtype, shape) return features def _check_shape(tensor: np.ndarray, key: str, descriptor: FeatureDescription): if descriptor.shape is None: return # Check whether shape matches. if descriptor.collate_method is CollateMethod.PaddedTensor: shape = tensor.shape[1:] else: shape = tensor.shape if len(shape) == 0: shape = (1,) # scalar tensor if shape != descriptor.shape: if descriptor.collate_method is CollateMethod.PaddedTensor: raise ValueError( f"Expected tensor of shape {('any', *descriptor.shape)} for " f"feature {key}, but received tensor of shape {tensor.shape}") else: raise ValueError( f"Expected tensor of shape {descriptor.shape} for " f"feature {key}, but received tensor of shape {tensor.shape}")
[docs]class RecordData(DatasetBase[Dict[str, Any], Dict[str, Any]]): r"""Record data which loads and processes pickled files. This module can be used to process image data, features, etc. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. device: The device of the produced batches. For GPU training, set to current CUDA device. The module reads and restores data from pickled files and results in a dataset whose element is a Python `dict` that maps feature names to feature values. The features names and dtypes are specified in :attr:`hparams.dataset.feature_types`. The module also provides simple processing options for image data, such as image resize. Example: .. code-block:: python # Read data from pickled file hparams={ 'dataset': { 'files': 'image1.pkl', 'feature_types': { 'height': ['int64', 'list'], # or 'stacked_tensor' 'width': ['int64', 'list'], # or 'stacked_tensor' 'label': ['int64', 'stacked_tensor'], 'image_raw': ['bytes', 'stacked_tensor'], } }, 'batch_size': 1 } data = RecordData(hparams) iterator = DataIterator(data) batch = next(iter(iterator)) # get the first batch in dataset # batch == { # 'data': { # 'height': [239], # 'width': [149], # 'label': tensor([1]), # # # 'image_raw' is a NumPy ndarray of raw image bytes in this # # example. # 'image_raw': [...], # } # } .. code-block:: python # Read image data from pickled file and do resizing hparams={ 'dataset': { 'files': 'image2.pkl', 'feature_types': { 'label': ['int64', 'stacked_tensor'], 'image_raw': ['bytes', 'stacked_tensor'], }, 'image_options': { 'image_feature_name': 'image_raw', 'resize_height': 512, 'resize_width': 512, } }, 'batch_size': 1 } data = RecordData(hparams) iterator = DataIterator(data) batch = next(iter(iterator)) # get the first batch in dataset # batch == { # 'data': { # 'label': tensor([1]), # # # "image_raw" is a tensor of image pixel data in this # # example. Each image has a width of 512 and height of 512. # 'image_raw': tensor([...]) # } # } """ def __init__(self, hparams=None, device: Optional[torch.device] = None, data_source: Optional[DataSource] = None): self._hparams = HParams(hparams, self.default_hparams()) feature_types = self._hparams.dataset.feature_original_types if feature_types is not None: warnings.warn( "'feature_original_types' of RecordData is deprecated. Please " "see default_hparams of RecordData for update instructions") if self._hparams.dataset.feature_types is not None: feature_types = self._hparams.dataset.feature_types elif feature_types is None: raise ValueError("'feature_types' must be specified") self._features = _convert_feature_hparams(feature_types) convert_types = self._hparams.dataset.feature_convert_types self._convert_types = {key: get_numpy_dtype(value) for key, value in convert_types.items()} for key, dtype in self._convert_types.items(): self._features[key] = self._features[key]._replace(dtype=dtype) image_options = self._hparams.dataset.image_options if isinstance(image_options, HParams): image_options = [image_options] self._image_transforms: Dict[str, TransformFn] = {} for options in image_options: key = options.get('image_feature_name') if key is None or key not in self._features: continue self._image_transforms[key] = _create_image_transform( options.get('resize_height'), options.get('resize_width'), options.get('resize_method') or 'bilinear') self._other_transforms = self._hparams.dataset.other_transformations if data_source is None: data_source = PickleDataSource[Dict[str, Any]]( self._hparams.dataset.files) super().__init__(data_source, hparams, device) class _RecordWriter(io.BytesIO): def __init__(self, file_path: str, features: Dict[str, FeatureDescription]): super().__init__() self._file_path = file_path self._features = features self._file_handle = open(self._file_path, 'wb') def close(self) -> None: self._file_handle.close() def write(self, example: Dict[str, Any]): # type: ignore converted = {} for key, descriptor in self._features.items(): value = example[key] if descriptor.collate_method is CollateMethod.List: converted[key] = value continue # Convert to NumPy array. value = np.asarray(value, dtype=descriptor.dtype) _check_shape(value, key, descriptor) converted[key] = value pickle.dump(converted, self._file_handle)
[docs] @classmethod def writer(cls, file_path: str, feature_types: Dict[str, Tuple[Any, ...]]) \ -> '_RecordWriter': r"""Construct a file writer object that saves records in pickled format. Example: .. code-block:: python file_path = "data/train.pkl" feature_types = { "input_ids": ["int64", "stacked_tensor", 128], "label_ids": ["int64", "stacked_tensor"], } with, feature_types) as writer: writer.write({ "input_ids": np.random.randint(0, 100, size=128), "label_ids": np.random.randint(0, 100), }) Args: file_path (str): Path to save the dataset. feature_types: Feature names and types. Please refer to :meth:`default_hparams` for details. Returns: A file writer object. """ feature_types = _convert_feature_hparams(feature_types) return cls._RecordWriter(file_path, feature_types)
[docs] @staticmethod def default_hparams(): r"""Returns a dictionary of default hyperparameters. .. code-block:: python { # (1) Hyperparameters specific to the record data 'dataset': { 'files': [], 'feature_types': {}, 'feature_convert_types': {}, 'image_options': {}, "num_shards": None, "shard_id": None, "other_transformations": [], "data_name": None, } # (2) General hyperparameters "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "tfrecord_data", } Here: 1. For the hyperparameters in the :attr:`"dataset"` field: `"files"`: str or list A (list of) pickled file path(s). `"feature_types"`: dict The feature names (`str`) with their descriptions in the form of ``feature_name: [dtype, feature_collate_method, shape]``: - ``dtype`` is a Python type (``int``, ``str``), dtype instance from PyTorch (``torch.float``), NumPy (``np.int64``), or TensorFlow (``tf.string``), or their stringified names such as ``"torch.float"`` and ``"np.int64"``. The feature will be read from the files and parsed into this dtype. - ``feature_collate_method`` is of type ``str``, and describes how features are collated in the batch. Available values are: - ``"stacked_tensor"``: Features are assumed to be tensors of a fixed shape (or scalars). When collating, features are stacked, with the batch dimension being the first dimension. This is the default value if ``feature_collate_method`` is not specified. For example: - 5 scalar features -> a tensor of shape [5]. - 4 tensor features, each of shape [6, 5] -> a tensor of shape [4, 6, 5]. - ``"padded_tensor"``: Features are assumed to be tensors, with all dimensions except the first having the same size. When collating, features are padded with zero values along the end of the first dimension so that every tensor has the same size, and then stacked, with the batch dimension being the first dimension. For example: - 3 tensor features, with shapes [4, 7, 8], [5, 7, 8], and [4, 7, 8] -> a tensor of shape [3, 5, 7, 8]. - ``"list"``: Features can be any objects. When collating, the features are stored in a Python list. - ``shape`` is optional, and can be of type ``int``, `tuple``, or ``torch.Size``. If specified, shapes of tensor features will be checked, depending on the ``feature_collate_method``: - ``"stacked_tensor"``: The shape of every feature tensor must be ``shape``. - ``"padded_tensor"``: The shape (excluding first dimension) of every feature tensor must be ``shape``. - ``"list"``: ``shape`` is ignored. .. note:: Shape check is performed before any transformations are applied. Example: .. code-block:: python feature_types = { "input_ids": ["int64", "stacked_tensor", 128], "label_ids": ["int64", "stacked_tensor"], "name_lists": ["string", "list"], } .. note:: This field is named `"feature_original_types"` in Texar-TF. This name is still supported, but is deprecated in favor of `"feature_types"`. Texar-TF also uses different names for feature types: - ``"FixedLenFeature"`` corresponds to ``"stacked_tensor"``. - ``"FixedLenSequenceFeature"`` corresponds to ``"padded_tensor"``. - ``"VarLenFeature"`` corresponds to ``"list"``. These names are also accepted in Texar-PyTorch, but are deprecated in favor of the new names. `"feature_convert_types"`: dict, optional Specifies dtype converting after reading the data files. This `dict` maps feature names to desired data dtypes. For example, you can first read a feature into dtype ``torch.int32`` by specifying in :attr:`"feature_types"` above, and convert the feature to dtype ``"torch.long"`` by specifying here. Features not specified here will not do dtype-convert. - ``dtype`` is a Python type (`int`, `str`), dtype instance from PyTorch (``torch.float``), NumPy (``np.int64``), or TensorFlow (``tf.string``), or their stringified names such as ``"torch.float"`` and ``"np.int64"``. Note that this converting process happens after all the data are restored. Example: .. code-block:: python feature_convert_types = { "input_ids": "int32", "label_ids": "int32", } `"image_options"`: dict, optional Specifies the image feature name and performs image resizing, includes three fields: - `"image_feature_name"`: str The name of the feature which contains the image data. If set, the image data will be restored in a `numpy.ndarray`. - `"resize_height"`: int The height of the image after resizing. - `"resize_width"`: int The width of the image after resizing. If any of :attr:`"resize_height"` or :attr:`"resize_width"` is not set, image data will be restored with original shape. `"num_shards"`: int, optional The number of data shards in distributed mode. Usually set to the number of processes in distributed computing. Used in combination with :attr:`"shard_id"`. .. warning:: Sharding is not yet supported. This option (and related ones below) will be ignored. `"shard_id"`: int, optional Sets the unique id to identify a shard. The module will processes only the corresponding shard of the whole data. Used in combination with :attr:`"num_shards"`. For example, in a case of distributed computing on 2 GPUs, the hyperparameters of the data module for the two processes can be configured as below, respectively. For GPU 0: .. code-block:: python dataset: { ... "num_shards": 2, "shard_id": 0 } For GPU 1: .. code-block:: python dataset: { ... "num_shards": 2, "shard_id": 1 } Also refer to `examples/bert` for a use case. `"other_transformations"`: list A list of transformation functions or function names/paths to further transform each single data instance. `"data_name"`: str Name of the dataset. 2. For the **general** hyperparameters, see :meth:`` for details. """ hparams = DatasetBase.default_hparams() hparams["name"] = "record_data" hparams.update({ "dataset": _default_record_dataset_hparams() }) return hparams
def process(self, raw_example: Dict[str, Any]) -> Dict[str, Any]: for key, descriptor in self._features.items(): _check_shape(raw_example[key], key, descriptor) example = raw_example for key, dtype in self._convert_types.items(): example[key] = np.asarray(example[key], dtype=dtype) for key, transform in self._image_transforms.items(): example[key] = transform(example[key]) for transform in self._other_transforms: example = transform(example) return example def collate(self, examples: List[Dict[str, Any]]) -> Batch: batch = {} for key, descriptor in self._features.items(): values = [ex[key] for ex in examples] if descriptor.collate_method is not CollateMethod.List: # NumPy functions work on PyTorch tensors too. if descriptor.collate_method is CollateMethod.StackedTensor: values = np.stack(values, axis=0) # type: ignore else: # padded_tensor values, _ = padded_batch(values) # type: ignore if (not isinstance(values, torch.Tensor) and descriptor.dtype not in [np.str_, np.bytes_]): values = torch.from_numpy(values) # type: ignore else: # Just put everything in a Python list. pass batch[key] = values return Batch(len(examples), batch)
[docs] def list_items(self) -> List[str]: r"""Returns the list of item names that the data can produce. Returns: A list of strings. """ return list(self._features.keys())
@property def feature_names(self): r"""A list of feature names. """ return self.list_items()