Source code for texar.torch.core.cell_wrappers

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TensorFlow-style RNN cell wrappers.
"""

from typing import Callable, Generic, List, Optional, Tuple, TypeVar, Union

import torch
import torch.nn.functional as F
from torch import nn

from texar.torch.core.attention_mechanism import (
    AttentionMechanism, AttentionWrapperState, compute_attention)
from texar.torch.utils import utils
from texar.torch.utils.types import MaybeList

__all__ = [
    'RNNState',
    'LSTMState',
    'HiddenState',
    'wrap_builtin_cell',
    'RNNCellBase',
    'RNNCell',
    'GRUCell',
    'LSTMCell',
    'DropoutWrapper',
    'ResidualWrapper',
    'HighwayWrapper',
    'MultiRNNCell',
    'AttentionWrapper',
]

State = TypeVar('State')
RNNState = torch.Tensor
LSTMState = Tuple[torch.Tensor, torch.Tensor]

HiddenState = MaybeList[Union[RNNState, LSTMState]]


[docs]def wrap_builtin_cell(cell: nn.RNNCellBase): r"""Convert a built-in :torch_nn:`RNNCellBase` derived RNN cell to our wrapped version. Args: cell: the RNN cell to wrap around. Returns: The wrapped cell derived from :class:`texar.torch.core.cell_wrappers.RNNCellBase`. """ # convert cls to corresponding derived wrapper class if isinstance(cell, nn.RNNCell): self = RNNCellBase.__new__(RNNCell) elif isinstance(cell, nn.GRUCell): self = RNNCellBase.__new__(GRUCell) elif isinstance(cell, nn.LSTMCell): self = RNNCellBase.__new__(LSTMCell) else: raise TypeError(f"Unrecognized class {type(cell)}.") RNNCellBase.__init__(self, cell) return self
[docs]class RNNCellBase(nn.Module, Generic[State]): r"""The base class for RNN cells in our framework. Major differences over :torch_nn:`RNNCell` are two-fold: 1. Holds an :torch_nn:`Module` which could either be a built-in RNN cell or a wrapped cell instance. This design allows :class:`RNNCellBase` to serve as the base class for both vanilla cells and wrapped cells. 2. Adds :meth:`zero_state` method for initialization of hidden states, which can also be used to implement batch-specific initialization routines. """ def __init__(self, cell: Union[nn.RNNCellBase, 'RNNCellBase']): super().__init__() if not isinstance(cell, nn.Module): raise ValueError("Type of parameter 'cell' must be derived from" "nn.Module, and has 'input_size' and 'hidden_size'" "attributes.") self._cell = cell @property def input_size(self) -> int: r"""The number of expected features in the input.""" return self._cell.input_size @property def hidden_size(self) -> int: r"""The number of features in the hidden state.""" return self._cell.hidden_size @property def _param(self) -> nn.Parameter: r"""Convenience method to access a parameter under the module. Useful when creating tensors of the same attributes using `param.new_*`. """ return next(self.parameters())
[docs] def init_batch(self): r"""Perform batch-specific initialization routines. For most cells this is a no-op. """ pass
[docs] def zero_state(self, batch_size: int) -> State: r"""Return zero-filled state tensor(s). Args: batch_size: int, the batch size. Returns: State tensor(s) initialized to zeros. Note that different subclasses might return tensors of different shapes and structures. """ self.init_batch() if isinstance(self._cell, nn.RNNCellBase): state = self._param.new_zeros( batch_size, self.hidden_size, requires_grad=False) else: state = self._cell.zero_state(batch_size) return state
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: r""" Returns: A tuple of (output, state). For single layer RNNs, output is the same as state. """ if state is None: batch_size = input.size(0) state = self.zero_state(batch_size) return self._cell(input, state)
class BuiltinCellWrapper(RNNCellBase[State]): r"""Base class for wrappers over built-in :torch_nn:`RNNCellBase` RNN cells. """ def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: if state is None: batch_size = input.size(0) state = self.zero_state(batch_size) new_state = self._cell(input, state) return new_state, new_state
[docs]class RNNCell(BuiltinCellWrapper[RNNState]): r"""A wrapper over :torch_nn:`RNNCell`.""" def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): cell = nn.RNNCell( input_size, hidden_size, bias=bias, nonlinearity=nonlinearity) super().__init__(cell)
[docs]class GRUCell(BuiltinCellWrapper[RNNState]): r"""A wrapper over :torch_nn:`GRUCell`.""" def __init__(self, input_size, hidden_size, bias=True): cell = nn.GRUCell(input_size, hidden_size, bias=bias) super().__init__(cell)
[docs]class LSTMCell(BuiltinCellWrapper[LSTMState]): r"""A wrapper over :torch_nn:`LSTMCell`, additionally providing the option to initialize the forget-gate bias to a constant value. """ def __init__(self, input_size, hidden_size, bias=True, forget_bias: Optional[float] = None): if forget_bias is not None and not bias: raise ValueError("Parameter 'forget_bias' must be set to None when" "'bias' is set to False.") cell = nn.LSTMCell(input_size, hidden_size, bias=bias) if forget_bias is not None: with torch.no_grad(): cell.bias_ih[hidden_size:(2 * hidden_size)].fill_(forget_bias) cell.bias_hh[hidden_size:(2 * hidden_size)].fill_(forget_bias) super().__init__(cell)
[docs] def zero_state(self, batch_size: int) -> LSTMState: r"""Returns the zero state for LSTMs as (h, c).""" state = self._param.new_zeros( batch_size, self.hidden_size, requires_grad=False) return state, state
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[LSTMState] = None) \ -> Tuple[torch.Tensor, LSTMState]: if state is None: batch_size = input.size(0) state = self.zero_state(batch_size) new_state = self._cell(input, state) return new_state[0], new_state
[docs]class DropoutWrapper(RNNCellBase[State]): r"""Operator adding dropout to inputs and outputs of the given cell.""" def __init__(self, cell: RNNCellBase[State], input_keep_prob: float = 1.0, output_keep_prob: float = 1.0, state_keep_prob: float = 1.0, variational_recurrent=False): r"""Create a cell with added input, state, and/or output dropout. If `variational_recurrent` is set to `True` (**NOT** the default behavior), then the same dropout mask is applied at every step, as described in: Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in Recurrent Neural Networks". https://arxiv.org/abs/1512.05287 Otherwise a different dropout mask is applied at every time step. Note, by default (unless a custom `dropout_state_filter` is provided), the memory state (`c` component of any `LSTMStateTuple`) passing through a `DropoutWrapper` is never modified. This behavior is described in the above article. Args: cell: an RNNCell. input_keep_prob: float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added. output_keep_prob: float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. state_keep_prob: float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. variational_recurrent: bool. If `True`, then the same dropout pattern is applied across all time steps for one batch. This is implemented by initializing dropout masks in :meth:`zero_state`. """ super().__init__(cell) for prob, attr in [(input_keep_prob, "input_keep_prob"), (state_keep_prob, "state_keep_prob"), (output_keep_prob, "output_keep_prob")]: if prob < 0.0 or prob > 1.0: raise ValueError( f"Parameter '{attr}' must be between 0 and 1: {prob:d}") self._input_keep_prob = input_keep_prob self._output_keep_prob = output_keep_prob self._state_keep_prob = state_keep_prob self._variational_recurrent = variational_recurrent self._recurrent_input_mask: Optional[torch.Tensor] = None self._recurrent_output_mask: Optional[torch.Tensor] = None self._recurrent_state_mask: Optional[torch.Tensor] = None def _new_mask(self, batch_size: int, mask_size: int, prob: float) -> torch.Tensor: return self._param.new_zeros(batch_size, mask_size).bernoulli_(prob)
[docs] def init_batch(self): r"""Initialize dropout masks for variational dropout. Note that we do not create dropout mask here, because the batch size may not be known until actual input is passed in. """ self._recurrent_input_mask = None self._recurrent_output_mask = None self._recurrent_state_mask = None
def _dropout(self, tensor: torch.Tensor, keep_prob: float, mask: Optional[torch.Tensor] = None) -> torch.Tensor: r"""Decides whether to perform standard dropout or recurrent dropout.""" if keep_prob == 1.0 or not self.training: return tensor if mask is not None: return tensor.mul(mask).mul(1.0 / keep_prob) return F.dropout(tensor, 1.0 - keep_prob, self.training)
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: if self.training and self._variational_recurrent: # Create or check recurrent masks. batch_size = input.size(0) for name, size in [('input', self.input_size), ('output', self.hidden_size), ('state', self.hidden_size)]: prob = getattr(self, f'_{name}_keep_prob') if prob == 1.0: continue mask = getattr(self, f'_recurrent_{name}_mask') if mask is None: # Initialize the mask according to current batch size. mask = self._new_mask(batch_size, size, prob) setattr(self, f'_recurrent_{name}_mask', mask) else: # Check that size matches. if mask.size(0) != batch_size: raise ValueError( "Variational recurrent dropout mask does not " "support variable batch sizes across time steps") input = self._dropout(input, self._input_keep_prob, self._recurrent_input_mask) output, new_state = super().forward(input, state) output = self._dropout(output, self._output_keep_prob, self._recurrent_output_mask) new_state = utils.map_structure( lambda x: self._dropout( x, self._state_keep_prob, self._recurrent_state_mask), new_state) return output, new_state
[docs]class ResidualWrapper(RNNCellBase[State]): r"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: output, new_state = super().forward(input, state) output = input + output return output, new_state
[docs]class HighwayWrapper(RNNCellBase[State]): r"""RNNCell wrapper that adds highway connection on cell input and output. Based on: `R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", arXiv preprint arXiv:1505.00387, 2015.` https://arxiv.org/pdf/1505.00387.pdf """ def __init__(self, cell: RNNCellBase[State], carry_bias_init: Optional[float] = None, couple_carry_transform_gates: bool = True): r"""Constructs a `HighwayWrapper` for `cell`. Args: cell: An instance of `RNNCell`. carry_bias_init: float, carry gates bias initialization. couple_carry_transform_gates: boolean, should the Carry and Transform gate be coupled. """ super().__init__(cell) self.carry = nn.Linear(self.input_size, self.input_size) if not couple_carry_transform_gates: self.transform = nn.Linear(self.input_size, self.input_size) self._coupled = couple_carry_transform_gates if carry_bias_init is not None: nn.init.constant_(self.carry.bias, carry_bias_init) if not couple_carry_transform_gates: nn.init.constant_(self.transform.bias, -carry_bias_init)
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: output, new_state = super().forward(input, state) carry = torch.sigmoid(self.carry(input)) if self._coupled: transform = 1 - carry else: transform = torch.sigmoid(self.transform(input)) output = input * carry + output * transform return output, new_state
[docs]class MultiRNNCell(RNNCellBase[List[State]]): r"""RNN cell composed sequentially of multiple simple cells. .. code-block:: python sizes = [128, 128, 64] cells = [BasicLSTMCell(input_size, hidden_size) for input_size, hidden_size in zip(sizes[:-1], sizes[1:])] stacked_rnn_cell = MultiRNNCell(cells) """ _cell: nn.ModuleList # type: ignore def __init__(self, cells: List[RNNCellBase[State]]): r"""Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. Raises: ValueError: if cells is empty (not allowed). """ if len(cells) == 0: raise ValueError("Parameter 'cells' should not be empty.") cell = nn.ModuleList(cells) super().__init__(cell) # type: ignore @property def input_size(self): return self._cell[0].input_size @property def hidden_size(self): return self._cell[-1].hidden_size
[docs] def init_batch(self): for cell in self._cell: cell.init_batch()
[docs] def zero_state(self, batch_size: int) -> List[State]: states = [cell.zero_state(batch_size) # type: ignore for cell in self._cell] return states
[docs] def forward(self, # type: ignore input: torch.Tensor, state: Optional[List[State]] = None) \ -> Tuple[torch.Tensor, List[State]]: r"""Run this multi-layer cell on inputs, starting from state.""" if state is None: batch_size = input.size(0) state = self.zero_state(batch_size) new_states = [] output = input for cell, hx in zip(self._cell, state): output, new_state = cell(output, hx) new_states.append(new_state) return output, new_states
[docs]class AttentionWrapper(RNNCellBase[AttentionWrapperState]): r"""Wraps another `RNNCell` with attention.""" def __init__(self, cell: RNNCellBase, attention_mechanism: MaybeList[AttentionMechanism], attention_layer_size: Optional[MaybeList[int]] = None, alignment_history: bool = False, cell_input_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, output_attention: bool = True): r"""Wraps RNN cell with attention. Construct the `AttentionWrapper`. Args: cell: An instance of RNN cell. attention_mechanism: A list of :class:`~texar.torch.core.AttentionMechanism` instances or a single instance. attention_layer_size: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. alignment_history (bool): whether to store alignment history from all time steps in the final output state. cell_input_fn (optional): A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. output_attention (bool): If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is the output of `cell`. This is the behavior of Bahdanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output. Raises: TypeError: :attr:`attention_layer_size` is not None and `attention_mechanism` is a list but :attr:`attention_layer_size` is not; or vice versa. ValueError: if `attention_layer_size` is not None, :attr:`attention_mechanism` is a list, and its length does not match that of :attr:`attention_layer_size`; if :attr:`attention_layer_size` and `attention_layer` are set simultaneously. """ super().__init__(cell) self._is_multi: bool if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism for mechanism in attention_mechanisms: if not isinstance(mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must contain only instances of " "AttentionMechanism, saw type: %s" % type(mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be an AttentionMechanism or list " "of multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__) attention_mechanisms = [attention_mechanism] if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: torch.cat((inputs, attention), dim=-1)) else: if not callable(cell_input_fn): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) self._attention_layers: Optional[nn.ModuleList] if attention_layer_size is not None: if isinstance(attention_layer_size, (list, tuple)): attention_layer_sizes = tuple(attention_layer_size) else: attention_layer_sizes = (attention_layer_size,) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer_size must contain exactly " "one integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = nn.ModuleList( nn.Linear(attention_mechanisms[i].encoder_output_size + cell.hidden_size, attention_layer_sizes[i], False) for i in range(len(attention_layer_sizes))) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None self._attention_layer_size = sum( attention_mechanism.encoder_output_size for attention_mechanism in attention_mechanisms) self._cell = cell self.attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._alignment_history = alignment_history self._initial_cell_state = None def _item_or_tuple(self, seq): r"""Returns `seq` as tuple or the singular element. Which is returned is determined by how the AttentionMechanism(s) were passed to the constructor. Args: seq: A non-empty sequence of items or generator. Returns: Either the values in the sequence as a tuple if AttentionMechanism(s) were passed to the constructor as a sequence or the singular element. """ t = tuple(seq) if self._is_multi: return t else: return t[0] @property def output_size(self) -> int: r"""The number of features in the output tensor.""" if self._output_attention: return self._attention_layer_size else: return self._cell.hidden_size
[docs] def zero_state(self, batch_size: int) -> AttentionWrapperState: r"""Return an initial (zero) state tuple for this :class:`AttentionWrapper`. .. note:: Please see the initializer documentation for details of how to call :meth:`zero_state` if using an :class:`~texar.torch.core.AttentionWrapper` with a :class:`~texar.torch.modules.BeamSearchDecoder`. Args: batch_size: `0D` integer: the batch size. Returns: An :class:`~texar.torch.core.AttentionWrapperState` tuple containing zeroed out tensors and Python lists. """ cell_state: torch.Tensor = super().zero_state(batch_size) # type:ignore initial_alignments = [None for _ in self.attention_mechanisms] alignment_history: List[List[Optional[torch.Tensor]]] alignment_history = [[] for _ in initial_alignments] return AttentionWrapperState( cell_state=cell_state, time=0, attention=self._param.new_zeros(batch_size, self._attention_layer_size, requires_grad=False), alignments=self._item_or_tuple(initial_alignments), attention_state=self._item_or_tuple(initial_alignments), alignment_history=self._item_or_tuple(alignment_history))
[docs] def forward(self, # type: ignore inputs: torch.Tensor, state: Optional[AttentionWrapperState], memory: torch.Tensor, memory_sequence_length: Optional[torch.LongTensor] = None) -> \ Tuple[torch.Tensor, AttentionWrapperState]: r"""Perform a step of attention-wrapped RNN. - Step 1: Mix the :attr:`inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of :class:`~texar.torch.core.AttentionWrapperState` containing tensors from the previous time step. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. memory_sequence_length: (optional) Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. Returns: A tuple `(attention_or_cell_output, next_state)`, where - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of :class:`~texar.torch.core.AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of :class:`~texar.torch.core.AttentionWrapperState`. """ if state is None: state = self.zero_state(batch_size=memory.shape[0]) elif not isinstance(state, AttentionWrapperState): raise TypeError("Expected state to be instance of " "AttentionWrapperState. Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] # type: ignore previous_alignment_history = \ [state.alignment_history] # type: ignore all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self.attention_mechanisms): if previous_attention_state[i] is not None: attention_state = previous_attention_state[i] else: attention_state = attention_mechanism.initial_state( memory.shape[0], memory.shape[1], self._param.dtype, self._param.device) attention, alignments, next_attention_state = compute_attention( attention_mechanism=attention_mechanism, cell_output=cell_output, attention_state=attention_state, attention_layer=(self._attention_layers[i] if self._attention_layers else None), memory=memory, memory_sequence_length=memory_sequence_length) if self._alignment_history: alignment_history = previous_alignment_history[i] + [alignments] else: alignment_history = previous_alignment_history[i] all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) attention = torch.cat(all_attentions, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state