# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various connectors.
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.distributions.distribution import Distribution
from texar.torch.core import get_activation_fn
from texar.torch.hyperparams import HParams
from texar.torch.modules.connectors.connector_base import ConnectorBase
from texar.torch.utils import nest
from texar.torch.utils import utils
from texar.torch.utils.types import MaybeTuple
__all__ = [
"ConstantConnector",
"ForwardConnector",
"MLPTransformConnector",
"ReparameterizedStochasticConnector",
"StochasticConnector",
# "ConcatConnector"
]
TensorStruct = Union[List[torch.Tensor],
Dict[Any, torch.Tensor],
MaybeTuple[torch.Tensor]]
OutputSize = MaybeTuple[Union[int, torch.Size]]
ActivationFn = Callable[[torch.Tensor], torch.Tensor]
LinearLayer = Callable[[torch.Tensor], torch.Tensor]
def _assert_same_size(outputs: TensorStruct,
output_size: OutputSize):
r"""Check if outputs match output_size
Args:
outputs: A tensor or a (nested) tuple of tensors
output_size: Can be an ``int``, a ``torch.Size``, or a (nested)
tuple of ``int`` or ``torch.Size``.
"""
flat_output_size = nest.flatten(output_size)
flat_output = nest.flatten(outputs)
for (output, size) in zip(flat_output, flat_output_size):
if isinstance(size, torch.Size):
if output[0].size() != size:
raise ValueError("The output size does not match"
"the required output_size")
elif output[0].size()[-1] != size:
raise ValueError(
"The output size does not match the required output_size")
def _get_sizes(sizes: List[Any]) -> List[int]:
r"""
Args:
sizes: A list of ``int`` or ``torch.Size``. If each element is of type
``torch.Size``, the size is computed by taking the product of the
shape.
Returns:
A list of sizes with ``torch.Size`` replaced by product of its
individual dimensions
"""
if isinstance(sizes[0], torch.Size):
size_list = [np.prod(shape) for shape in sizes]
else:
size_list = sizes
return size_list
def _sum_output_size(output_size: OutputSize) -> int:
r"""Return sum of all dim values in :attr:`output_size`
Args:
output_size: Can be an ``int``, a ``torch.Size``, or a (nested)
tuple of ``int`` or ``torch.Size``.
"""
flat_output_size = nest.flatten(output_size)
size_list = _get_sizes(flat_output_size)
ret = sum(size_list)
return ret
def _mlp_transform(inputs: TensorStruct,
output_size: OutputSize,
linear_layer: Optional[LinearLayer] = None,
activation_fn: Optional[ActivationFn] = None) -> Any:
r"""Transforms inputs through a fully-connected layer that creates
the output with specified size.
Args:
inputs: A Tensor of shape `[batch_size, d1, ..., dn]`, or a (nested)
tuple of such elements. The dimensions `d1, ..., dn` will be flatten
and transformed by a dense layer.
output_size: Can be an ``int``, a ``torch.Size``, or a (nested)
tuple of ``int`` or ``torch.Size``.
activation_fn: Activation function applied to the output.
:returns:
If :attr:`output_size` is an ``int`` or a ``torch.Size``,
returns a tensor of shape ``[batch_size, *, output_size]``.
If :attr:`output_size` is a tuple of ``int`` or ``torch.Size``,
returns a tuple having the same structure as :attr:`output_size`,
where each element has the same size as defined in :attr:`output_size`.
"""
# Flatten inputs
flat_input = nest.flatten(inputs)
flat_input = [x.view(-1, x.size(-1)) for x in flat_input]
concat_input = torch.cat(flat_input, 1)
# Get output dimension
flat_output_size = nest.flatten(output_size)
size_list = _get_sizes(flat_output_size)
fc_output = concat_input
if linear_layer is not None:
fc_output = linear_layer(fc_output)
if activation_fn is not None:
fc_output = activation_fn(fc_output)
flat_output = torch.split(fc_output, size_list, dim=1)
flat_output = list(flat_output)
if isinstance(flat_output_size[0], torch.Size):
flat_output = [torch.reshape(output, (-1,) + shape) for output, shape
in zip(flat_output, flat_output_size)]
output = nest.pack_sequence_as(structure=output_size,
flat_sequence=flat_output)
return output
[docs]class ConstantConnector(ConnectorBase):
r"""Creates a constant tensor or (nested) tuple of Tensors that
contains a constant value.
Args:
output_size: Size of output **excluding** the batch dimension. For
example, set :attr:`output_size` to ``dim`` to generate output of
shape ``[batch_size, dim]``.
Can be an ``int``, a tuple of ``int``, a ``torch.Size``,
or a tuple of ``torch.Size``.
For example, to transform inputs to have decoder state size, set
:python:`output_size=decoder.state_size`.
If :attr:`output_size` is a tuple ``(1, 2, 3)``, then the
output structure will be
``([batch_size * 1], [batch_size * 2], [batch_size * 3])``.
If :attr:`output_size` is ``torch.Size([1, 2, 3])``, then the
output structure will be ``[batch_size, 1, 2, 3]``.
hparams (dict, optional): Hyperparameters. Missing
hyperparameter will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
This connector does not have trainable parameters.
Example:
.. code-block:: python
state_size = (1, 2, 3)
connector = ConstantConnector(state_size, hparams={"value": 1.})
one_state = connector(batch_size=64)
# `one_state` structure: (Tensor_1, Tensor_2, Tensor_3),
# Tensor_1.size() == torch.Size([64, 1])
# Tensor_2.size() == torch.Size([64, 2])
# Tensor_3.size() == torch.Size([64, 3])
# Tensors are filled with 1.0.
size = torch.Size([1, 2, 3])
connector_size = ConstantConnector(size, hparams={"value": 2.})
size_state = connector_size(batch_size=64)
# `size_state` structure: Tensor with size [64, 1, 2, 3].
# Tensor is filled with 2.0.
"""
def __init__(self,
output_size: OutputSize,
hparams: Optional[HParams] = None):
super().__init__(output_size, hparams=hparams)
self.value = self.hparams.value
[docs] @staticmethod
def default_hparams() -> dict:
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"value": 0.,
"name": "constant_connector"
}
Here:
`"value"`: float
The constant scalar that the output tensor(s) has.
`"name"`: str
Name of the connector.
"""
return {
"value": 0.,
"name": "constant_connector"
}
[docs] def forward(self, # type: ignore
batch_size: Union[int, torch.Tensor]) -> Any:
r"""Creates output tensor(s) that has the given value.
Args:
batch_size: An ``int`` or ``int`` scalar tensor, the
batch size.
:returns:
A (structure of) tensor whose structure is the same as
:attr:`output_size`, with value specified by
``value`` or :attr:`hparams`.
"""
def full_tensor(x):
if isinstance(x, torch.Size):
return torch.full((batch_size,) + x, self.value)
else:
return torch.full((batch_size, x), self.value)
output = utils.map_structure(
full_tensor,
self._output_size)
return output
[docs]class ForwardConnector(ConnectorBase):
r"""Transforms inputs to have specified structure.
Example:
.. code-block:: python
state_size = namedtuple('LSTMStateTuple', ['h', 'c'])(256, 256)
# state_size == LSTMStateTuple(c=256, h=256)
connector = ForwardConnector(state_size)
output = connector([tensor_1, tensor_2])
# output == LSTMStateTuple(c=tensor_1, h=tensor_2)
Args:
output_size: Size of output **excluding** the batch dimension. For
example, set :attr:`output_size` to ``dim`` to generate output of
shape ``[batch_size, dim]``.
Can be an ``int``, a tuple of ``int``, a ``torch.Size``, or a
tuple of ``torch.Size``.
For example, to transform inputs to have decoder state size, set
:python:`output_size=decoder.state_size`.
hparams (dict, optional): Hyperparameters. Missing
hyperparameter will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
This connector does not have trainable parameters.
See :meth:`forward` for the inputs and outputs of the connector.
The input to the connector must have the same structure with
:attr:`output_size`, or must have the same number of elements and be
re-packable into the structure of :attr:`output_size`. Note that if input
is or contains a ``dict`` instance, the keys will be sorted to pack in
deterministic order (See :func:`~texar.torch.utils.nest.pack_sequence_as`).
"""
def __init__(self,
output_size: OutputSize,
hparams: Optional[HParams] = None):
super().__init__(output_size, hparams=hparams)
[docs] @staticmethod
def default_hparams() -> dict:
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"name": "forward_connector"
}
Here:
`"name"`: str
Name of the connector.
"""
return {
"name": "forward_connector"
}
[docs] def forward(self, # type: ignore
inputs: TensorStruct) -> Any:
r"""Transforms inputs to have the same structure as with
:attr:`output_size`. Values of the inputs are not changed.
:attr:`inputs` must either have the same structure, or have the same
number of elements with :attr:`output_size`.
Args:
inputs: The input (structure of) tensor to pass forward.
:returns:
A (structure of) tensors that re-packs :attr:`inputs` to have
the specified structure of :attr:`output_size`.
"""
flat_input = nest.flatten(inputs)
output = nest.pack_sequence_as(
self._output_size, flat_input)
return output
class ReparameterizedStochasticConnector(ConnectorBase):
r"""Samples from a distribution with reparameterization trick, and
transforms samples into specified size.
Reparameterization allows gradients to be back-propagated through the
stochastic samples. Used in, e.g., Variational Autoencoders (VAEs).
Example:
.. code-block:: python
# Initialized without num_samples
cell = LSTMCell(num_units=256)
# cell.state_size == LSTMStateTuple(c=256, h=256)
mu = torch.zeros([16, 100])
var = torch.ones([100])
connector = ReparameterizedStochasticConnector(
cell.state_size,
mlp_input_size=mu.size()[-1],
distribution="MultivariateNormal",
distribution_kwargs={
"loc": mu,
"scale_tril": torch.diag(var)})
output, sample = connector()
# output == LSTMStateTuple(c=tensor_of_shape_(16, 256),
# h=tensor_of_shape_(16, 256))
# sample == Tensor([16, 100])
output_, sample_ = connector(num_samples=4)
# output_ == LSTMStateTuple(c=tensor_of_shape_(4, 16, 256),
# h=tensor_of_shape_(4, 16, 256))
# sample == Tensor([4, 16, 100])
Args:
output_size: Size of output **excluding** the batch dimension. For
example, set ``output_size`` to ``dim`` to generate output of
shape ``[batch_size, dim]``.
Can be an ``int``, a tuple of ``int``, a ``torch.Size``, or
a tuple of ``torch.Size``.
For example, to transform inputs to have decoder state size, set
:python:`output_size=decoder.state_size`.
mlp_input_size: Size of MLP transfer process input, which is equal to
the distribution result size **excluding** the batch dimension,
Can be ``int`` or ``torch.Size`` or a tuple of ``int``.
distribution: A instance or name ``str`` of subclass of
:torch:`distributions.distribution.Distribution`,
Can be a distribution class instance or ``str``.
distribution_kwargs (dict, optional): ``dict`` of keyword arguments
for the :attr:`distribution`. Its keys are `str`, which are names
of keyword arguments; Its values are corresponding values for each
argument.
hparams (dict, optional): Hyperparameters. Missing
hyperparameter will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
"""
def __init__(self,
output_size: OutputSize,
mlp_input_size: Union[torch.Size, MaybeTuple[int], int],
distribution: Union[Distribution, str] = 'MultivariateNormal',
distribution_kwargs: Optional[Dict[str, Any]] = None,
hparams: Optional[HParams] = None):
super().__init__(output_size, hparams=hparams)
if distribution_kwargs is None:
distribution_kwargs = {}
self._dstr_type = distribution
self._dstr_kwargs = distribution_kwargs
for dstr_attr, dstr_val in distribution_kwargs.items():
if isinstance(dstr_val, torch.Tensor):
dstr_param = nn.Parameter(dstr_val)
distribution_kwargs[dstr_attr] = dstr_param
self.register_parameter(dstr_attr, dstr_param)
if isinstance(mlp_input_size, int):
input_feature = mlp_input_size
else:
input_feature = np.prod(mlp_input_size)
self._linear_layer = nn.Linear(
input_feature, _sum_output_size(output_size))
self._activation_fn = get_activation_fn(
self.hparams.activation_fn)
@staticmethod
def default_hparams() -> dict:
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"activation_fn": "texar.torch.core.layers.identity",
"name": "reparameterized_stochastic_connector"
}
Here:
`"activation_fn"`: str
The activation function applied to the outputs of the MLP
transformation layer. Can be a function, or its name or module path.
`"name"`: str
Name of the connector.
"""
return {
"activation_fn": "texar.torch.core.layers.identity",
"name": "reparameterized_stochastic_connector"
}
def forward(self, # type: ignore
num_samples: Optional[Union[int, torch.Tensor]] = None,
transform: bool = True) -> Tuple[Any, Any]:
r"""Samples from a distribution and optionally performs transformation
with an MLP layer.
The distribution must be reparameterizable, i.e.,
:python:`Distribution.has_rsample == True`.
Args:
num_samples (optional): An ``int`` or ``int`` tensor.
Number of samples to generate. If not given,
generate a single sample. Note that if batch size has
already been included in :attr:`distribution`'s dimensionality,
:attr:`num_samples` should be left as ``None``.
transform (bool): Whether to perform MLP transformation of the
distribution samples. If ``False``, the structure/shape of a
sample must match :attr:`output_size`.
:returns:
A tuple (:attr:`output`, :attr:`sample`), where
- output: A tensor or a (nested) tuple of Tensors with
the same structure and size of :attr:`output_size`.
The batch dimension equals :attr:`num_samples` if specified,
or is determined by the distribution dimensionality.
If :attr:`transform` is `False`, it will be
equal to :attr:`sample`.
- sample: The sample from the distribution, prior to transformation.
Otherwise, returns a tensor :attr:`sample`, where
- sample: The sample from the distribution, prior to transformation.
Raises:
ValueError: If distribution is not reparameterizable.
ValueError: The output does not match :attr:`output_size`.
"""
if isinstance(self._dstr_type, str):
dstr: Distribution = utils.check_or_get_instance(
self._dstr_type, self._dstr_kwargs,
["torch.distributions", "texar.torch.custom"])
else:
dstr = self._dstr_type
if not dstr.has_rsample:
raise ValueError("Distribution should be reparameterizable")
if num_samples:
sample = dstr.rsample([num_samples])
else:
sample = dstr.rsample()
if transform:
output = _mlp_transform(
sample,
self._output_size,
self._linear_layer,
self._activation_fn)
_assert_same_size(output, self._output_size)
else:
output = sample
return output, sample
class StochasticConnector(ConnectorBase):
r"""Samples from a distribution and transforms samples into specified size.
The connector is the same as
:class:`~texar.torch.modules.ReparameterizedStochasticConnector`, except
that here reparameterization is disabled, and thus the gradients cannot be
back-propagated through the stochastic samples.
Args:
output_size: Size of output **excluding** the batch dimension. For
example, set ``output_size`` to ``dim`` to generate output of
shape ``[batch_size, dim]``.
Can be an ``int``, a tuple of ``int``, a torch.Size, or a tuple of
torch.Size.
For example, to transform inputs to have decoder state size, set
:python:`output_size=decoder.state_size`.
mlp_input_size: Size of MLP transfer process input, which is equal to
the distribution result size **excluding** the batch dimension,
Can be ``int`` or ``torch.Size`` or a tuple of ``int``.
distribution: A instance of subclass of
:torch:`distributions.distribution.Distribution`,
Can be a class, its name or module path, or a class instance.
The :attr:`distribution` should not be reparameterizable.
distribution_kwargs (dict, optional): ``dict`` of keyword arguments
for the :attr:`distribution`. Its keys are `str`, which are names
of keyword arguments; Its values are corresponding values for each
argument.
hparams (dict, optional): Hyperparameters. Missing
hyperparameter will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
"""
def __init__(self,
output_size: OutputSize,
mlp_input_size: Union[torch.Size, MaybeTuple[int], int],
distribution: Union[Distribution, str] = 'MultivariateNormal',
distribution_kwargs: Optional[Dict[str, Any]] = None,
hparams: Optional[HParams] = None):
super().__init__(output_size, hparams=hparams)
if distribution_kwargs is None:
distribution_kwargs = {}
self._dstr_kwargs = distribution_kwargs
if isinstance(distribution, str):
self._dstr: Distribution = utils.check_or_get_instance(
distribution, self._dstr_kwargs,
["torch.distributions", "texar.torch.custom"])
else:
self._dstr = distribution
if self._dstr.has_rsample:
raise ValueError("Distribution should not be reparameterizable")
if isinstance(mlp_input_size, int):
input_feature = mlp_input_size
else:
input_feature = np.prod(mlp_input_size)
self._linear_layer = nn.Linear(
input_feature, _sum_output_size(output_size))
self._activation_fn = get_activation_fn(
self.hparams.activation_fn)
@staticmethod
def default_hparams():
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"activation_fn": "texar.torch.core.layers.identity",
"name": "stochastic_connector"
}
Here:
`"activation_fn"`: str
The activation function applied to the outputs of the MLP
transformation layer. Can
be a function, or its name or module path.
`"name"`: str
Name of the connector.
"""
return {
"activation_fn": "texar.torch.core.layers.identity",
"name": "stochastic_connector"
}
def forward(self, # type: ignore
num_samples: Optional[Union[int, torch.Tensor]] = None,
transform: bool = False) -> Any:
r"""Samples from a distribution and optionally performs transformation
with an MLP layer.
The inputs and outputs are the same as
:class:`~texar.torch.modules.ReparameterizedStochasticConnector` except
that the distribution does not need to be reparameterizable, and
gradient cannot be back-propagate through the samples.
Args:
num_samples (optional): An ``int`` or ``int`` tensor.
Number of samples to generate. If not given,
generate a single sample. Note that if batch size has
already been included in :attr:`distribution`'s dimensionality,
:attr:`num_samples` should be left as ``None``.
transform (bool): Whether to perform MLP transformation of the
distribution samples. If ``False``, the structure/shape of a
sample must match :attr:`output_size`.
:returns:
A tuple (:attr:`output`, :attr:`sample`), where
- output: A tensor or a (nested) tuple of Tensors with
the same structure and size of :attr:`output_size`.
The batch dimension equals :attr:`num_samples` if specified,
or is determined by the distribution dimensionality.
If :attr:`transform` is `False`, it will be
equal to :attr:`sample`.
- sample: The sample from the distribution, prior to transformation.
Raises:
ValueError: If distribution can be reparameterizable.
ValueError: The output does not match :attr:`output_size`.
"""
if num_samples:
sample = self._dstr.sample([num_samples])
else:
sample = self._dstr.sample()
if self._dstr.event_shape == []:
sample = torch.reshape(
input=sample, shape=sample.size() + torch.Size([1]))
# Disable gradients through samples
sample = sample.detach().float()
if transform:
output = _mlp_transform(
sample,
self._output_size,
self._linear_layer,
self._activation_fn)
_assert_same_size(output, self._output_size)
else:
output = sample
return output, sample
# class ConcatConnector(ConnectorBase):
# r"""Concatenates multiple connectors into one connector. Used in, e.g.,
# semi-supervised variational autoencoders, disentangled representation
# learning, and other models.
#
# Args:
# output_size: Size of output excluding the batch dimension (eg.
# :attr:`output_size = p` if :attr:`output.shape` is :attr:`[N, p]`).
# Can be an int, a tuple of int, a torch.Size, or a tuple of
# torch.Size.
# For example, to transform to decoder state size, set
# `output_size=decoder.cell.state_size`.
# hparams (dict): Hyperparameters of the connector.
# """
#
# def __init__(self, output_size, hparams=None):
# super().__init__(self, output_size, hparams)
#
# @staticmethod
# def default_hparams():
# r"""Returns a dictionary of hyperparameters with default values.
#
# Returns:
#
# .. code-block:: python
#
# {
# "activation_fn": "texar.torch.core.layers.identity",
# "name": "concat_connector"
# }
#
# Here:
#
# `"activation_fn"`: (str or callable)
# The name or full path to the activation function applied to
# the outputs of the MLP layer. The activation functions can be:
#
# - Built-in activation functions defined in :
# - User-defined activation functions in `texar.torch.custom`.
# - External activation functions. Must provide the full path, \
# e.g., "my_module.my_activation_fn".
#
# The default value is :attr:`"identity"`, i.e., the MLP
# transformation is linear.
#
# `"name"`: str
# Name of the connector.
#
# The default value is "concat_connector".
# """
# return {
# "activation_fn": "texar.torch.core.layers.identity",
# "name": "concat_connector"
# }
#
# def forward(self, connector_inputs, transform=True):
# r"""Concatenate multiple input connectors
#
# Args:
# connector_inputs: a list of connector states
# transform (bool): If `True`, then the output are automatically
# transformed to match :attr:`output_size`.
#
# Returns:
# A Tensor or a (nested) tuple of Tensors of the same structure of
# the decoder state.
# """
# connector_inputs = [connector.float()
# for connector in connector_inputs]
# output = torch.cat(connector_inputs, dim=1)
#
# if transform:
# fn_modules = ['texar.torch.custom', 'torch', 'torch.nn']
# activation_fn = utils.get_function(self.hparams.activation_fn,
# fn_modules)
# output, linear_layer = _mlp_transform(
# output, self._output_size, activation_fn)
# self._linear_layers.append(linear_layer)
# _assert_same_size(output, self._output_size)
#
# self._add_internal_trainable_variables()
# self._built = True
#
# return output