Utilities

Frequent Use

AverageRecorder

class texar.torch.utils.AverageRecorder(size=None)[source]

Maintains the moving averages (i.e., the average of the latest N records) of (possibly multiple) fields.

Fields are determined by the first call of add().

Parameters

size (int, optional) – The window size of moving average. If None, the average of all added records is maintained.

Example

## Use to maintain moving average of training loss
avg_rec = AverageRecorder(size=10) # average over latest 10 records
while training:
    loss_0, loss_1  = ...
    avg_rec.add([loss_0, loss_1])
    # avg_rec.avg() == [0.12343452, 0.567800323]
    # avg_rec.avg(0) == 0.12343452
    # avg_rec.to_str(precision=2, ) == '0.12 0.57'

## Use to maintain average of test metrics on the whole test set
avg_rec = AverageRecorder() # average over ALL records
while test:
    metric_0, metric_1  = ...
    avg_rec.add({'m0': metric_0, 'm1': metric_1}) # dict is allowed
print(avg_rec.to_str(precision=4, delimiter=' , '))
# 'm0: 0.1234 , m1: 0.5678'
#
# avg_rec.avg() == {'m0': 0.12343452, 'm1': 0.567800323}
# avg_rec.avg(0) == 0.12343452
add(record, weight=None)[source]

Appends a new record.

record can be a list, dict, or a single scalar. The record type is determined at the first time add() is called. All subsequent calls to add() must have the same type of record.

record in subsequent calls to add() can contain only a subset of fields than the first call to add().

Example

recorder.add({'1': 0.2, '2': 0.2}) # 1st call to `add`
x = recorder.add({'1': 0.4}) # 2nd call to `add`
# x == {'1': 0.3, '2': 0.2}
Parameters
  • record – A single scalar, a list of scalars, or a dict of scalars.

  • weight (optional) – A scalar, weight of the new record for calculating a weighted average. If None, weight is set to 1. For example, weight can be set to batch size and record the average value of certain metrics on the batch in order to calculate the average metric values on a whole dataset.

Returns

The (moving) average after appending the record, with the same type as record.

avg(id_or_name=None)[source]

Returns the (moving) average.

Parameters

id_or_name (optional) – A list of or a single element. Each element is the index (if the record type is list) or name (if the record type is dict) of the field for which the average is calculated. If not given, the average of all fields are returned.

Returns

The average value(s). If id_or_name is a single element (not a list), then returns the average value of the corresponding field. Otherwise, if id_or_name is a list of element(s), then returns average value(s) in the same type as record of add().

reset(id_or_name=None)[source]

Resets the record.

Parameters

id_or_name (optional) – A list or a single element. Each element is the index (if the record type is list) or name (if the record type is dict) of the field to reset. If None, all fields are reset.

to_str(precision=None, delimiter=' ')[source]

Returns a string of the average values of the records.

Parameters
  • precision (int, optional) – The number of decimal places to keep in the returned string. For example, for an average value of 0.1234, precision = 2 leads to "0.12".

  • delimiter (str) – The delimiter string that separates between fields.

Returns

A string of the average values.

If record is of type dict, the string is a concatenation of "field_name: average_value", delimited with delimiter. For example, "field_name_1: 0.1234 field_name_2: 0.5678 ...".

Otherwise, the string is of a concatenation of ‘average_value’. For example, "0.1234 0.5678 ..."

compat_as_text

texar.torch.utils.compat_as_text(str_)[source]

Converts strings into unicode (Python 2) or str (Python 3).

Parameters

str_ – A string or other data types convertible to string, or an n-D numpy array or (possibly nested) list of such elements.

Returns

The converted strings of the same structure/shape as str_.

write_paired_text

texar.torch.utils.write_paired_text(src, tgt, fname, append=False, mode='h', sep='\t', src_fname_suffix='src', tgt_fname_suffix='tgt')[source]

Writes paired text to a file.

