Source code for texar.torch.utils.beam_search

# Adapted from the Tensor2Tensor's implementation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications copyright (C) 2019 Texar
# ==============================================================================
"""
Implementation of beam search with penalties.

Adapted from:
    `https://github.com/tensorflow/tensor2tensor/blob/eb048f69c7ea860324122b87cb9caf59c52a27f3/tensor2tensor/utils/beam_search.py`
"""
from typing import Any, Callable, Optional, Tuple, TypeVar, overload

import torch

from texar.torch.utils import map_structure, torch_bool

__all__ = [
    'beam_search',
]

State = TypeVar('State')

# Default value for INF
INF = 1.0 * 1e7


def gather_nd(params: Any, indices: torch.Tensor) -> Any:
    if not isinstance(params, torch.Tensor):
        return params
    assert len(indices.size()) == 3
    orig_size = params.size()
    index = indices[:, :, 1].view(-1) + indices[:, :, 0].view(-1) * orig_size[1]
    ret = torch.index_select(
        params.view(-1, *params.size()[2:]), dim=0, index=index
    )
    ret = ret.view(orig_size[0], indices.size(1), *orig_size[2:])

    return ret


def _merge_beam_dim(tensor: Any) -> Any:
    r"""Reshapes first two dimensions in to single dimension.

    Args:
        tensor: Tensor to reshape of shape `[A, B, ...]`.

    Returns:
        Reshaped tensor of shape `[A * B, ...]`.
    """
    if not isinstance(tensor, torch.Tensor):
        return tensor
    shape = list(tensor.size())
    shape[0] *= shape[1]  # batch -> batch * beam_size
    shape.pop(1)  # Remove beam dim
    return tensor.view(tuple(shape))


def _unmerge_beam_dim(tensor: Any, batch_size: int,
                      beam_size: int) -> Any:
    r"""Reshapes first dimension back to `[batch_size, beam_size]`.

    Args:
        tensor: Tensor to reshape of shape `[batch_size * beam_size, ...]`.
        batch_size: int, original batch size.
        beam_size: int, original beam size.

    Returns:
        Reshaped tensor of shape `[batch_size, beam_size, ...]`.
    """
    if not isinstance(tensor, torch.Tensor):
        return tensor
    shape = list(tensor.size())
    new_shape = [batch_size] + [beam_size] + shape[1:]
    return tensor.view(tuple(new_shape))


def _expand_to_beam_size(tensor: Any,
                         beam_size: int) -> Any:
    r"""Tiles a given tensor by :attr:`beam_size`.

    Args:
        tensor: tensor to tile. Shape: `[batch_size, ...]`.
        beam_size: How much to tile the tensor by.

    Returns:
        Tiled tensor of shape `[batch_size, beam_size, ...]`.
    """
    if not isinstance(tensor, torch.Tensor):
        return tensor
    tensor = torch.unsqueeze(tensor, dim=1)
    tile_dims = [1] * len(tensor.size())
    tile_dims[1] = beam_size

    return tensor.repeat(tuple(tile_dims))


def log_prob_from_logits(logits: torch.Tensor) -> torch.Tensor:
    return logits - torch.logsumexp(logits, dim=-1, keepdim=True)


def compute_batch_indices(batch_size: int, beam_size: int) -> torch.LongTensor:
    r"""Computes the i-th coordinate that contains the batch index for
    gathers.

    The batch index tensor is a tensor like `[[0,0,0,0,],[1,1,1,1],..]`.
    It says which batch the beam item is in. This will create the first
    dimension of the 2D coordinates needed for the gather.

    Args:
        batch_size: Batch size
        beam_size: Size of the beam.

    Returns:
        `[batch_size, beam_size]` tensor of ids.
    """
    batch_pos = torch.arange(batch_size)
    batch_pos = batch_pos.view(-1, 1).expand(batch_size, beam_size)
    return batch_pos


