# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions related to tensor shapes.
"""
from typing import Any, List, Optional, Union
import numpy as np
import torch
from texar.torch.utils import utils
from texar.torch.utils.types import MaybeList
__all__ = [
"transpose_batch_time",
"get_batch_size",
"get_rank",
"mask_sequences",
"flatten",
"pad_and_concat",
]
[docs]def transpose_batch_time(inputs: torch.Tensor) -> torch.Tensor:
r"""Transposes inputs between time-major and batch-major.
Args:
inputs: A Tensor of shape ``[batch_size, max_time, ...]`` (batch-major)
or ``[max_time, batch_size, ...]`` (time-major), or a (possibly
nested) tuple of such elements.
Returns:
A (possibly nested tuple of) Tensor with transposed batch and
time dimensions of inputs.
"""
return inputs.transpose(0, 1)
def get_batch_size(tensor: torch.Tensor) -> int:
r"""Returns an ``int`` representing the batch size, i.e.,
the size of the 1st dimension of :attr:`tensor`.
"""
return tensor.size(0)
def get_rank(tensor: torch.Tensor) -> int:
r"""Returns the tensor rank as a Python ``int``. The input tensor can also
be a Python array.
Args:
tensor: A Tensor or Python array.
Returns:
A Python ``int`` representing the rank of :attr:`tensor`. Returns
`None` if the rank cannot be determined.
"""
if torch.is_tensor(tensor):
rank = tensor.dim()
else:
array = np.asarray(tensor)
rank = array.ndim
return rank
[docs]def mask_sequences(sequence: Union[torch.Tensor, List[int]],
sequence_length: Union[torch.LongTensor, List[int]],
dtype: Optional[torch.dtype] = None,
time_major: bool = False) -> torch.Tensor:
r"""Masks out sequence entries that are beyond the respective sequence
lengths. Masks along the time dimension.
:attr:`sequence` and :attr:`sequence_length` can either be python
arrays or Tensors, respectively. If both are Python arrays (or None), the
return will be a Python array as well.
Args:
sequence: A Tensor or Python array of sequence values.
If ``time_major==False`` (default), this must be a Tensor of shape
``[batch_size, max_time, ...]``. The batch and time dimension is
exchanged if ``time_major==True``.
sequence_length: A Tensor or python array of shape ``[batch_size]``.
Time steps beyond the respective sequence lengths will be
made zero.
dtype (dtype): Type of :attr:`sequence`. If `None`, infer from
:attr:`sequence` automatically.
time_major (bool): The shape format of the inputs. If `True`,
:attr:`sequence` must have shape
``[max_time, batch_size, ...]``.
If `False` (default), :attr:`sequence` must have
shape ``[batch_size, max_time, ...]``.
Returns:
The masked sequence, i.e., a Tensor or python array of the same shape
as :attr:`sequence` but with masked-out entries (set to zero).
If both :attr:`sequence` and :attr:`sequence_length` are python
arrays, the returned value is a python array as well.
"""
if not torch.is_tensor(sequence):
sequence = torch.tensor(sequence, dtype=dtype)
sequence: torch.Tensor
rank = sequence.dim()
if rank < 2:
raise ValueError("`sequence` must be 2D or higher order.")
if time_major:
sequence = transpose_batch_time(sequence)
max_time = sequence.size(1)
if dtype is None:
dtype = sequence.dtype
mask = utils.sequence_mask(sequence_length, max_time, dtype=dtype)
mask = mask.view(*mask.size(), *([1] * (rank - 2)))
sequence = sequence * mask
if time_major:
sequence = transpose_batch_time(sequence)
return sequence
[docs]def flatten(tensor: torch.Tensor, preserve_dims: int,
flattened_dim: Optional[int] = None) -> torch.Tensor:
r"""Flattens a tensor whiling keeping several leading dimensions.
:attr:`preserve_dims` must be less than or equal to tensor's rank.
Args:
tensor: A Tensor to flatten.
preserve_dims (int): The number of leading dimensions to preserve.
flattened_dim (int, optional): The size of the resulting flattened
dimension. If not given, infer automatically.
Returns:
A Tensor with rank :attr:`preserve_dims` +1.
Example:
.. code-block:: python
x = torch.ones(d_1, d_2, d_3, d_4)
y = flatten(x, 2) # y.shape == [d_1, d_2, d_3 * d_4]
"""
if preserve_dims > tensor.dim():
raise ValueError(
"`preserve_dims` must be less than or equal to tensor's rank")
if flattened_dim is None:
flattened_dim = -1
shape = tensor.size()[:preserve_dims] + (flattened_dim,)
tensor_ = tensor.reshape(shape)
return tensor_
[docs]def pad_and_concat(values: List[torch.Tensor], axis: int,
pad_axis: Optional[MaybeList[int]] = None,
pad_constant_values: Any = 0) -> torch.Tensor:
r"""Concatenates tensors along one dimension. Pads each of other dimensions
of the tensors to the corresponding maximum size if necessary.
Args:
values: A list of Tensors of the same rank.
axis (int): A Python int. Dimension along which to concatenate.
pad_axis (int or list, optional): A Python int or a list of int.
Dimensions to pad. Paddings are only added to the end of
corresponding dimensions. If `None`, all dimensions except the
:attr:`axis` dimension are padded.
pad_constant_values: The scalar pad value to use. Must be same type
as the tensors.
Returns:
A ``Tensor`` resulting from padding and concatenation of the input
tensors.
Raises:
ValueError: If ``rank`` of :attr:`values` are not consistent.
Example:
.. code-block:: python
a = torch.ones([1, 2])
b = torch.ones([2, 3])
c = pad_and_concat([a,b], 0)
# c.shape == [3, 3]
# c == [[1, 1, 0],
# [1, 1, 1],
# [1, 1, 1]]
d = pad_and_concat([a,b], 1)
# d.shape == [2, 5]
# d == [[1, 1, 1, 1, 1]
# [0, 0, 1, 1, 1]]
"""
rank = values[0].dim()
if any(value.dim() != rank for value in values):
raise ValueError("All tensors in `values` must have the same rank.")
if pad_axis is None:
pad_axis = [r for r in range(rank) if r != axis]
elif isinstance(pad_axis, int):
pad_axis = [pad_axis]
for pad_dim in pad_axis:
max_dim_size = max(v.size(pad_dim) for v in values)
for i, v in enumerate(values):
pad_shape: List[int] = list(v.size())
if pad_shape[pad_dim] == max_dim_size:
continue
pad_shape[pad_dim] = max_dim_size - pad_shape[pad_dim]
padding = values[0].new_full(tuple(pad_shape), pad_constant_values)
values[i] = torch.cat((v, padding), dim=pad_dim)
return torch.cat(values, dim=axis)