Parameters
  • src – A list (or array) of str source text.

  • tgt – A list (or array) of str target text.

  • fname (str) – The output filename.

  • append (bool) – Whether append content to the end of the file if exists.

  • mode (str) –

    The mode of writing, with the following options:

    • ’h’: The “horizontal” mode. Each source target pair is written in one line, intervened with sep, e.g.:

      source_1 target_1
      source_2 target_2
      
    • ’v’: The "vertical" mode. Each source target pair is written in two consecutive lines, e.g:

      source_1
      target_1
      source_2
      target_2
      
    • ’s’: The “separate” mode. Each source target pair is written in corresponding lines of two files named as "{fname}.{src_fname_suffix}" and "{fname}.{tgt_fname_suffix}", respectively.

  • sep (str) – The string intervening between source and target. Used when mode is set to "h".

  • src_fname_suffix (str) – Used when mode is "s". The suffix to the source output filename. For example, with (fname='output', src_fname_suffix='src'), the output source file is named as output.src.

  • tgt_fname_suffix (str) – Used when mode is "s". The suffix to the target output filename.

Returns

The filename(s). If mode == "h" or "v", returns fname. If mode == "s", returns a list of filenames ["{fname}.src", "{fname}.tgt"].

IO

write_paired_text

texar.torch.utils.write_paired_text(src, tgt, fname, append=False, mode='h', sep='\t', src_fname_suffix='src', tgt_fname_suffix='tgt')[source]

Writes paired text to a file.

Parameters
  • src – A list (or array) of str source text.

  • tgt – A list (or array) of str target text.

  • fname (str) – The output filename.

  • append (bool) – Whether append content to the end of the file if exists.

  • mode (str) –

    The mode of writing, with the following options:

    • ’h’: The “horizontal” mode. Each source target pair is written in one line, intervened with sep, e.g.:

      source_1 target_1
      source_2 target_2
      
    • ’v’: The "vertical" mode. Each source target pair is written in two consecutive lines, e.g:

      source_1
      target_1
      source_2
      target_2
      
    • ’s’: The “separate” mode. Each source target pair is written in corresponding lines of two files named as "{fname}.{src_fname_suffix}" and "{fname}.{tgt_fname_suffix}", respectively.

  • sep (str) – The string intervening between source and target. Used when mode is set to "h".

  • src_fname_suffix (str) – Used when mode is "s". The suffix to the source output filename. For example, with (fname='output', src_fname_suffix='src'), the output source file is named as output.src.

  • tgt_fname_suffix (str) – Used when mode is "s". The suffix to the target output filename.

Returns

The filename(s). If mode == "h" or "v", returns fname. If mode == "s", returns a list of filenames ["{fname}.src", "{fname}.tgt"].

maybe_create_dir

texar.torch.utils.maybe_create_dir(dirname)[source]

Creates directory if it does not exist.

Parameters

dirname (str) – Path to the directory.

Returns

Whether a new directory is created.

Return type

bool

DType

get_numpy_dtype

texar.torch.utils.get_numpy_dtype(dtype)[source]

Returns equivalent NumPy dtype.

Parameters

dtype – A str, Python numeric or string type, NumPy data type, or PyTorch dtype.

Returns

The corresponding NumPy dtype.

maybe_hparams_to_dict

texar.torch.utils.maybe_hparams_to_dict(hparams)[source]

If hparams is an instance of HParams, converts it to a dict and returns. If hparams is a dict, returns as is.

Parameters

hparams – The HParams instance to convert.

Returns

The corresponding dict instance

Return type

dict

compat_as_text

texar.torch.utils.compat_as_text(str_)[source]

Converts strings into unicode (Python 2) or str (Python 3).

Parameters

str_ – A string or other data types convertible to string, or an n-D numpy array or (possibly nested) list of such elements.

Returns

The converted strings of the same structure/shape as str_.

Shape

mask_sequences

texar.torch.utils.mask_sequences(sequence, sequence_length, dtype=None, time_major=False)[source]

Masks out sequence entries that are beyond the respective sequence lengths. Masks along the time dimension.

sequence and sequence_length can either be python arrays or Tensors, respectively. If both are Python arrays (or None), the return will be a Python array as well.

Parameters
  • sequence – A Tensor or Python array of sequence values. If time_major==False (default), this must be a Tensor of shape [batch_size, max_time, ...]. The batch and time dimension is exchanged if time_major==True.

  • sequence_length – A Tensor or python array of shape [batch_size]. Time steps beyond the respective sequence lengths will be made zero.

  • dtype (dtype) – Type of sequence. If None, infer from sequence automatically.

  • time_major (bool) – The shape format of the inputs. If True, sequence must have shape [max_time, batch_size, ...]. If False (default), sequence must have shape [batch_size, max_time, ...].

