Source code for texar.torch.modules.pretrained.roberta

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utils of RoBERTa Modules.
"""

import os
from abc import ABC
from typing import Any, Dict

import torch

from texar.torch.modules.pretrained.pretrained_base import PretrainedMixin

__all__ = [
    "PretrainedRoBERTaMixin",
]

_ROBERTA_PATH = "https://dl.fbaipublicfiles.com/fairseq/models/"


[docs]class PretrainedRoBERTaMixin(PretrainedMixin, ABC): r"""A mixin class to support loading pre-trained checkpoints for modules that implement the RoBERTa model. The RoBERTa model was proposed in (`Liu et al`. 2019) `RoBERTa: A Robustly Optimized BERT Pretraining Approach`_. As a variant of the standard BERT model, RoBERTa trains for more iterations on more data with a larger batch size as well as other tweaks in pre-training. Differing from the standard BERT, the RoBERTa model does not use segmentation embedding. Available model names include: * ``roberta-base``: RoBERTa using the BERT-base architecture, 125M parameters. * ``roberta-large``: RoBERTa using the BERT-large architecture, 355M parameters. We provide the following RoBERTa classes: * :class:`~texar.torch.modules.RoBERTaEncoder` for text encoding. * :class:`~texar.torch.modules.RoBERTaClassifier` for text classification and sequence tagging. .. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`: https://arxiv.org/abs/1907.11692 """ _MODEL_NAME = "RoBERTa" _MODEL2URL = { 'roberta-base': _ROBERTA_PATH + "roberta.base.tar.gz", 'roberta-large': _ROBERTA_PATH + "roberta.large.tar.gz", } @classmethod def _transform_config(cls, pretrained_model_name: str, cache_dir: str) -> Dict[str, Any]: info = list(os.walk(cache_dir)) root, _, files = info[0] config_path = None for file in files: if file.endswith('model.pt'): config_path = os.path.join(root, file) args = torch.load(config_path, map_location="cpu")['args'] hidden_dim = args.encoder_embed_dim vocab_size = 50265 position_size = args.max_positions + 2 embedding_dropout = args.dropout num_blocks = args.encoder_layers num_heads = args.encoder_attention_heads dropout_rate = args.attention_dropout residual_dropout = args.dropout intermediate_size = args.encoder_ffn_embed_dim hidden_act = args.activation_fn if config_path is None: raise ValueError(f"Cannot find the config file in {cache_dir}") configs = { 'hidden_size': hidden_dim, 'embed': { 'name': 'word_embeddings', 'dim': hidden_dim }, 'vocab_size': vocab_size, 'position_embed': { 'name': 'position_embeddings', 'dim': hidden_dim }, 'position_size': position_size, 'encoder': { 'name': 'encoder', 'embedding_dropout': embedding_dropout, 'num_blocks': num_blocks, 'multihead_attention': { 'use_bias': True, 'num_units': hidden_dim, 'num_heads': num_heads, 'output_dim': hidden_dim, 'dropout_rate': dropout_rate, 'name': 'self' }, 'residual_dropout': residual_dropout, 'dim': hidden_dim, 'eps': 1e-12, 'use_bert_config': True, 'poswise_feedforward': { "layers": [{ 'type': 'Linear', 'kwargs': { 'in_features': hidden_dim, 'out_features': intermediate_size, 'bias': True, } }, { 'type': 'Bert' + hidden_act.upper() }, { 'type': 'Linear', 'kwargs': { 'in_features': intermediate_size, 'out_features': hidden_dim, 'bias': True, } }], }, } } return configs def _init_from_checkpoint(self, pretrained_model_name: str, cache_dir: str, **kwargs): global_tensor_map = { 'decoder.sentence_encoder.embed_tokens.weight': 'word_embedder._embedding', 'decoder.sentence_encoder.embed_positions.weight': 'position_embedder._embedding', 'decoder.sentence_encoder.emb_layer_norm.weight': 'encoder.input_normalizer.weight', 'decoder.sentence_encoder.emb_layer_norm.bias': 'encoder.input_normalizer.bias', } attention_tensor_map = { 'final_layer_norm.weight': 'encoder.output_layer_norm.{}.weight', 'final_layer_norm.bias': 'encoder.output_layer_norm.{}.bias', 'fc1.weight': 'encoder.poswise_networks.{}._layers.0.weight', 'fc1.bias': 'encoder.poswise_networks.{}._layers.0.bias', 'fc2.weight': 'encoder.poswise_networks.{}._layers.2.weight', 'fc2.bias': 'encoder.poswise_networks.{}._layers.2.bias', 'self_attn_layer_norm.weight': 'encoder.poswise_layer_norm.{}.weight', 'self_attn_layer_norm.bias': 'encoder.poswise_layer_norm.{}.bias', 'self_attn.out_proj.weight': 'encoder.self_attns.{}.O_dense.weight', 'self_attn.out_proj.bias': 'encoder.self_attns.{}.O_dense.bias', 'self_attn.in_proj_weight': [ 'encoder.self_attns.{}.Q_dense.weight', 'encoder.self_attns.{}.K_dense.weight', 'encoder.self_attns.{}.V_dense.weight', ], 'self_attn.in_proj_bias': [ 'encoder.self_attns.{}.Q_dense.bias', 'encoder.self_attns.{}.K_dense.bias', 'encoder.self_attns.{}.V_dense.bias' ], } checkpoint_path = os.path.abspath(os.path.join(cache_dir, 'model.pt')) device = next(self.parameters()).device params = torch.load(checkpoint_path, map_location=device)['model'] for name, tensor in params.items(): if name in global_tensor_map: v_name = global_tensor_map[name] pointer = self._name_to_variable(v_name) assert pointer.shape == tensor.shape pointer.data = tensor.data.type(pointer.dtype) elif name.startswith('decoder.sentence_encoder.layers.'): name = name.lstrip('decoder.sentence_encoder.layers.') layer_num, layer_name = name.split('.', 1) if layer_name in attention_tensor_map: v_names = attention_tensor_map[layer_name] if isinstance(v_names, str): pointer = self._name_to_variable( v_names.format(layer_num)) assert pointer.shape == tensor.shape pointer.data = tensor.data.type(pointer.dtype) else: # Q, K, V in self-attention tensors = torch.chunk(tensor, chunks=3, dim=0) for i in range(3): pointer = self._name_to_variable( v_names[i].format(layer_num)) assert pointer.shape == tensors[i].shape pointer.data = tensors[i].data.type(pointer.dtype) else: raise NameError(f"Layer name '{layer_name}' not found")