Source code for texar.torch.modules.classifiers.rnn_classifiers

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various RNN classifiers.
"""

from typing import Optional, Tuple, TypeVar

import torch
import torch.nn as nn
from torch.nn import functional as F

from texar.torch.core.cell_wrappers import RNNCellBase
from texar.torch.hyperparams import HParams
from texar.torch.modules.classifiers.classifier_base import ClassifierBase
from texar.torch.modules.encoders.rnn_encoders import \
        UnidirectionalRNNEncoder
from texar.torch.utils.utils import dict_fetch


__all__ = [
    "UnidirectionalRNNClassifier",
]

State = TypeVar('State')


[docs]class UnidirectionalRNNClassifier(ClassifierBase): r"""One directional RNN classifier. This is a combination of the :class:`~texar.torch.modules.UnidirectionalRNNEncoder` with a classification layer. Both step-wise classification and sequence-level classification are supported, specified in :attr:`hparams`. Arguments are the same as in :class:`~texar.torch.modules.UnidirectionalRNNEncoder`. Args: input_size (int): The number of expected features in the input for the cell. cell: (RNNCell, optional) If not specified, a cell is created as specified in :attr:`hparams["rnn_cell"]`. output_layer (optional): An instance of :torch_nn:`Module`. Applies to the RNN cell output of each step. If `None` (default), the output layer is created as specified in :attr:`hparams["output_layer"]`. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparameters will be set to default values. See :meth:`default_hparams` for the hyperparameter structure and default values. """ def __init__(self, input_size: int, cell: Optional[RNNCellBase[State]] = None, output_layer: Optional[nn.Module] = None, hparams=None): super().__init__(hparams=hparams) # Create the underlying encoder encoder_hparams = dict_fetch( hparams, UnidirectionalRNNEncoder.default_hparams()) self._encoder = UnidirectionalRNNEncoder( input_size=input_size, cell=cell, output_layer=output_layer, hparams=encoder_hparams) # Create an additional classification layer if needed self.num_classes = self._hparams.num_classes if self.num_classes <= 0: self._logits_layer = None else: logit_kwargs = self._hparams.logit_layer_kwargs if logit_kwargs is None: logit_kwargs = {} elif not isinstance(logit_kwargs, HParams): raise ValueError("hparams['logit_layer_kwargs'] " "must be a dict.") else: logit_kwargs = logit_kwargs.todict() if self._hparams.clas_strategy == 'all_time': self._logits_layer = nn.Linear( self._encoder.output_size * self._hparams.max_seq_length, self.num_classes, **logit_kwargs) else: self._logits_layer = nn.Linear( self._encoder.output_size, self.num_classes, **logit_kwargs) self.is_binary = (self.num_classes == 1) or \ (self.num_classes <= 0 and self._encoder.output_size == 1)
[docs] @staticmethod def default_hparams(): r"""Returns a dictionary of hyperparameters with default values. .. code-block:: python { # (1) Same hyperparameters as in UnidirectionalRNNEncoder ... # (2) Additional hyperparameters "num_classes": 2, "logit_layer_kwargs": None, "clas_strategy": "final_time", "max_seq_length": None, "name": "unidirectional_rnn_classifier" } Here: 1. Same hyperparameters as in :class:`~texar.torch.modules.UnidirectionalRNNEncoder`. See the :meth:`~texar.torch.modules.UnidirectionalRNNEncoder.default_hparams` . An instance of UnidirectionalRNNEncoder is created for feature extraction. 2. Additional hyperparameters: `"num_classes"`: int Number of classes: - If **> 0**, an additional `Linear` layer is appended to the encoder to compute the logits over classes. - If **<= 0**, no dense layer is appended. The number of classes is assumed to be the final dense layer size of the encoder. `"logit_layer_kwargs"`: dict Keyword arguments for the logit Dense layer constructor, except for argument "units" which is set to `num_classes`. Ignored if no extra logit layer is appended. `"clas_strategy"`: str The classification strategy, one of: - **final_time**: Sequence-level classification based on the output of the final time step. Each sequence has a class. - **all_time**: Sequence-level classification based on the output of all time steps. Each sequence has a class. - **time_wise**: Step-wise classification, i.e., make classification for each time step based on its output. `"max_seq_length"`: int, optional Maximum possible length of input sequences. Required if `clas_strategy` is `all_time`. `"name"`: str Name of the classifier. """ hparams = UnidirectionalRNNEncoder.default_hparams() hparams.update({ "num_classes": 2, "logit_layer_kwargs": None, "clas_strategy": "final_time", "max_seq_length": None, "name": "bert_classifier" }) return hparams
[docs] def forward(self, # type: ignore inputs: torch.Tensor, sequence_length: Optional[torch.LongTensor] = None, initial_state: Optional[State] = None, time_major: bool = False) \ -> Tuple[torch.Tensor, torch.LongTensor]: r"""Feeds the inputs through the network and makes classification. The arguments are the same as in :class:`~texar.torch.modules.UnidirectionalRNNEncoder`. Args: inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``. The first two dimensions :attr:`batch_size` and :attr:`max_time` are exchanged if :attr:`time_major` is `True`. sequence_length (optional): A 1D :tensor:`LongTensor` of shape ``[batch_size]``. Sequence lengths of the batch inputs. Used to copy-through state and zero-out outputs when past a batch element's sequence length. initial_state (optional): Initial state of the RNN. time_major (bool): The shape format of the :attr:`inputs` and :attr:`outputs` Tensors. If `True`, these tensors are of shape ``[max_time, batch_size, depth]``. If `False` (default), these tensors are of shape ``[batch_size, max_time, depth]``. Returns: A tuple `(logits, preds)`, containing the logits over classes and the predictions, respectively. - If ``clas_strategy`` is ``final_time`` or ``all_time``: - If ``num_classes`` == 1, ``logits`` and ``pred`` are both of shape ``[batch_size]``. - If ``num_classes`` > 1, ``logits`` is of shape ``[batch_size, num_classes]`` and ``pred`` is of shape ``[batch_size]``. - If ``clas_strategy`` is ``time_wise``: - ``num_classes`` == 1, ``logits`` and ``pred`` are both of shape ``[batch_size, max_time]``. - If ``num_classes`` > 1, ``logits`` is of shape ``[batch_size, max_time, num_classes]`` and ``pred`` is of shape ``[batch_size, max_time]``. - If ``time_major`` is `True`, the batch and time dimensions are exchanged. """ enc_outputs, _ = self._encoder(inputs=inputs, sequence_length=sequence_length, initial_state=initial_state, time_major=time_major) # Compute logits strategy = self._hparams.clas_strategy if strategy == 'time_wise': logits = enc_outputs elif strategy == 'final_time': if time_major: logits = enc_outputs[-1, :, :] else: logits = enc_outputs[:, -1, :] elif strategy == 'all_time': if time_major: length_diff = self._hparams.max_seq_length - inputs.shape[0] logit_input = F.pad(enc_outputs, [0, length_diff, 0, 0, 0, 0]) logit_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) logits = logit_input.view(-1, logit_input_dim) else: length_diff = self._hparams.max_seq_length - inputs.shape[1] logit_input = F.pad(enc_outputs, [0, 0, 0, length_diff, 0, 0]) logit_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) logits = logit_input.view(-1, logit_input_dim) else: raise ValueError('Unknown classification strategy: {}'.format( strategy)) if self._logits_layer is not None: logits = self._logits_layer(logits) # Compute predictions if strategy == "time_wise": if self.is_binary: logits = torch.squeeze(logits, -1) preds = (logits > 0).long() else: preds = torch.argmax(logits, dim=-1) else: if self.is_binary: preds = (logits > 0).long() logits = torch.flatten(logits) else: preds = torch.argmax(logits, dim=-1) preds = torch.flatten(preds) return logits, preds
@property def output_size(self) -> int: r"""The feature size of :meth:`forward` output :attr:`logits`. If :attr:`logits` size is only determined by input (i.e. if ``num_classes`` == 1), the feature size is equal to ``-1``. Otherwise it is equal to last dimension value of :attr:`logits` size. """ if self._hparams.num_classes == 1: logit_dim = -1 elif self._hparams.num_classes > 1: logit_dim = self._hparams.num_classes elif self._hparams.clas_strategy == 'all_time': logit_dim = (self._encoder.output_size * self._hparams.max_seq_length) elif self._hparams.clas_strategy == 'final_time': logit_dim = self._encoder.output_size elif self._hparams.clas_strategy == 'time_wise': logit_dim = self._hparams.encoder.dim return logit_dim