Returns

The masked sequence, i.e., a Tensor or python array of the same shape as sequence but with masked-out entries (set to zero).

If both sequence and sequence_length are python arrays, the returned value is a python array as well.

transpose_batch_time

texar.torch.utils.transpose_batch_time(inputs)[source]

Transposes inputs between time-major and batch-major.

Parameters

inputs – A Tensor of shape [batch_size, max_time, ...] (batch-major) or [max_time, batch_size, ...] (time-major), or a (possibly nested) tuple of such elements.

Returns

A (possibly nested tuple of) Tensor with transposed batch and time dimensions of inputs.

flatten

texar.torch.utils.flatten(tensor, preserve_dims, flattened_dim=None)[source]

Flattens a tensor whiling keeping several leading dimensions.

preserve_dims must be less than or equal to tensor’s rank.

Parameters
  • tensor – A Tensor to flatten.

  • preserve_dims (int) – The number of leading dimensions to preserve.

  • flattened_dim (int, optional) – The size of the resulting flattened dimension. If not given, infer automatically.

Returns

A Tensor with rank preserve_dims +1.

Example

x = torch.ones(d_1, d_2, d_3, d_4)
y = flatten(x, 2) # y.shape == [d_1, d_2, d_3 * d_4]

pad_and_concat

texar.torch.utils.pad_and_concat(values, axis, pad_axis=None, pad_constant_values=0)[source]

Concatenates tensors along one dimension. Pads each of other dimensions of the tensors to the corresponding maximum size if necessary.

Parameters
  • values – A list of Tensors of the same rank.

  • axis (int) – A Python int. Dimension along which to concatenate.

  • pad_axis (int or list, optional) – A Python int or a list of int. Dimensions to pad. Paddings are only added to the end of corresponding dimensions. If None, all dimensions except the axis dimension are padded.

  • pad_constant_values – The scalar pad value to use. Must be same type as the tensors.

Returns

A Tensor resulting from padding and concatenation of the input tensors.

Raises

ValueError – If rank of values are not consistent.

Example

a = torch.ones([1, 2])
b = torch.ones([2, 3])

c = pad_and_concat([a,b], 0)
# c.shape == [3, 3]
# c == [[1, 1, 0],
#       [1, 1, 1],
#       [1, 1, 1]]

d = pad_and_concat([a,b], 1)
# d.shape == [2, 5]
# d == [[1, 1, 1, 1, 1]
#       [0, 0, 1, 1, 1]]

Dictionary

dict_patch

texar.torch.utils.dict_patch(tgt_dict, src_dict)[source]

Recursively patch tgt_dict by adding items from src_dict that do not exist in tgt_dict.

If respective items in src_dict and tgt_dict are both dict, the tgt_dict item is patched recursively.

Parameters
  • tgt_dict (dict) – Target dictionary to patch.

  • src_dict (dict) – Source dictionary.

Returns

The new tgt_dict that is patched.

Return type

dict

dict_lookup

texar.torch.utils.dict_lookup(dict_, keys, default=None)[source]

Looks up keys in the dictionary, returns the corresponding values.

The default is used for keys not present in the dictionary.

Parameters
  • dict_ (dict) – A dictionary for lookup.

  • keys – A numpy array or a (possibly nested) list of keys.

  • default (optional) – Value to be returned when a key is not in dict_. Error is raised if default is not given and key is not in the dictionary.

Returns

A numpy array of values with the same structure as keys.

Raises

TypeError – If key is not in dict_ and default is None.

dict_fetch

texar.torch.utils.dict_fetch(src_dict, tgt_dict_or_keys)[source]

Fetches a sub-dictionary of src_dict with the keys in tgt_dict_or_keys.

Parameters
  • src_dict – A dictionary or instance of HParams. The source dictionary to fetch values from.

  • tgt_dict_or_keys – A dictionary, instance of HParams, or a list (or a dict_keys/KeysView) of keys to be included in the output dictionary.

Returns

A new dictionary that is a sub-dictionary of src_dict.

dict_pop

texar.torch.utils.dict_pop(dict_, pop_keys, default=None)[source]

Removes keys from a dictionary and returns their values.

Parameters
  • dict_ (dict) – A dictionary from which items are removed.

  • pop_keys – A key or a list of keys to remove and return respective values or default.

  • default (optional) – Value to be returned when a key is not in dict_. The default value is None.

