# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various optimization related utilities.
"""
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from mypy_extensions import TypedDict
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from texar.torch.hyperparams import HParams
from texar.torch.utils import utils
from texar.torch.utils.types import MaybeList
__all__ = [
"default_optimization_hparams",
"get_optimizer",
"get_scheduler",
"get_grad_clip_fn",
"get_train_op",
"BertAdam"
]
[docs]def default_optimization_hparams() -> Dict[str, Any]:
r"""Returns a `dict` of default hyperparameters of training op
and their default values
.. code-block:: python
{
"optimizer": {
"type": "Adam",
"kwargs": {
"lr": 0.001
}
},
"learning_rate_decay": {
"type": "",
"kwargs": {}
},
"gradient_clip": {
"type": "",
"kwargs": {}
},
"gradient_noise_scale": None,
"name": None
}
Here:
`"optimizer"`: dict
Hyperparameters of a
:torch_docs:`torch.optim.Optimizer <optim.html#torch.optim.Optimizer>`.
- `"type"` specifies the optimizer class. This can be
- The string name or full module path of an optimizer class.
If the class name is provided, the class must be in module
:torch_docs:`torch.optim <optim.html>` or :mod:`texar.torch.custom`,
:mod:`texar.torch.core.optimization`
- An optimizer class.
- An instance of an optimizer class.
For example
.. code-block:: python
"type": "Adam" # class name
"type": "my_module.MyOptimizer" # module path
"type": texar.torch.custom.BertAdam # class
"type": my_module.MyOptimizer # class
- `"kwargs"` is a `dict` specifying keyword arguments for creating
the optimizer class instance, with :python:`opt_class(**kwargs)`.
Ignored if `"type"` is a class instance.
`"learning_rate_decay"`: dict
Hyperparameters of learning rate decay function. The learning rate
starts decay from :attr:`"start_decay_step"` and keeps unchanged after
:attr:`"end_decay_step"` or reaching :attr:`"min_learning_rate"`.
The decay function is specified in `"type"` and `"kwargs"`.
- `"type"` can be a decay function or its name or module path. If
function name is provided, it must be from module
:torch_docs:`torch.optim <optim.html>` or :mod:`texar.torch.custom`,
:mod:`texar.torch.core.optimization`.
- `"kwargs"` is a `dict` of keyword arguments for the function
excluding arguments named `"global_step"` and `"learning_rate"`.
The function is called with
:python:`lr = decay_fn(learning_rate=lr, global_step=offset_step,
**kwargs)`, where `offset_step` is the global step offset as above.
`"gradient_clip"`: dict
Hyperparameters of gradient clipping. The gradient clipping function
takes a list of `(gradients, variables)` tuples and returns a list
of `(clipped_gradients, variables)` tuples. Typical examples include
:torch_nn:`utils.clip_grad_norm_` and
:torch_nn:`utils.clip_grad_value_`.
"type" specifies the gradient clip function, and can be a function,
or its name or module path. If function name is provided, the
function must be from module :mod:`torch.nn.utils`,
:mod:`texar.torch.custom`, or :mod:`texar.torch.core.optimization`.
`"kwargs"` specifies keyword arguments to the function, except arguments
named `"parameters"`.
`"gradient_noise_scale"`: float, optional
Adds 0-mean normal noise scaled by this value to gradient.
"""
return {
"optimizer": {
"type": "Adam",
"kwargs": {
"lr": 0.001
}
},
"learning_rate_decay": {
"type": "",
"kwargs": {}
},
"gradient_clip": {
"type": "",
"kwargs": {}
},
"gradient_noise_scale": None,
# TODO(zhiting): allow module-level control of gradient_multipliers
"name": None
}
[docs]def get_optimizer(
params: Iterable[Union[torch.Tensor, Dict[str, Any]]],
hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
Optimizer:
r"""Creates a optimizer instance.
Args:
params: an iterable of :class:`torch.Tensor` or
:class:`dict`. Specifies what Tensors should be optimized.
hparams (dict or HParams, optional): hyperparameters. Missing
hyperparameters are set to default values automatically. See
:func:`~texar.torch.core.default_optimization_hparams` for
all hyperparameters and default values.
:return:
The :torch_docs:`torch.optim.Optimizer
<optim.html#torch.optim.Optimizer>` instance specified in
:attr:`hparams`.
"""
if hparams is None or isinstance(hparams, dict):
hparams = HParams(hparams, default_optimization_hparams())
hparams_opt = hparams["optimizer"]
optimizer_type = hparams_opt["type"]
if isinstance(optimizer_type, Optimizer):
optimizer_class = optimizer_type
else:
optimizer_modules = ['torch.optim',
'texar.torch.custom']
try:
optimizer_class = utils.check_or_get_class( # type: ignore
optimizer_type, optimizer_modules, Optimizer)
except TypeError:
raise ValueError(
"Unrecognized optimizer. Must be string name of the "
"optimizer class, or the class which is a subclass of "
"torch.optim.Optimizer, or an instance of the subclass of "
"Optimizer.")
optimizer_kwargs = hparams_opt["kwargs"].todict()
optimizer_kwargs.update({"params": params})
optimizer = optimizer_class(**optimizer_kwargs) # type: ignore
return optimizer
[docs]def get_scheduler(optimizer: Optimizer,
hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
Optional[_LRScheduler]:
r"""Creates a scheduler instance.
Args:
optimizer: A :torch_docs:`torch.optim.Optimizer
<optim.html#torch.optim.Optimizer>` instance.
hparams (dict or HParams, optional): hyperparameters. Missing
hyperparameters are set to default values automatically. See
:func:`~texar.torch.core.default_optimization_hparams` for
all hyperparameters and default values.
:return:
A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
<optim.html#how-to-adjust-learning-rate>` instance.
"""
if hparams is None or isinstance(hparams, dict):
hparams = HParams(hparams, default_optimization_hparams())
hparams_scheduler = hparams["learning_rate_decay"]
scheduler_type = hparams_scheduler["type"]
if scheduler_type == "" or scheduler_type is None:
scheduler = None
else:
if isinstance(scheduler_type, _LRScheduler):
scheduler_class = scheduler_type
else:
scheduler_modules = ['torch.optim.lr_scheduler',
'texar.torch.custom']
try:
scheduler_class = utils.check_or_get_class( # type: ignore
scheduler_type, scheduler_modules, _LRScheduler)
except TypeError:
raise ValueError(
"Unrecognized lr_scheduler. Must be string name of the "
"lr_scheduler class, or the class which is a subclass of "
"torch.optim._LRScheduler.")
scheduler_kwargs = hparams_scheduler["kwargs"].todict()
scheduler_kwargs.update({"optimizer": optimizer})
scheduler = scheduler_class(**scheduler_kwargs) # type: ignore
return scheduler
[docs]def get_grad_clip_fn(hparams: Optional[Union[HParams,
Dict[str, Any]]] = None) -> \
Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]:
r"""Create a gradient clipping function.
Args:
hparams (dict or HParams, optional): hyperparameters. Missing
hyperparameters are set to default values automatically. See
:func:`~texar.torch.core.default_optimization_hparams` for
all hyperparameters and default values.
Returns:
A gradient clipping function.
"""
if hparams is None or isinstance(hparams, dict):
hparams = HParams(hparams, default_optimization_hparams())
hparams_grad_clip = hparams["gradient_clip"]
grad_clip_type = hparams_grad_clip["type"]
if grad_clip_type == "" or grad_clip_type is None:
grad_clip_fn = None
else:
grad_clip_modules = ['torch.nn.utils',
'texar.torch.custom']
grad_clip_fn = utils.get_function(grad_clip_type, grad_clip_modules)
grad_clip_fn_kwargs = hparams_grad_clip["kwargs"].todict()
grad_clip_fn = functools.partial(grad_clip_fn, **grad_clip_fn_kwargs)
return grad_clip_fn
[docs]def get_train_op(params: Optional[Iterable[Union[torch.Tensor,
Dict[str, Any]]]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
Callable[[], None]:
r"""Creates a training op.
Args:
params: an iterable of :class:`torch.Tensor` or
:class:`dict`. Specifies what Tensors should be optimized.
optimizer: A :torch_docs:`torch.optim.Optimizer
<optim.html#torch.optim.Optimizer>` instance.
scheduler: A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
<optim.html#how-to-adjust-learning-rate>` instance.
hparams (dict or HParams, optional): hyperparameters. Missing
hyperparameters are set to default values automatically. See
:func:`~texar.torch.core.default_optimization_hparams` for
all hyperparameters and default values.
Returns:
The callable used for variable optimization.
"""
hparams = HParams(hparams, default_optimization_hparams())
if params is None and optimizer is None and scheduler is None:
raise ValueError("'params', 'optimizer' and 'scheduler' must not be "
"None simultaneously.")
if scheduler is None:
if optimizer is None and params is not None:
optimizer = get_optimizer(params, hparams)
if optimizer is not None:
scheduler = get_scheduler(optimizer, hparams)
else:
optimizer = scheduler.optimizer # type: ignore
grad_clip_fn = get_grad_clip_fn(hparams)
# TODO: Support per-parameter options in the future.
params_list: List[nn.Parameter] = []
for param_group in optimizer.param_groups: # type: ignore
params = param_group["params"]
if isinstance(params, torch.Tensor):
params_list.append(params)
elif isinstance(params, list):
params_list += params
def _train_op():
if grad_clip_fn is not None:
grad_clip_fn(parameters=params_list)
optimizer.step()
# TODO: Ideally, scheduler should be used in the epoch level.
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
return _train_op
class BertAdamParamDict(TypedDict):
r"""The :attr:`param_groups` dictionary used in PyTorch optimizers."""
params: List[nn.Parameter]
lr: float
betas: Tuple[float, float]
eps: float
weight_decay: float
max_grad_norm: float
class BertAdamStateDict(TypedDict):
r"""The :attr:`state` dictionary used in :class:`BertAdam` optimizer."""
next_m: torch.Tensor
next_v: torch.Tensor
OptimParamType = Union[
MaybeList[Iterable[nn.Parameter]], # model.parameters()
MaybeList[Dict[str, Any]], # {"params": ..., "other_kwargs": ...}
]
class BertAdam(Optimizer):
r"""Implements BERT version of Adam algorithm with weight decay fix.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_grad_norm: Maximum norm for the gradients (-1 means no clipping).
Default: 1.0
"""
param_groups: List[BertAdamParamDict]
state: Dict[nn.Parameter, BertAdamStateDict]
def __init__(self, params: OptimParamType,
lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08, weight_decay: float = 0,
max_grad_norm: float = 1.0):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, max_grad_norm=max_grad_norm)
super().__init__(params, defaults) # type: ignore
def step(self, closure: Optional[Callable[[], float]] = None):
r"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please "
"consider SparseAdam instead")
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['betas']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['eps'])
# Just adding the square of the weights to the loss function is
# *not* # the correct way of using L2 regularization or weight
# decay with Adam, since that will interact with the m and v
# parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't
# interact with the m/v parameters. This is equivalent to adding
# the square of the weights to the loss with plain
# (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
lr = group['lr']
update_with_lr = lr * update
p.data.add_(-update_with_lr)
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
return loss