def compute_topk_scores_and_seq(
    sequences: torch.LongTensor,
    scores: torch.Tensor,
    scores_to_gather: torch.Tensor,
    flags: torch.ByteTensor,
    beam_size: int,
    batch_size: int,
    states_to_gather: Optional[State] = None,
) -> Tuple[torch.LongTensor, torch.Tensor, torch.ByteTensor, Optional[State]]:
    r"""Given sequences and scores, will gather the top-k (`k = beam`) size
    sequences.

    This function is used to grow alive, and finished. It takes sequences,
    scores, and flags, and returns the top k from sequence
    :attr:`scores_to_gather`, and flags based on the values in scores.

    Args:
        sequences: Tensor of sequences that we need to gather from.
            Shape: `[batch_size, beam_size, seq_length]`.
        scores: Tensor of scores for each sequence in sequences. We will use
            these to compute the top-k. Shape: `[batch_size, beam_size]`.
        scores_to_gather: Tensor of scores for each sequence in sequences.
            Shape: `[batch_size, beam_size]`.
            We will return the gathered scores from here.
            Scores to gather is different from scores because for
            grow_alive, we will need to return log-probabilities, while for
            grow_finished, we will need to return the length penalized
            scores.
        flags: Tensor of booleans for sequences that say whether a sequence
            has reached `EOS`.
        beam_size: int
        batch_size: int
        states_to_gather: (possibly nested structure of) decoding states.

    :returns: Tuple of:

        - `topk_seq`: `[batch_size, beam_size, decode_length]`.
        - `topk_gathered_scores`: `[batch_size, beam_size]`.
        - `topk_finished_flags`: `[batch_size, beam_size]`.
    """
    # by default top-k is for the last dimension
    _, topk_indexes = torch.topk(scores, k=beam_size)
    # The next three steps are to create coordinates for torch.gather_nd to
    # pull out the top-k sequences from sequences based on scores.
    # batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. It says which
    # batch the beam item is in. This will create the i of the i,j
    # coordinate needed for the gather
    batch_pos = compute_batch_indices(batch_size, beam_size)
    batch_pos = batch_pos.to(device=topk_indexes.device)
    # top coordinates will give us the actual coordinates to do the gather.
    # stacking will create a tensor of dimension batch * beam * 2, where
    # the last dimension contains the i,j gathering coordinates.
    top_coordinates = torch.stack([batch_pos, topk_indexes], dim=2)

    # Gather up the highest scoring sequences.
    topk_seq = gather_nd(sequences, top_coordinates)
    topk_flags = gather_nd(flags, top_coordinates)
    topk_gathered_scores = gather_nd(scores_to_gather, top_coordinates)
    if states_to_gather is not None:
        topk_gathered_states = map_structure(
            lambda state: gather_nd(state, top_coordinates), states_to_gather
        )
    else:
        topk_gathered_states = states_to_gather
    return topk_seq, topk_gathered_scores, topk_flags, topk_gathered_states


# TODO: Remove these once pylint supports function stubs.
# pylint: disable=unused-argument,function-redefined

@overload
def beam_search(
    symbols_to_logits_fn: Callable[[torch.Tensor, State],
                                   Tuple[torch.Tensor, State]],
    initial_ids: torch.LongTensor,
    beam_size: int,
    decode_length: int,
    vocab_size: int,
    alpha: float,
    eos_id: int,
    states: State,
    stop_early: bool = True) -> Tuple[torch.LongTensor, torch.Tensor]: ...


@overload
def beam_search(
    symbols_to_logits_fn: Callable[[torch.Tensor], torch.Tensor],
    initial_ids: torch.LongTensor,
    beam_size: int,
    decode_length: int,
    vocab_size: int,
    alpha: float,
    eos_id: int,
    states: Optional[State] = None,
    stop_early: bool = True) -> Tuple[torch.LongTensor, torch.Tensor]: ...

# pylint: enable=unused-argument




# pylint: enable=function-redefined