Returns

A dict of the items removed from dict_.

flatten_dict

texar.torch.utils.flatten_dict(dict_, parent_key='', sep='.')[source]

Flattens a nested dictionary. Namedtuples within the dictionary are also converted to dictionaries.

Adapted from: https://github.com/google/seq2seq/blob/master/seq2seq/models/model_base.py

Parameters
  • dict_ (dict) – The dictionary to flatten.

  • parent_key (str) – A prefix to prepend to each key.

  • sep (str) – Separator that intervenes between parent and child keys. For example, if sep == ".", then { "a": { "b": 3 } } is converted into { "a.b": 3 }.

Returns

A new flattened dict.

String

strip_token

texar.torch.utils.strip_token(str_, token, is_token_list=False)[source]

Returns a copy of strings with leading and trailing tokens removed.

Note that besides token, all leading and trailing whitespace characters are also removed.

If is_token_list is False, then the function assumes tokens in str_ are separated with whitespace character.

Parameters
  • str_ – A str, or an n-D numpy array or (possibly nested) list of str.

  • token (str) – The token to strip, e.g., the "<PAD>" token defined in SpecialTokens.

  • is_token_list (bool) – Whether each sentence in str_ is a list of tokens. If False, each sentence in str_ is assumed to contain tokens separated with space character.

Returns

The stripped strings of the same structure/shape as str_.

Example

str_ = '<PAD> a sentence <PAD> <PAD>  '
str_stripped = strip_token(str_, '<PAD>')
# str_stripped == 'a sentence'

str_ = ['<PAD>', 'a', 'sentence', '<PAD>', '<PAD>', '', '']
str_stripped = strip_token(str_, '<PAD>', is_token_list=True)
# str_stripped == 'a sentence'

strip_eos

texar.torch.utils.strip_eos(str_, eos_token='<EOS>', is_token_list=False)[source]

Remove the EOS token and all subsequent tokens.

If is_token_list is False, then the function assumes tokens in str_ are separated with whitespace character.

Parameters
  • str_ – A str, or an n-D numpy array or (possibly nested) list of str.

  • eos_token (str) – The EOS token. Default is "<EOS>" as defined in SpecialTokens.EOS

  • is_token_list (bool) – Whether each sentence in str_ is a list of tokens. If False, each sentence in str_ is assumed to contain tokens separated with space character.

Returns

Strings of the same structure/shape as str_.

strip_special_tokens

texar.torch.utils.strip_special_tokens(str_, strip_pad='<PAD>', strip_bos='<BOS>', strip_eos='<EOS>', is_token_list=False)[source]

Removes special tokens in strings, including:

  • Removes EOS and all subsequent tokens

  • Removes leading and and trailing PAD tokens

  • Removes leading BOS tokens

Note that besides the special tokens, all leading and trailing whitespace characters are also removed.

This is a joint function of strip_eos(), strip_pad(), and strip_bos()

Parameters
  • str_ – A str, or an n-D numpy array or (possibly nested) list of str.

  • strip_pad (str) – The PAD token to strip from the strings (i.e., remove the leading and trailing PAD tokens of the strings). Default is "<PAD>" as defined in SpecialTokens.PAD. Set to None or False to disable the stripping.

  • strip_bos (str) – The BOS token to strip from the strings (i.e., remove the leading BOS tokens of the strings). Default is "<BOS>" as defined in SpecialTokens.BOS. Set to None or False to disable the stripping.

  • strip_eos (str) – The EOS token to strip from the strings (i.e., remove the EOS tokens and all subsequent tokens of the strings). Default is "<EOS>" as defined in SpecialTokens.EOS. Set to None or False to disable the stripping.

  • is_token_list (bool) – Whether each sentence in str_ is a list of tokens. If False, each sentence in str_ is assumed to contain tokens separated with space character.

Returns

Strings of the same shape of str_ with special tokens stripped.

str_join

texar.torch.utils.str_join(tokens, sep=' ')[source]

Concatenates tokens along the last dimension with intervening occurrences of sep.

Parameters
  • tokens – An n-D numpy array or (possibly nested) list of str.

  • sep (str) – The string intervening between the tokens.

Returns

An (n-1)-D numpy array (or list) of str.

default_str

texar.torch.utils.default_str(str_, default)[source]

Returns str_ if it is not None or empty, otherwise returns default_str.

