Source code for texar.torch.data.tokenizers.xlnet_tokenizer

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre-trained XLNet Tokenizer.

Code structure adapted from:
    `https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/tokenization_xlnet.py`
"""

from typing import Any, Dict, List, Optional, Tuple

import os
import unicodedata
from shutil import copyfile
import sentencepiece as spm

from texar.torch.modules.pretrained.xlnet import PretrainedXLNetMixin
from texar.torch.data.tokenizers.tokenizer_base import TokenizerBase
from texar.torch.utils.utils import truncate_seq_pair

__all__ = [
    "XLNetTokenizer",
]

SPIECE_UNDERLINE = u'▁'

SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4


[docs]class XLNetTokenizer(PretrainedXLNetMixin, TokenizerBase): r"""Pre-trained XLNet Tokenizer. Args: pretrained_model_name (optional): a `str`, the name of pre-trained model (e.g., `xlnet-base-uncased`). Please refer to :class:`~texar.torch.modules.PretrainedXLNetMixin` for all supported models. If None, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), a default directory (``texar_data`` folder under user's home directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparameter will be set to default values. See :meth:`default_hparams` for the hyperparameter structure and default values. """ _IS_PRETRAINED = True _MAX_INPUT_SIZE = { 'xlnet-base-cased': None, 'xlnet-large-cased': None, } _VOCAB_FILE_NAMES = {'vocab_file': 'spiece.model'} _VOCAB_FILE_MAP = { 'vocab_file': { 'xlnet-base-cased': 'spiece.model', 'xlnet-large-cased': 'spiece.model', } } def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, hparams=None): self.load_pretrained_config(pretrained_model_name, cache_dir, hparams) super().__init__(hparams=None) self.__dict__: Dict self.config = { 'do_lower_case': self.hparams['do_lower_case'], 'remove_space': self.hparams['remove_space'], 'keep_accents': self.hparams['keep_accents'], } if self.pretrained_model_dir is not None: assert self.pretrained_model_name is not None vocab_file = os.path.join(self.pretrained_model_dir, self._VOCAB_FILE_MAP['vocab_file'] [self.pretrained_model_name]) assert self.pretrained_model_name is not None if self._MAX_INPUT_SIZE.get(self.pretrained_model_name): self.max_len = self._MAX_INPUT_SIZE[self.pretrained_model_name] else: vocab_file = self.hparams['vocab_file'] if self.hparams.get('max_len'): self.max_len = self.hparams['max_len'] if not os.path.isfile(vocab_file): raise ValueError("Can't find a vocabulary file at path " "'{}".format(vocab_file)) self.do_lower_case = self.hparams["do_lower_case"] self.remove_space = self.hparams["remove_space"] self.keep_accents = self.hparams["keep_accents"] self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(self.vocab_file) # spm.SentencePieceProcessor() is a SwigPyObject object which cannot be # pickled. We need to define __getstate__ here. def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None state["vocab_file"] = None return state, self.vocab_file # spm.SentencePieceProcessor() is a SwigPyObject object which cannot be # pickled. We need to define __setstate__ here. def __setstate__(self, d): self.__dict__, self.vocab_file = d self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(self.vocab_file) def _preprocess_text(self, inputs: str) -> str: r"""Pre-process the text, including removing space, stripping accents, and lower-casing the text. """ if self.remove_space: outputs = ' '.join(inputs.strip().split()) else: outputs = inputs outputs = outputs.replace("``", '"').replace("''", '"') if not self.keep_accents: outputs = unicodedata.normalize('NFKD', outputs) outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) if self.do_lower_case: outputs = outputs.lower() return outputs def _map_text_to_token(self, text: str, # type: ignore sample: bool = False) -> List[str]: text = self._preprocess_text(text) if not sample: pieces = self.sp_model.EncodeAsPieces(text) else: pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) new_pieces: List[str] = [] for piece in pieces: if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): cur_pieces = self.sp_model.EncodeAsPieces( piece[:-1].replace(SPIECE_UNDERLINE, '')) if piece[0] != SPIECE_UNDERLINE and \ cur_pieces[0][0] == SPIECE_UNDERLINE: if len(cur_pieces[0]) == 1: cur_pieces = cur_pieces[1:] else: cur_pieces[0] = cur_pieces[0][1:] cur_pieces.append(piece[-1]) new_pieces.extend(cur_pieces) else: new_pieces.append(piece) return new_pieces
[docs] def save_vocab(self, save_dir: str) -> Tuple[str]: r"""Save the sentencepiece vocabulary (copy original file) to a directory. """ if not os.path.isdir(save_dir): raise ValueError("Vocabulary path ({}) should be a " "directory".format(save_dir)) out_vocab_file = os.path.join(save_dir, self._VOCAB_FILE_NAMES['vocab_file']) if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,)
@property def vocab_size(self) -> int: return len(self.sp_model) def _map_token_to_id(self, token: str) -> int: return self.sp_model.PieceToId(token) def _map_id_to_token(self, index: int) -> str: token = self.sp_model.IdToPiece(index) return token
[docs] def map_token_to_text(self, tokens: List[str]) -> str: r"""Maps a sequence of tokens (string) in a single string.""" out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string
[docs] def encode_text(self, text_a: str, text_b: Optional[str] = None, max_seq_length: Optional[int] = None) -> \ Tuple[List[int], List[int], List[int]]: r"""Adds special tokens to a sequence or sequence pair and computes the corresponding segment ids and input mask for XLNet specific tasks. The sequence will be truncated if its length is larger than ``max_seq_length``. A XLNet sequence has the following format: X `[sep_token]` `[cls_token]` A XLNet sequence pair has the following format: `[cls_token]` A `[sep_token]` B `[sep_token]` Args: text_a: The first input text. text_b: The second input text. max_seq_length: Maximum sequence length. Returns: A tuple of `(input_ids, segment_ids, input_mask)`, where - ``input_ids``: A list of input token ids with added special token ids. - ``segment_ids``: A list of segment ids. - ``input_mask``: A list of mask ids. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. """ if max_seq_length is None: max_seq_length = self.max_len cls_token_id = self._map_token_to_id(self.cls_token) sep_token_id = self._map_token_to_id(self.sep_token) token_ids_a = self.map_text_to_id(text_a) assert isinstance(token_ids_a, list) token_ids_b = None if text_b: token_ids_b = self.map_text_to_id(text_b) if token_ids_b: assert isinstance(token_ids_b, list) # Modifies `token_ids_a` and `token_ids_b` in place so that the # total length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 3) input_ids = (token_ids_a + [sep_token_id] + token_ids_b + [sep_token_id] + [cls_token_id]) segment_ids = [SEG_ID_A] * (len(token_ids_a) + 1) + \ [SEG_ID_B] * (len(token_ids_b) + 1) + [SEG_ID_CLS] else: # Account for [CLS] and [SEP] with "- 2" token_ids = token_ids_a[:max_seq_length - 2] input_ids = token_ids + [sep_token_id] + [cls_token_id] segment_ids = [SEG_ID_A] * (len(input_ids) - 1) + [SEG_ID_CLS] input_mask = [0] * len(input_ids) # Zero-pad up to the maximum sequence length. input_ids = [0] * (max_seq_length - len(input_ids)) + input_ids input_mask = [1] * (max_seq_length - len(input_mask)) + input_mask segment_ids = ([SEG_ID_PAD] * (max_seq_length - len(segment_ids)) + segment_ids) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length return input_ids, segment_ids, input_mask
[docs] def encode_text_for_generation( self, text: str, max_seq_length: Optional[int] = None, append_eos_token: bool = True) -> Tuple[List[int], int]: r"""Adds special tokens to a sequence and computes the corresponding sequence length for XLNet specific tasks. The sequence will be truncated if its length is larger than ``max_seq_length``. A XLNet sequence has the following format: `[bos_token]` X `[eos_token]` `[pad_token]` Args: text: Input text. max_seq_length: Maximum sequence length. append_eos_token: Whether to append ``eos_token`` after the sequence. Returns: A tuple of `(input_ids, seq_len)`, where - ``input_ids``: A list of input token ids with added special tokens. - ``seq_len``: The sequence length. """ if max_seq_length is None: max_seq_length = self.max_len token_ids = self.map_text_to_id(text) assert isinstance(token_ids, list) bos_token_id = self._map_token_to_id(self.bos_token) eos_token_id = self._map_token_to_id(self.eos_token) pad_token_id = self._map_token_to_id(self.pad_token) if append_eos_token: input_ids = token_ids[:max_seq_length - 2] input_ids = [bos_token_id] + input_ids + [eos_token_id] else: input_ids = token_ids[:max_seq_length - 1] input_ids = [bos_token_id] + input_ids seq_len = len(input_ids) # Pad up to the maximum sequence length. input_ids = input_ids + [pad_token_id] * (max_seq_length - seq_len) assert len(input_ids) == max_seq_length return input_ids, seq_len
[docs] @staticmethod def default_hparams() -> Dict[str, Any]: r"""Returns a dictionary of hyperparameters with default values. * The tokenizer is determined by the constructor argument :attr:`pretrained_model_name` if it's specified. In this case, `hparams` are ignored. * Otherwise, the tokenizer is determined by `hparams['pretrained_model_name']` if it's specified. All other configurations in `hparams` are ignored. * If the above two are `None`, the tokenizer is defined by the configurations in `hparams`. .. code-block:: python { "pretrained_model_name": "xlnet-base-cased", "vocab_file": None, "max_len": None, "bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "<sep>", "pad_token": "<pad>", "cls_token": "<cls>", "mask_token": "<mask>", "additional_special_tokens": ["<eop>", "<eod>"], "do_lower_case": False, "remove_space": True, "keep_accents": False, "name": "xlnet_tokenizer", } Here: `"pretrained_model_name"`: str or None The name of the pre-trained XLNet model. `"vocab_file"`: str or None The path to a sentencepiece vocabulary file. `"max_len"`: int or None The maximum sequence length that this model might ever be used with. `"bos_token"`: str Beginning of sentence token. `"eos_token"`: str End of sentence token. `"unk_token"`: str Unknown token. `"sep_token"`: str Separation token. `"pad_token"`: str Padding token. `"cls_token"`: str Classification token. `"mask_token"`: str Masking token. `"additional_special_tokens"`: list A list of additional special tokens. `"do_lower_case"`: bool Whether to lower-case the text. `"remove_space"`: bool Whether to remove the space in the text. `"keep_accents"`: bool Whether to keep the accents in the text. `"name"`: str Name of the tokenizer. """ return { 'pretrained_model_name': 'xlnet-base-cased', 'vocab_file': None, 'max_len': None, 'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['<eop>', '<eod>'], 'do_lower_case': False, 'remove_space': True, 'keep_accents': False, 'name': 'xlnet_tokenizer', '@no_typecheck': ['pretrained_model_name'], }
@classmethod def _transform_config(cls, pretrained_model_name: str, cache_dir: str): r"""Returns the configuration of the pre-trained XLNet tokenizer.""" return { 'vocab_file': None, 'max_len': None, 'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['<eop>', '<eod>'], 'do_lower_case': False, 'remove_space': True, 'keep_accents': False, }