Source code for texar.torch.losses.losses_utils

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various utilities for losses.
"""

from typing import Optional

import torch

from texar.torch.utils.shapes import mask_sequences, transpose_batch_time
from texar.torch.utils.types import MaybeList

__all__ = [
    "mask_and_reduce",
    "reduce_batch_time",
    "reduce_dimensions",
]


[docs]def mask_and_reduce(sequence: torch.Tensor, sequence_length: Optional[torch.LongTensor], rank: int = 2, average_across_batch: bool = True, average_across_timesteps: bool = False, average_across_remaining: bool = False, sum_over_batch: bool = False, sum_over_timesteps: bool = True, sum_over_remaining: bool = True, dtype: Optional[torch.dtype] = None, time_major: bool = False) -> torch.Tensor: r"""Masks out sequence entries that are beyond the respective sequence lengths, and reduces (average or sum) away dimensions. This is a combination of :func:`~texar.torch.utils.shapes.mask_sequences` and :func:`~texar.torch.losses.losses_utils.reduce_batch_time`. Args: sequence: A tensor of sequence values. If `time_major=False` (default), this must be a tensor of shape `[batch_size, max_time, d_2, ..., d_rank]`, where the rank of the tensor is specified with :attr:`rank`. The batch and time dimensions are exchanged if `time_major` is True. sequence_length: A tensor of shape `[batch_size]`. Time steps beyond the respective sequence lengths will be made zero. If `None`, no masking is performed. rank (int): The rank of :attr:`sequence`. Must be >= 2. Default is 2, i.e., `sequence` is a 2D Tensor consisting of batch and time dimensions. average_across_timesteps (bool): If set, average the sequence across the time dimension. Must not set `average_across_timesteps` and `sum_over_timesteps` at the same time. average_across_batch (bool): If set, average the sequence across the batch dimension. Must not set `average_across_batch`' and `sum_over_batch` at the same time. average_across_remaining (bool): If set, average the sequence across the remaining dimensions. Must not set `average_across_remaining`' and `sum_over_remaining` at the same time. sum_over_timesteps (bool): If set, sum the sequence across the time dimension. Must not set `average_across_timesteps` and `sum_over_timesteps` at the same time. sum_over_batch (bool): If set, sum the sequence across the batch dimension. Must not set `average_across_batch` and `sum_over_batch` at the same time. sum_over_remaining (bool): If set, sum the sequence across the remaining dimension. Must not set `average_across_remaining` and `sum_over_remaining` at the same time. dtype (torch.dtype): The dtype of the returned mask. time_major (bool): The shape format of the inputs. If `True`, :attr:`sequence` must have shape `[max_time, batch_size, ...]`. If `False` (default), `sequence` must have shape `[batch_size, max_time, ...]`. Returns: A tensor containing the masked and reduced sequence. """ if rank < 2: raise ValueError('`rank` must be >= 2.') if time_major: sequence = transpose_batch_time(sequence) if sequence_length is not None: sequence = mask_sequences(sequence, sequence_length, dtype=dtype, time_major=False) if rank > 2: if average_across_remaining and sum_over_remaining: raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") if average_across_remaining: for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.mean(sequence, dim=axis) elif sum_over_remaining: for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.sum(sequence, dim=axis) sequence = reduce_batch_time(sequence, sequence_length, average_across_batch, average_across_timesteps, sum_over_batch, sum_over_timesteps) reduce_time = average_across_timesteps or sum_over_timesteps reduce_batch = average_across_batch or sum_over_batch if not reduce_time and not reduce_batch and time_major: sequence = transpose_batch_time(sequence) return sequence
[docs]def reduce_batch_time(sequence: torch.Tensor, sequence_length: Optional[torch.LongTensor], average_across_batch: bool = True, average_across_timesteps: bool = False, sum_over_batch: bool = False, sum_over_timesteps: bool = True) -> torch.Tensor: r"""Average or sum over the respective dimensions of :attr:`sequence`, which is of shape `[batch_size, max_time, ...]`. Assumes :attr:`sequence` has been properly masked according to :attr:`sequence_length`. Args: sequence: A tensor to reduce. sequence_length: A tensor of shape `[batch_size]`. Time steps beyond the respective sequence lengths will be made zero. If `None`, no masking is performed. average_across_batch (bool): If set, average the sequence across the batch dimension. Must not set `average_across_batch`' and `sum_over_batch` at the same time. average_across_timesteps (bool): If set, average the sequence across the time dimension. Must not set `average_across_timesteps` and `sum_over_timesteps` at the same time. sum_over_batch (bool): If set, sum the sequence across the batch dimension. Must not set `average_across_batch` and `sum_over_batch` at the same time. sum_over_timesteps (bool): If set, sum the sequence across the time dimension. Must not set `average_across_timesteps` and `sum_over_timesteps` at the same time. Returns: A tensor with dimension reduction. """ if average_across_timesteps and sum_over_timesteps: raise ValueError("Only one of `average_across_timesteps` and " "`sum_over_timesteps` can be set.") if average_across_batch and sum_over_batch: raise ValueError("Only one of `average_across_batch` and " "`sum_over_batch` can be set.") if sum_over_timesteps: sequence = torch.sum(sequence, dim=1) elif average_across_timesteps: if sequence_length is None: sequence = torch.mean(sequence, dim=1) else: sequence = (torch.sum(sequence, dim=1).float() / sequence_length.float()) if sum_over_batch: sequence = torch.sum(sequence, dim=0) elif average_across_batch: sequence = torch.mean(sequence, dim=0) return sequence
[docs]def reduce_dimensions(tensor: torch.Tensor, average_axes: Optional[MaybeList[int]] = None, sum_axes: Optional[MaybeList[int]] = None, keepdims: Optional[bool] = None) -> torch.Tensor: r"""Average or sum over dimensions of :attr:`tensor`. :attr:`average_axes` and :attr:`sum_axes` must be mutually exclusive. That is, elements in `average_axes` must not be contained in `sum_axes`, and vice versa. Args: tensor: A tensor to reduce. average_axes (optional): A (list of) `int` that indicates the dimensions to reduce by taking average. sum_axes (optional): A (list of) `int` that indicates the dimensions to reduce by taking sum. keepdims (optional): If `True`, retains reduced dimensions with length 1. Returns: A tensor with dimension reduction. """ reduced_axes = set() if average_axes is not None: if not isinstance(average_axes, (list, tuple)): average_axes = [average_axes] if len(average_axes) > 0: for average_axis in average_axes: tensor = torch.mean(tensor, dim=average_axis, keepdim=True) reduced_axes.update(average_axes) if sum_axes is not None: if not isinstance(sum_axes, (list, tuple)): sum_axes = [sum_axes] if len(sum_axes) > 0: for sum_axis in sum_axes: tensor = torch.sum(tensor, dim=sum_axis, keepdim=True) reduced_axes.update(sum_axes) if average_axes is not None: if len(reduced_axes) != len(average_axes) + len(sum_axes): raise ValueError('`average_axes` and `sum_axes` must not ' 'have overlapped elements.') if not keepdims: for axis in sorted(list(reduced_axes), reverse=True): tensor = torch.squeeze(tensor, dim=axis) return tensor