Parameters
  • str_ – A string.

  • default – A string.

Returns

Either str_ or default_str.

uniquify_str

texar.torch.utils.uniquify_str(str_, str_set)[source]

Uniquifies str_ if str_ is included in str_set.

This is done by appending a number to str_. Returns str_ directly if it is not included in str_set.

Parameters
  • str_ (string) – A string to uniquify.

  • str_set (set, dict, or list) – A collection of strings. The returned string is guaranteed to be different from the elements in the collection.

Returns

The uniquified string. Returns str_ directly if it is already unique.

Example

print(uniquify_str('name', ['name', 'name_1']))
# 'name_2'

Meta

check_or_get_class

texar.torch.utils.check_or_get_class(class_or_name, module_paths=None, superclass=None)[source]

Returns the class and checks if the class inherits superclass.

Parameters
  • class_or_name – Name or full path to the class, or the class itself.

  • module_paths (list, optional) – Paths to candidate modules to search for the class. This is used if class_or_name is a string and the class cannot be located solely based on class_or_name. The first module in the list that contains the class is used.

  • superclass (optional) – A (list of) classes that the target class must inherit.

Returns

The target class.

Raises
  • ValueError – If class is not found based on class_or_name and module_paths.

  • TypeError – If class does not inherits superclass.

get_class

texar.torch.utils.get_class(class_name, module_paths=None)[source]

Returns the class based on class name.

Parameters
  • class_name (str) – Name or full path to the class.

  • module_paths (list) – Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on class_name. The first module in the list that contains the class is used.

Returns

The target class.

Raises

ValueError – If class is not found based on class_name and module_paths.

check_or_get_instance

texar.torch.utils.check_or_get_instance(ins_or_class_or_name, kwargs, module_paths=None, classtype=None)[source]

Returns a class instance and checks types.

Parameters
  • ins_or_class_or_name

    Can be of 3 types:

    • A class to instantiate.

    • A string of the name or full path to a class to instantiate.

    • The class instance to check types.

  • kwargs (dict) – Keyword arguments for the class constructor. Ignored if ins_or_class_or_name is a class instance.

  • module_paths (list, optional) – Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on class_name. The first module in the list that contains the class is used.

  • classtype (optional) – A (list of) class of which the instance must be an instantiation.

Raises
  • ValueError – If class is not found based on class_name and module_paths.

  • ValueError – If kwargs contains arguments that are invalid for the class construction.

  • TypeError – If the instance is not an instantiation of classtype.

get_instance

texar.torch.utils.get_instance(class_or_name, kwargs, module_paths=None)[source]

Creates a class instance.

Parameters
  • class_or_name – A class, or its name or full path to a class to instantiate.

  • kwargs (dict) – Keyword arguments for the class constructor.

  • module_paths (list, optional) – Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on class_name. The first module in the list that contains the class is used.

Returns

A class instance.

Raises
  • ValueError – If class is not found based on class_or_name and module_paths.

  • ValueError – If kwargs contains arguments that are invalid for the class construction.

check_or_get_instance_with_redundant_kwargs

texar.torch.utils.check_or_get_instance_with_redundant_kwargs(ins_or_class_or_name, kwargs, module_paths=None, classtype=None)[source]

Returns a class instance and checks types.

Only those keyword arguments in kwargs that are included in the class construction method are used.

Parameters
  • ins_or_class_or_name

    Can be of 3 types:

    • A class to instantiate.

    • A string of the name or module path to a class to instantiate.

    • The class instance to check types.

  • kwargs (dict) – Keyword arguments for the class constructor.

  • module_paths (list, optional) – Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on class_name. The first module in the list that contains the class is used.

  • classtype (optional) – A (list of) classes of which the instance must be an instantiation.

Raises
  • ValueError – If class is not found based on class_name and module_paths.

  • ValueError – If kwargs contains arguments that are invalid for the class construction.

  • TypeError – If the instance is not an instantiation of classtype.

get_instance_with_redundant_kwargs

texar.torch.utils.get_instance_with_redundant_kwargs(class_name, kwargs, module_paths=None)[source]

Creates a class instance.

Only those keyword arguments in kwargs that are included in the class construction method are used.

