# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Hyperparameter manager
"""
import copy
import json
from typing import (
Any, Dict, ItemsView, Iterator, KeysView, Optional, Tuple, Union)
__all__ = [
'HParams',
]
def _type_name(value):
return type(value).__name__
[docs]class HParams:
r"""A class that maintains hyperparameters for configuring Texar modules.
The class has several useful features:
- **Auto-completion of missing values.** Users can specify only a subset of
hyperparameters they care about. Other hyperparameters will automatically
take the default values. The auto-completion performs **recursively** so
that hyperparameters taking `dict` values will also be auto-completed
**All Texar modules** provide a :meth:`default_hparams` containing
allowed hyperparameters and their default values. For example:
.. code-block:: python
## Recursive auto-completion
default_hparams = {"a": 1, "b": {"c": 2, "d": 3}}
hparams = {"b": {"c": 22}}
hparams_ = HParams(hparams, default_hparams)
hparams_.todict() == {"a": 1, "b": {"c": 22, "d": 3}}
# "a" and "d" are auto-completed
## All Texar modules have built-in `default_hparams`
hparams = {"dropout_rate": 0.1}
emb = tx.modules.WordEmbedder(hparams=hparams, ...)
emb.hparams.todict() == {
"dropout_rate": 0.1, # provided value
"dim": 100 # default value
...
}
- **Automatic type-check.** For most hyperparameters, provided value must
have the same or compatible dtype with the default value. :class:`HParams`
does necessary type-check, and raises Error if improper dtype is provided.
Also, hyperparameters not listed in `default_hparams` are not allowed,
except for `"kwargs"` as detailed below.
- **Flexible dtype for specified hyperparameters.** Some hyperparameters
may allow different dtypes of values.
- Hyperparameters named `"type"` are not type-checked.
For example, in :func:`~texar.torch.core.get_rnn_cell`, hyperparameter
`"type"` can take value of an RNNCell class, its string name of module
path, or an RNNCell class instance. (String name or module path is
allowed so that users can specify the value in YAML configuration
files.)
- For other hyperparameters, list them in the `"@no_typecheck"` field
in :meth:`default_hparams` to skip type-check. For example, in
:class:`~texar.torch.modules.Conv1DNetwork`, hyperparameter
`"kernel_size"` can be set to either a `list` of `int`\ s or simply
an `int`.
- **Special flexibility of keyword argument hyperparameters.**
Hyperparameters named ``"kwargs"`` are used as keyword arguments for a
class constructor or a function call. Such hyperparameters take a `dict`,
and users can add arbitrary valid keyword arguments to the dict.
For example:
.. code-block:: python
default_rnn_cell_hparams = {
"type": "LSTMCell",
"kwargs": {"num_units": 256}
# Other hyperparameters
...
}
my_hparams = {
"kwargs" {
"num_units": 123,
# Other valid keyword arguments for LSTMCell constructor
"forget_bias": 0.0
"activation": "torch.nn.functional.relu"
}
}
_ = HParams(my_hparams, default_rnn_cell_hparams)
- **Rich interfaces.** An :class:`HParams` instance provides rich interfaces
for accessing, updating, or adding hyperparameters.
.. code-block:: python
hparams = HParams(my_hparams, default_hparams)
# Access
hparams.type == hparams["type"]
# Update
hparams.type = "GRUCell"
hparams.kwargs = { "num_units": 100 }
hparams.kwargs.num_units == 100
# Add new
hparams.add_hparam("index", 1)
hparams.index == 1
# Convert to `dict` (recursively)
type(hparams.todic()) == dict
# I/O
pickle.dump(hparams, "hparams.dump")
with open("hparams.dump", 'rb') as f:
hparams_loaded = pickle.load(f)
Args:
hparams: A `dict` or an :class:`HParams` instance containing
hyperparameters. If `None`, all hyperparameters are set to default
values.
default_hparams (dict): Hyperparameters with default values. If `None`,
Hyperparameters are fully defined by :attr:`hparams`.
allow_new_hparam (bool): If `False` (default), :attr:`hparams` cannot
contain hyperparameters that are not included in
:attr:`default_hparams`, except for the case of :attr:`"kwargs"` as
above.
"""
# - The default hyperparameters in :attr:`"kwargs"` are used (for type-check
# and complementing missing hyperparameters) only when :attr:`"type"`
# takes default value (i.e., missing in :attr:`hparams` or set to
# the same value with the default). In this case :attr:`kwargs` allows to
# contain new keys not included in :attr:`default_hparams["kwargs"]`.
#
# - If :attr:`"type"` is set to an other value and :attr:`"kwargs"` is
# missing in :attr:`hparams`, :attr:`"kwargs"` is set to an empty
# dictionary.
def __init__(self, hparams: Optional[Union['HParams', Dict[str, Any]]],
default_hparams: Optional[Dict[str, Any]],
allow_new_hparam: bool = False):
if isinstance(hparams, HParams):
hparams = hparams.todict()
if default_hparams is not None:
parsed_hparams = self._parse(
hparams, default_hparams, allow_new_hparam)
else:
parsed_hparams = self._parse(hparams, hparams)
super().__setattr__('_hparams', parsed_hparams)
@staticmethod
def _parse(hparams: Optional[Dict[str, Any]],
default_hparams: Optional[Dict[str, Any]],
allow_new_hparam: bool = False):
r"""Parses hyperparameters.
Args:
hparams (dict): Hyperparameters. If `None`, all hyperparameters are
set to default values.
default_hparams (dict): Hyperparameters with default values.
If `None`,Hyperparameters are fully defined by :attr:`hparams`.
allow_new_hparam (bool): If `False` (default), :attr:`hparams`
cannot contain hyperparameters that are not included in
:attr:`default_hparams`, except the case of :attr:`"kwargs"`.
Return:
A dictionary of parsed hyperparameters. Returns `None` if both
:attr:`hparams` and :attr:`default_hparams` are `None`.
Raises:
ValueError: If :attr:`hparams` is not `None` and
:attr:`default_hparams` is `None`.
ValueError: If :attr:`default_hparams` contains "kwargs" not does
not contains "type".
"""
if hparams is None and default_hparams is None:
return None
if hparams is None:
return HParams._parse(default_hparams, default_hparams)
if default_hparams is None:
raise ValueError("`default_hparams` cannot be `None` if `hparams` "
"is not `None`.")
no_typecheck_names = default_hparams.get('@no_typecheck', [])
if "kwargs" in default_hparams and "type" not in default_hparams:
raise ValueError("Ill-defined hyperparameter structure: 'kwargs' "
"must accompany with 'type'.")
parsed_hparams = copy.deepcopy(default_hparams)
# Parse recursively for params of type dictionary that are missing
# in `hparams`.
for name, value in default_hparams.items():
if name not in hparams and isinstance(value, dict):
if (name == 'kwargs' and 'type' in hparams and
hparams['type'] != default_hparams['type']):
# Set params named "kwargs" to empty dictionary if "type"
# takes value other than default.
parsed_hparams[name] = HParams({}, {})
else:
parsed_hparams[name] = HParams(value, value)
# Parse hparams
for name, value in hparams.items():
if name not in default_hparams:
if allow_new_hparam:
parsed_hparams[name] = HParams._parse_value(value, name)
continue
raise ValueError(
"Unknown hyperparameter: %s. "
"Only when allow_new_hparam is set to True and "
"only hyperparameters named 'kwargs' can contain new "
"entries undefined in default hyperparameters." % name)
if value is None:
parsed_hparams[name] = HParams._parse_value(
parsed_hparams[name])
default_value = default_hparams[name]
if default_value is None:
parsed_hparams[name] = HParams._parse_value(value)
continue
# Parse recursively for params of type dictionary.
if isinstance(value, dict):
if (name not in no_typecheck_names and
not isinstance(default_value, dict)):
raise ValueError(
"Hyperparameter '%s' must have type %s, got %s" %
(name, _type_name(default_value), _type_name(value)))
if name == "kwargs":
if ("type" in hparams and
hparams["type"] != default_hparams['type']):
# Leave "kwargs" as-is if "type" takes value
# other than default.
parsed_hparams[name] = HParams(value, value)
else:
# Allow new hyperparameters if "type" takes default
# value
parsed_hparams[name] = HParams(
value, default_value, allow_new_hparam=True)
elif name in no_typecheck_names:
parsed_hparams[name] = HParams(value, value)
else:
parsed_hparams[name] = HParams(
value, default_value, allow_new_hparam)
continue
# Do not type-check hyperparameter named "type" and accompanied
# with "kwargs"
if name == 'type' and 'kwargs' in default_hparams:
parsed_hparams[name] = value
continue
if name in no_typecheck_names:
parsed_hparams[name] = value
elif isinstance(value, type(default_value)):
parsed_hparams[name] = value
elif callable(value) and callable(default_value):
parsed_hparams[name] = value
else:
try:
parsed_hparams[name] = type(default_value)(value)
except TypeError as err:
raise ValueError(
"Hyperparameter '%s' must have type %s, got %s" %
(name, _type_name(default_value), _type_name(value))
) from err
return parsed_hparams
@staticmethod
def _parse_value(value: Any, name: Optional[str] = None) -> Any:
if isinstance(value, dict) and (name is None or name != "kwargs"):
return HParams(value, None)
else:
return value
def __getattr__(self, name: str) -> Any:
r"""Retrieves the value of the hyperparameter.
"""
if name == '_hparams':
return super().__getattribute__('_hparams')
if name not in self._hparams:
# Raise AttributeError to allow copy.deepcopy, etc
raise AttributeError("Unknown hyperparameter: %s" % name)
return self._hparams[name]
def __getitem__(self, name: str) -> Any:
r"""Retrieves the value of the hyperparameter.
"""
return self.__getattr__(name)
def __setattr__(self, name: str, value: Any):
r"""Sets the value of the hyperparameter.
"""
if name not in self._hparams:
raise ValueError(
"Unknown hyperparameter: %s. Only the `kwargs` "
"hyperparameters can contain new entries undefined "
"in default hyperparameters." % name)
self._hparams[name] = self._parse_value(value, name)
[docs] def items(self) -> ItemsView[str, Any]:
r"""Returns the list of hyperparameter `(name, value)` pairs.
"""
return self._hparams.items()
[docs] def keys(self) -> KeysView[str]:
r"""Returns the list of hyperparameter names.
"""
return self._hparams.keys()
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for name, value in self._hparams.items():
yield name, value
def __len__(self) -> int:
return len(self._hparams)
def __contains__(self, name) -> bool:
return name in self._hparams
def __str__(self) -> str:
r"""Return a string of the hyperparameters.
"""
hparams_dict = self.todict()
return json.dumps(hparams_dict, sort_keys=True, indent=2)
[docs] def get(self, name: str, default: Optional[Any] = None) -> Any:
r"""Returns the hyperparameter value for the given name. If name is not
available then returns :attr:`default`.
Args:
name (str): the name of hyperparameter.
default: the value to be returned in case name does not exist.
"""
try:
return self.__getattr__(name)
except AttributeError:
return default
[docs] def add_hparam(self, name: str, value: Any):
r"""Adds a new hyperparameter.
"""
if (name in self._hparams) or hasattr(self, name):
raise ValueError("Hyperparameter name already exists: %s" % name)
self._hparams[name] = self._parse_value(value, name)
[docs] def todict(self) -> Dict[str, Any]:
r"""Returns a copy of hyperparameters as a dictionary.
"""
dict_ = copy.deepcopy(self._hparams)
for name, value in self._hparams.items():
if isinstance(value, HParams):
dict_[name] = value.todict()
return dict_