Parameters
  • class_name (str) – A class or its name or module path.

  • kwargs (dict) – A dictionary of arguments for the class constructor. It may include invalid arguments which will be ignored.

  • module_paths (list of str) – A list of paths to candidate modules to search for the class. This is used if the class cannot be located solely based on class_name. The first module in the list that contains the class is used.

Returns

A class instance.

Raises

ValueError – If class is not found based on class_name and module_paths.

get_function

texar.torch.utils.get_function(fn_or_name, module_paths=None)[source]

Returns the function of specified name and module.

Parameters
  • fn_or_name (str or callable) – Name or full path to a function, or the function itself.

  • module_paths (list, optional) – A list of paths to candidate modules to search for the function. This is used only when the function cannot be located solely based on fn_or_name. The first module in the list that contains the function is used.

Returns

A function.

Raises

ValueError – If method with name as fn_or_name is not found.

call_function_with_redundant_kwargs

texar.torch.utils.call_function_with_redundant_kwargs(fn, kwargs)[source]

Calls a function and returns the results.

Only those keyword arguments in kwargs that are included in the function’s argument list are used to call the function.

Parameters
  • fn (function) – A callable. If fn is not a python function, fn.__call__ is called.

  • kwargs (dict) – A dict of arguments for the callable. It may include invalid arguments which will be ignored.

Returns

The returned results by calling fn.

get_args

texar.torch.utils.get_args(fn)[source]

Gets the arguments of a function.

Parameters

fn (callable) – The function to inspect.

Returns

A list of argument names (str) of the function.

Return type

list

get_default_arg_values

texar.torch.utils.get_default_arg_values(fn)[source]

Gets the arguments and respective default values of a function.

Only arguments with default values are included in the output dictionary.

Parameters

fn (callable) – The function to inspect.

Returns

A dictionary that maps argument names (str) to their default values. The dictionary is empty if no arguments have default values.

Return type

dict

get_instance_kwargs

texar.torch.utils.get_instance_kwargs(kwargs, hparams)[source]

Makes a dictionary of keyword arguments with the following structure:

kwargs_ = {'hparams': dict(hparams), **kwargs}.

This is typically used for constructing a module which takes a set of arguments as well as a argument named "hparams".

Parameters
  • kwargs (dict) – A dict of keyword arguments. Can be None.

  • hparams – A dict or an instance of HParams. Can be None.

Returns

A dict that contains the keyword arguments in kwargs, and an additional keyword argument named "hparams".

Misc

ceildiv

texar.torch.utils.ceildiv(a, b)[source]

Compute division with results rounding up.

For example, 5 / 2 = 2.5, ceildiv(5, 2) = 3.

Parameters
  • a (int) – The dividend.

  • b (int) – The divisor.

Returns

The quotient, rounded up.

Return type

int

map_structure

texar.torch.utils.map_structure(fn, obj)[source]

Map a function over all elements in a (possibly nested) collection.

Parameters
  • fn (callable) – The function to call on elements.

  • obj – The collection to map function over.

Returns

The collection in the same structure, with elements mapped.

map_structure_zip

texar.torch.utils.map_structure_zip(fn, objs)[source]

Map a function over tuples formed by taking one elements from each (possibly nested) collection. Each collection must have identical structures.

Note

Although identical structures are required, it is not enforced by assertions. The structure of the first collection is assumed to be the structure for all collections.

For rare cases where collections need to have different structures, refer to no_map().

Parameters
  • fn (callable) – The function to call on elements.

  • objs – The list of collections to map function over.

Returns

A collection with the same structure, with elements mapped.

flatten

texar.torch.utils.nest.flatten(structure)[source]

Returns a flat list from a given nested structure. If nest is not a sequence, tuple, or dict, then returns a single-element list:[nest]. In the case of dict instances, the sequence consists of the values, sorted by key to ensure deterministic behavior. This is true also for OrderedDict instances: their sequence order is ignored, the sorting order of keys is used instead. The same convention is followed in pack_sequence_as(). This correctly repacks dictionaries and OrderedDict`s after they have been flattened, and also allows flattening an `OrderedDict and then repacking it back using a corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys cannot be flattened. Users must not modify any collections used in nest while this function is running.

Parameters

structure – an arbitrarily nested structure or a scalar object. Note, numpy arrays are considered scalars.

Returns

A Python list, the flattened version of the input.

Raises

TypeError – The nest is or contains a dict with non-sortable keys.

pack_sequence_as

texar.torch.utils.nest.pack_sequence_as(structure, flat_sequence)[source]

Returns a given flattened sequence packed into a given structure. If structure is a scalar, flat_sequence must be a single-element list; in this case the return value is flat_sequence[0]. If structure is or contains a dict instance, the keys will be sorted to pack the flat sequence in deterministic order. This is true also for OrderedDict instances: their sequence order is ignored, the sorting order of keys is used instead. The same convention is followed in flatten(). This correctly repacks dictionaries and OrderedDicts after they have been flattened, and also allows flattening an OrderedDict and then repacking it back using a corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys cannot be flattened.

Parameters
  • structure – Nested structure, whose structure is given by nested lists, tuples, and dictionaries. Note: numpy arrays and strings are considered scalars.

  • flat_sequence – flat sequence to pack.

Returns

flat_sequence converted to have the same recursive structure as structure.

Return type

packed

Raises
  • ValueError – If flat_sequence and structure have different element counts.

  • TypeErrorstructure is or contains a dict with non-sortable keys.

AverageRecorder

class texar.torch.utils.AverageRecorder(size=None)[source]

Maintains the moving averages (i.e., the average of the latest N records) of (possibly multiple) fields.

Fields are determined by the first call of add().

Parameters

size (int, optional) – The window size of moving average. If None, the average of all added records is maintained.

Example

## Use to maintain moving average of training loss
avg_rec = AverageRecorder(size=10) # average over latest 10 records
while training:
    loss_0, loss_1  = ...
    avg_rec.add([loss_0, loss_1])
    # avg_rec.avg() == [0.12343452, 0.567800323]
    # avg_rec.avg(0) == 0.12343452
    # avg_rec.to_str(precision=2, ) == '0.12 0.57'

## Use to maintain average of test metrics on the whole test set
avg_rec = AverageRecorder() # average over ALL records
while test:
    metric_0, metric_1  = ...
    avg_rec.add({'m0': metric_0, 'm1': metric_1}) # dict is allowed
print(avg_rec.to_str(precision=4, delimiter=' , '))
# 'm0: 0.1234 , m1: 0.5678'
#
# avg_rec.avg() == {'m0': 0.12343452, 'm1': 0.567800323}
# avg_rec.avg(0) == 0.12343452
add(record, weight=None)[source]

Appends a new record.

record can be a list, dict, or a single scalar. The record type is determined at the first time add() is called. All subsequent calls to add() must have the same type of record.

record in subsequent calls to add() can contain only a subset of fields than the first call to add().

Example

recorder.add({'1': 0.2, '2': 0.2}) # 1st call to `add`
x = recorder.add({'1': 0.4}) # 2nd call to `add`
# x == {'1': 0.3, '2': 0.2}
Parameters
  • record – A single scalar, a list of scalars, or a dict of scalars.

  • weight (optional) – A scalar, weight of the new record for calculating a weighted average. If None, weight is set to 1. For example, weight can be set to batch size and record the average value of certain metrics on the batch in order to calculate the average metric values on a whole dataset.

Returns

The (moving) average after appending the record, with the same type as record.

avg(id_or_name=None)[source]

Returns the (moving) average.

Parameters

id_or_name (optional) – A list of or a single element. Each element is the index (if the record type is list) or name (if the record type is dict) of the field for which the average is calculated. If not given, the average of all fields are returned.

Returns

The average value(s). If id_or_name is a single element (not a list), then returns the average value of the corresponding field. Otherwise, if id_or_name is a list of element(s), then returns average value(s) in the same type as record of add().

reset(id_or_name=None)[source]

Resets the record.

Parameters

id_or_name (optional) – A list or a single element. Each element is the index (if the record type is list) or name (if the record type is dict) of the field to reset. If None, all fields are reset.

to_str(precision=None, delimiter=' ')[source]

Returns a string of the average values of the records.

Parameters
  • precision (int, optional) – The number of decimal places to keep in the returned string. For example, for an average value of 0.1234, precision = 2 leads to "0.12".

  • delimiter (str) – The delimiter string that separates between fields.

Returns

A string of the average values.

If record is of type dict, the string is a concatenation of "field_name: average_value", delimited with delimiter. For example, "field_name_1: 0.1234 field_name_2: 0.5678 ...".

Otherwise, the string is of a concatenation of ‘average_value’. For example, "0.1234 0.5678 ..."