Executor

Executor

Executor

class texar.torch.run.Executor(model: torch.nn.modules.module.Module, *, train_data: Optional[texar.torch.data.data.data_base.DatasetBase] = None, valid_data: Optional[texar.torch.data.data.data_base.DatasetBase] = None, test_data: Union[texar.torch.data.data.data_base.DatasetBase, Sequence[Union[texar.torch.data.data.data_base.DatasetBase, Tuple[str, texar.torch.data.data.data_base.DatasetBase]]], Mapping[str, texar.torch.data.data.data_base.DatasetBase], None] = None, batching_strategy: Optional[texar.torch.data.data.sampler.BatchingStrategy] = None, device: Optional[torch.device] = None, tbx_logging_dir: Optional[str] = None, tbx_log_every: Optional[texar.torch.run.condition.Condition] = None, checkpoint_dir: Optional[str] = None, max_to_keep: Optional[int] = None, save_every: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, save_training_state: bool = True, train_metrics: Union[texar.torch.run.metric.base_metric.Metric, Sequence[Union[texar.torch.run.metric.base_metric.Metric, Tuple[str, texar.torch.run.metric.base_metric.Metric]]], Mapping[str, texar.torch.run.metric.base_metric.Metric], None] = None, optimizer: Union[torch.optim.optimizer.Optimizer, Dict[str, Any], None] = None, lr_scheduler: Union[torch.optim.lr_scheduler._LRScheduler, Dict[str, Any], None] = None, stop_training_on: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, num_iters_per_update: int = 1, grad_clip: Optional[float] = None, valid_metrics: Union[texar.torch.run.metric.base_metric.Metric, Sequence[Union[texar.torch.run.metric.base_metric.Metric, Tuple[str, texar.torch.run.metric.base_metric.Metric]]], Mapping[str, texar.torch.run.metric.base_metric.Metric], None] = None, validate_every: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, plateau_condition: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, action_on_plateau: Union[Callable[[Executor], None], Sequence[Callable[[Executor], None]], None] = None, validate_mode: str = 'eval', test_metrics: Union[texar.torch.run.metric.base_metric.Metric, Sequence[Union[texar.torch.run.metric.base_metric.Metric, Tuple[str, texar.torch.run.metric.base_metric.Metric]]], Mapping[str, texar.torch.run.metric.base_metric.Metric], None] = None, test_every: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, test_mode: str = 'predict', log_every: Union[texar.torch.run.condition.Condition, Sequence[texar.torch.run.condition.Condition], None] = None, log_format: Optional[str] = None, log_destination: Union[str, pathlib.Path, IO[str], Sequence[Union[str, pathlib.Path, IO[str]]], None] = None, print_model_arch: bool = True, valid_log_format: Optional[str] = None, test_log_format: Optional[str] = None, valid_progress_log_format: Optional[str] = None, test_progress_log_format: Optional[str] = None, show_live_progress: Union[bool, str, List[str]] = False)[source]

Executor is a substitute for the general training/evaluation loop. It is designed with the following goals in mind:

  1. Minimize the amount of boilerplate code that is essentially the same across all experiments.
  2. Provide best practices and hide hideous details from the user.
  3. Guarantee reproducability (runs with same configuration & seed always produces the same result) and portability (same code runs whether using GPU or not).
  4. Meanwhile, allowing flexible configurations and support user-overridden behaviors.

Example

Here is a realistic training loop example using Executor, showcasing the features built in. Executor takes care of common training procedures including logging, checkpoint management, evaluation metrics, validation, and patience-based early-stopping.

from texar.torch.run import *

executor = Executor(
    model=model,
    train_data=datasets["train"],
    valid_data=datasets["dev"],
    test_data=datasets["test"],
    checkpoint_dir=args.save_dir,
    save_every=cond.validation(better=True),
    train_metrics=("loss", metric.RunningAverage(args.display_steps)),
    optimizer={"type": torch.optim.Adam},
    grad_clip=args.grad_clip,
    log_every=cond.iteration(args.display_steps),
    validate_every=cond.epoch(1),
    valid_metrics=[
        metric.PearsonR(pred_name="preds"),
        ("loss", metric.Average())],
    plateau_condition=[
        cond.consecutive(cond.validation(better=False), 2)],
    action_on_plateau=[
        action.early_stop(patience=2),
        action.reset_params(),
        action.scale_lr(0.8)],
    stop_training_on=cond.iteration(args.max_train_steps),
    test_mode='eval',
)
executor.train()

Concepts

To make full use of Executor, you’ll need to understand the following concepts:

  • Event: Events are a set of pre-defined time spans within the training and evaluation loop. Common events include a training iteration (Event.Iteration), an epoch (Event.Epoch), or an entire validation (Event.Validation). Please refer to the enum class Event for the full list of events. The beginning and end of events are called event points.

  • Condition: Conditions are checks performed at the beginning or end of events. If a check passes, we say that the condition “triggers”. To give a few examples:

    • cond.iteration(num_iters) is checked only at the end of training iterations, and the check passes when the number of iterations passed equals num_iters. Thus, the condition triggers at the end of every num_iters iterations.
    • cond.validation(better=True) is checked only at the end of validations, and the check passes if the current validation result is better than the previous best. Thus, the condition triggers whenever a finished validation yields an improved result.
    • cond.time is checked at the end of iterations, and at the beginning and end of training, validation, and testing. It checks whether the elapsed time has reached the specified duration, and triggers when it does.

    Custom conditions must subclass the Condition class.

  • Action: Actions are special callback functions that are called either at the beginning or end of an event (actions registered by on_event()), or when conditions trigger (actions registered by on(), and by specifying save_every and similar arguments). This is similar to hook functions that exists in common frameworks. Actions take a single argument – the Executor instance itself, and can perform any operations within.

    Custom actions can be simple functions, or subclass the Action class.

  • Metric: Metrics are used to evaluate the output of models. They are categorized into two classes: SimpleMetric and StreamingMetric. The only difference is that streaming metrics support incremental computation of metric values, so they can be used to aggregate results over the training set, or provide intermediate results on-the-fly.

    Custom metrics must subclass one of the above classes.

Customization

You can easily extend the Executor class by subclassing it and overriding methods. Methods of interest include:

  • _train_step(): Perform a single step of training (i.e. process a single batch, call backward(), and potentially call optimizer updates). Takes the data batch as argument.
  • _validate_step(): Perform a single step of validation. Takes the data batch as argument.
  • _test_step(): Perform a single step of testing. Takes the data batch as argument.
  • _train_loop(): Runs the entire training loop. Takes the data iterator as argument.
  • _train_loop(): Runs the entire validation loop. Takes the data iterator as argument.
  • _test_loop(): Runs the entire testing loop. Takes the data iterator as argument.

You can also define custom events by writing a new enum class and modifying the _EVENT_TYPES attribute. Event points can be signaled by calling _fire_event(). For example:

class GANEvent(Enum):
    DiscriminatorUpdate = auto()
    GeneratorUpdate = auto()

class GANExecutor(Executor):
    _EVENT_TYPES = (Event, GANEvent)

    def __init__(self, *args, optimizer_g, optimizer_d, **kwargs):
        kwargs["optimizer"] = {
            "g": optimizer_g,
            "d": optimizer_d,
        }
        super.__init__(*args, **kwargs)

    def _train_step(self, batch):
        self._fire_event(GANEvent.GeneratorUpdate, False)
        z = torch.randn(len(batch), args.latent_dim)
        fake_image = self.model.generator(z)
        logits = self.model.discriminator(fake_image)
        g_loss = F.binary_cross_entropy(logits, torch.ones(len(batch))
        g_loss.backward()
        self.optimizer["g"].step()
        self.optimizer["g"].zero_grad()
        self._fire_event(GANEvent.GeneratorUpdate, True)

        self._fire_event(GANEvent.DiscriminatorUpdate, False)
        real_logits = self.model.discriminator(batch.image)
        fake_logits = self.model.discriminator(fake_image.detach())
        real_loss = F.binary_cross_entropy(real_logits, torch.ones(len(batch)))
        fake_loss = F.binary_cross_entropy(fake_logits, torch.zeros(len(batch)))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        self.optimizer["d"].step()
        self.optimizer["d"].zero_grad()
        self._fire_event(GANEvent.DiscriminatorUpdate, True)

        return {"g_loss": g_loss, "d_loss": d_loss}

Arguments

The constructor of Executor takes many arguments, almost all of which are keyword-only and optional. Some arguments can take values of multiple types, either a single instance of a specific type, or a list or dictionary of values of that type.

Arguments grouped by functions:

General arguments:

model: torch.nn.Module

The model to train or evaluate. The model must be a subclass of torch.nn.Module, with its forward() method taking a single argument batch and returning a dictionary of torch.Tensors:

  • The batch argument is of type Batch, which is the batch object produced by your provided dataset.
  • The returned dictionary must contain an entry named "loss", which will be used as the loss to backpropagate during training. You can also include other values and use them in metrics.

If the model performs different routines during training and evaluation (for instance, a sequence-to-sequence model may train using teacher-forcing but evaluate using beam search-based inference), you can define another method predict() following the same signature as forward(). To use predict() instead of forward() in validation or testing, set validate_mode or test_mode to "predict" instead of "eval".

If the model you use does not follow this convention, you will need to wrap the model in a new class. The following example demonstrates how to wrap XLNetRegressor:

class RegressorWrapper(tx.modules.XLNetRegressor):
    def forward(self, batch):
        preds = super().forward(token_ids=batch.input_ids,
                                segment_ids=batch.segment_ids,
                                input_mask=batch.input_mask)
        loss = (preds - batch.label_ids) ** 2
        loss = loss.sum() / len(batch)
        return {"loss": loss, "preds": preds}
train_data: DatasetBase
The dataset used during training. Must be specified for training.
valid_data: DatasetBase
The dataset used during validation. If not specified, you cannot perform validation during training (e.g., by setting validate_every).
test_data: DatasetBase, or a list or dictionary
The dataset(s) used during testing. If not specified, you cannot perform testing during training (e.g., by setting test_every).
batching_strategy: BatchingStrategy
The batching strategy to use for batching. This will be passed as the batching_strategy argument for DataIterator during training, and evaluation if the corresponding mode is set to "eval".
validate_mode: str
The evaluation mode for validation. Available choices are "eval" and "predict". Defaults to "eval". When mode is set to "eval", forward() method of the model will be called; when set to "predict", predict() method of the model will be called.
test_mode: str
The evaluation mode for testing. Available choices are "eval" and "predict". Defaults to "predict".
device: torch.device
The device on which the model and data should be placed. Defaults to None, in which case GPUs will be used if available.

Arguments for tensorboard logging:

tbx_logging_dir: str
Path to the directory for storing tensorboard logs.
tbx_log_every: Condition
Conditions that, when triggered, saves the tensorboard logs for train metrics. If None, log_every will be used.

Arguments for checkpoint management:

checkpoint_dir: str
Path to the directory for storing checkpoints. If not specified, you cannot save/load the model during training (e.g., by setting save_every or using reset_params).
max_to_keep: int

The maximum number of checkpoints to keep in the checkpoint directory. When the number of checkpoints exceed this limit, the oldest one will be removed. If None, no such limit is imposed. Defaults to None.

Note

Be careful when saving periodic snapshots along with the best performing checkpoint. Periodic snapshots might overwrite best models if checkpoint limit is exceeded.

A better workaround is to only save periodic snapshots with the built-in mechanism, and register a custom action for saving a single best performing checkpoint:

# Don't
executor = Executor(
    save_every=[cond.epoch(1), cond.validation(better=True)],
    max_to_keep=3, ...)

# Do
executor = Executor(
    save_every=cond.epoch(1),
    max_to_keep=3, ...)

@executor.on(cond.validation(better=True))
def save_best_model(executor):
    executor.save(path=another_directory, max_to_keep=1)
save_every: Condition, or a list

Conditions that, when triggered, saves the model.

In the following example, the model will be saved every 1000 iterations, or whenever validation results improve.

save_every=[cond.validation(better=True), cond.iteration(1000)]
save_training_state: bool
If False, only save model parameters in checkpoint. If True, also save optimizer and scheduler states, along with random number generator states from Python, NumPy, and PyTorch. Defaults to True.

Arguments for training:

train_metrics: Metric, or a list or dictionary

The metrics computed over the training set. In case of multiple metrics, two sets of metric values will be compared in the provided order.

For example, if two metrics f1 (F1) and loss (Average) are defined (in this order), when comparing two sets of values, the one with a higher f1 is considered better. If the two sets have the same f1 value, the one with a lower loss is considered better.

Acceptable values include:

  • A single Metric, or a list of Metrics. These metrics will be automatically named when they’re logged.
  • A tuple of (str, Metric), or a list of this. These metrics are explicitly named according to the provided strings. Note that names must be unique, and should be valid identifier names.
  • An OrderedDict mapping names to metrics. Note that a plain (unordered) dictionary is not accepted.

Note

Metrics that are logged will be evaluated once every time logging is performed. For efficiency considerations, such metrics should be StreamingMetrics. Please take extra care when implementing your own metrics.

optimizer: torch.optim.Optimizer, or a dictionary of hyperparameters
The optimizer used during training. This can be a torch.optim.Optimizer instance, or a dictionary of hyperparameters that will be passed into texar.torch.utils.get_instance(). Must be specified for training.
lr_scheduler: LRScheduler, or a dictionary of hyperparameters, optional
The learning rate scheduler used during training. This can be an LRScheduler instance, or a dictionary of hyperparameters that will be passed into texar.torch.utils.get_instance().
stop_training_on: Condition, or a list

Conditions that, when triggered, will stop training.

In the following example, training will be terminated after 5 epochs or 20000 iterations, whichever comes first:

stop_training_on=[cond.epoch(5), cond.iteration(20000)]
num_iters_per_update: int

Number of iterations to run before performing a parameter update. When this value is greater than 1, the loss is scaled by its reciprocal. Defaults to 1, in which case the parameters are updated after each .backward() call.

This can be used to accumulate gradients across multiple batches, in order to simulate the effect of using a large batch size on a machine with limited memory.

grad_clip: float, optional
Maximum norm of the gradients. Please refer to nn.utils.clip_grad_norm_ for details. Defaults to None, i.e. no clipping.

Arguments for validation:

valid_metrics: Metric, or a list or dictionary
The metrics computed over the validation set. Please see train_metrics for details.
validate_every: Condition, or a list

Conditions that, when triggered, performs validation.

In the following example, the model will be validated once per epoch.

validate_every=cond.epoch(1)
plateau_condition: Condition, or a list

Conditions that, when triggered, indicates that training has reached a plateau, i.e., the model has stopped improving.

In the following example, we consider that training has reached a plateau if validation metrics have not improved for 3 consecutive validations.

plateau_condition=cond.consecutive(cond.validation(better=False))
action_on_plateau: Action, or a list

Actions that will be called when training has reached a plateau.

In the following example, we perform patience-based early-stopping when reaching plateaus. A patience of 2 means training will be terminated after plateau is reached twice. We also scale the learning rate by 0.8, and reset the model & optimizer parameters to the previous best checkpoint.

action_on_plateau=[
    action.reset_params(), action.scale_lr(0.8),
    action.early_stop(patience=2)]

Arguments for testing:

test_metrics: Metric, or a list or dictionary

The metrics computed over the test set. Please see train_metrics for details. test() can only be called if test_metrics is not None.

Note

valid_metrics will be automatically shared with test_metrics if:

  1. test_metrics is None;
  2. validate_mode is the same as test_mode.

In this case, calling validation while testing (or vice versa) could cause incorrect results since the same Metric instances are used.

test_every: Condition, or a list

Conditions that, when triggered, performs testing.

In the following example, the model will be tested whenever validation results improve.

test_every=cond.validation(better=True)

Arguments for logging:

log_every: Condition, or a list

Conditions that, when triggered, performs logging.

In the following example, a log will be printed every 100 iterations, and after every epoch.

log_every=[cond.iteration(100), cond.epoch()]
log_destination: str, IO object, or a list

Logging destinations. Acceptable values include:

  • An IO object. This can be an opened file, sys.stdout, or any other file-like object.
  • A string, denoting the path to a log file. Executor will open the file and close it when the program exits. The file will be opened in “append” ("a") mode to prevent accidentally overwriting previous logs. To force overwrite, supply the file object instead.
  • A list, with each element being one of the above.

When writing to a file, special syntax for terminals (e.g. color codes) are emitted. Also, live progress is not written to files.

By default, the log is only written to sys.stdout.

log_format: str

The format string for logs during training.

The format string follows the syntax of Python format strings. The status variables you can reference include:

  • epoch (int): The current epoch.

  • iteration (int): The current iteration.

  • progress (float): The epoch progress represented in percentage, i.e. a floating-point number between 0 and 100. It should be noted that progress may not be accurate, and may not be available if the data is loaded lazily.

  • speed (float): Average number of data examples processed per second. It should be noted that speed may not be accurate.

  • time: The current date and time. Time format can be set using the “format spec” syntax of Python format strings, i.e.: {time:%H:%M:%S} prints time in the format 08:26:03. If time format is not specified, it is equivalent to {time:%Y-%m-%d %H:%M:%S}, which corresponds to the format 2018-07-06 08:26:03. For more information on time formatting, please refer to documentation for Python built-in function date.strftime().

  • metric: An aggregated representation of all metrics, in the format of <name1>: <value1>, <name2>: <value2>, .... The format spec for metric will be applied to all metrics whose value supports such format spec. For more fine grained control over metric formatting, use the following methods.

  • metric.<name>: Value of the metric under the specified name <name>.

  • <name>: Value of the metric under the specified name <name>.

    Note

    The metric can only be looked-up if its name does not coincide with built-in status variables. For instance, a metric named “loss” can be looked-up by loss and metric.loss, but a metric named “time” can only be looked-up by metric.time.

The default format string is:

which produces logs similar to:

valid_log_format: str

The format string for logs when validation completes. Please refer to log_format for details of the format string.

The default format string is:

which produces logs similar to:

test_log_format: str, optional

The format string for logs when testing completes. Please refer to log_format for details of the format string.

If None, the format string for validation (valid_log_format) is used. Defaults to None.

valid_progress_log_format: str

The format string for logs during validation if live progress is enabled. Please refer to log_format for details of the format string.

The default format string is:

which produces logs similar to:

test_progress_log_format: str

The format string for logs during testing if live progress is enabled. Please refer to log_format for details of the format string.

If None, the format string for validation (valid_progress_log_format) is used. Defaults to None.

print_model_arch: bool
If True, model architecture and will logged in a readable format. Defaults to True.
show_live_progress: bool or str

Controls whether live progress will be shown. Acceptable values include:

  • True: Live progress is enabled for training, validation, and testing.
  • False: Live progress is disabled.
  • "train", "valid", "test", or a list of these strings: Live progress is enabled for the specified stages only.

If live progress is enabled for a certain stage, the specified format string will be shown similar to a sticky progress bar at the bottom of the terminal window, and updated after each iteration.

Note that live progress is only shown on terminals. It will not be printed to log files.

Warning

This may incur extra overhead because an update requires re-evaluating metrics. Make sure that all metrics logged are StreamingMetrics. You can also explicitly log only streaming metrics, or disable live progress for certain stages.

write_log(log_str: str, *, mode: str = 'info', skip_tty: bool = False, skip_non_tty: bool = False) → None[source]

Write a string to log.

Parameters:
  • log_str (str) – The string to log.
  • mode (str) –

    The logging mode. Supported values are:

    • "log": A plain log. Logs created by log_every and eval_log_every are in this format.
    • "info": Indicates a result or notification. Includes a header with "INFO" and a timestamp in green color. This is the default mode.
    • "warning": Indicates an unexpected or incorrect situation. Includes a header with "WARNING" and a timestamp. The entire warning is in red color.
  • skip_tty – If True, the log string will not be written to terminal log destinations. Defaults to False.
  • skip_non_tty – If True, the log string will not be written to non-terminal (e.g., files) destinations. Defaults to False.
on(cond: texar.torch.run.condition.Condition, func=None)[source]

Register a function as an action triggered on a condition. For example:

executor = Executor(...)

@executor.on(cond.iteration(10, mode="valid"))
def log_during_validation(executor):
    logging.info("Validation iter %d", executor.status["iteration"])

The action function takes exactly one argument: the executor instance itself.

Parameters:
  • cond – The condition that will call the function when triggered. Must be of type Condition.
  • func (optional) – The function to register. If not None, the function will be registered; if None, a decorator will be returned.
Returns:

  • If func is None, this method will return a decorator to wrap around the function to register as hook.
  • If func is not None, this method will return the Executor itself, allowing chained calls.

on_event(event, point='end', func=None)[source]

Register a function as an action triggered on an event point. For example:

executor = Executor(...)

@executor.on_event(Event.Epoch, 'end')
def log_at_end_of_epoch(executor):
    logging.info("Epoch %d done", executor.status["epoch"])

The action function takes exactly one argument: the executor instance itself.

Parameters:
  • event – The event to hook on. Must be an enum value from Event.
  • point (str) – The point of event to hook on. Supported values are "begin" and "end". Defaults to "end".
  • func (optional) – The function to register. If not None, the function will be registered; if None, a decorator will be returned.
Returns:

  • If func is None, this method will return a decorator to wrap around the function to register as hook.
  • If func is not None, this method will return the Executor itself, allowing chained calls.

save(path: Optional[str] = None, save_training_state: Optional[bool] = None)[source]

Save a snapshot of the current state to a checkpoint file.

Parameters:
  • path (str, optional) – Path to the checkpoint directory. If None, checkpoint_dir in the constructor arguments will be used. Defaults to None.
  • save_training_state (bool) – If True, will save entire training state from checkpoint. If False, only save model weights. If None, the value from the constructor arguments will be used. Defaults to None.
load(path: Optional[str] = None, load_training_state: bool = True, allow_failure: bool = False) → Optional[pathlib.Path][source]

Load a previous model checkpoint from file.

Parameters:
  • path (str, optional) – Path to a specific checkpoint or a checkpoint directory. If a directory is specified, the most recent checkpoint in the directory is loaded. If None, checkpoint_dir in the constructor arguments will be used. Defaults to None.
  • load_training_state (bool) – If True, will load entire training state from checkpoint (if the checkpoint contains training state). Otherwise, just load model weights. Defaults to True.
  • allow_failure (bool) – If True, no exceptions will be raised if no checkpoints were found. Defaults to False. Note that exceptions are still raised if the provided path does not exist, or the selected checkpoint is corrupted.
Returns:

Path of the loaded checkpoint, or None if load failed.

terminate() → None[source]

Terminate training. This method is intended to be called within actions. An example use case would be to implement a custom early-stopping mechanism.

It is guaranteed that no other event points will be fired once terminate() is called. However, conditions and actions under the same event point is still called.

remove_action() → None[source]

Remove the current action being run. This method is intended to be called within actions. An example use case would be to implement an action that is only run once at a certain event point.

train()[source]

Start the training loop.

test(dataset: Union[texar.torch.data.data.data_base.DatasetBase, Sequence[Union[texar.torch.data.data.data_base.DatasetBase, Tuple[str, texar.torch.data.data.data_base.DatasetBase]]], Mapping[str, texar.torch.data.data.data_base.DatasetBase], None] = None)[source]

Start the test loop.

Parameters:dataset (optional) –

The dataset(s) to test on. Acceptable values include:

If None, test_data from the constructor arguments is used. Defaults to None.

_write_log(log_str: str, skip_tty: bool = False, skip_non_tty: bool = False, newline: bool = True, clear_line: bool = False)[source]

Write a string to log.

Parameters:
  • log_str (str) – The string to log.
  • skip_tty – If True, the log string will not be written to terminal log destinations. Defaults to False.
  • skip_non_tty – If True, the log string will not be written to non-terminal (e.g., files) destinations. Defaults to False.
  • newline – If True, print a newline character after printing the log string. Defaults to True.
  • clear_line – If True, clear the line before printing. This only works for terminals. Defaults to False.
_create_logging_fn(format_str: str, metrics: OrderedDict[str, Metric], tracker: texar.torch.run.executor_utils.ProgressTracker, warn_non_streaming_metric: bool = False) → Callable[[Executor], str][source]

Given a logging format string, create a function that takes the executor instance as argument and returns the logging string.

Parameters:
  • format_str (str) – The logging format string.
  • metrics – The metrics dictionary that will be used in the logging hook function.
  • tracker – The progress tracker that will be used in the logging hook function.
  • warn_non_streaming_metric (bool) – If True, will issue a warning if any of the provided metric is not a StreamingMetric. Defaults to False. This is set to True when the logging string is used as the status line.
Returns:

A hook function to print logs given the format string. Note that the hook function can accept additional arguments to pass to executor.write_log(), allowing it to be used in combination with functools.partial().

_fire_event(event: texar.torch.run.condition.Event, end: bool)[source]

Signal the beginning or end of an event. Internally, this is where conditions are checked and actions are executed.

Parameters:
  • event – The Event to fire.
  • end – If True, the fired event point is the end of event. If False, the fired event point is the beginning of event.
Raises:

If any triggered action calls terminate(), ExecutorTerminateSignal is thrown after all conditions are checked and actions executed.

_validate_step(batch: texar.torch.data.data.dataset_utils.Batch)[source]

Perform one step of validation, i.e., perform a forward pass (or decoding, depending on validate_mode) for a single batch.

Parameters:batch – The batch to validate on.
Returns:The dictionary containing values returned by the model. This is used to compute metrics.
_test_step(batch: texar.torch.data.data.dataset_utils.Batch)[source]

Perform one step of testing, i.e., perform a forward pass (or decoding, depending on test_mode) for a single batch.

Parameters:batch – The batch to test on.
Returns:The dictionary containing values returned by the model. This is used to compute metrics.
_train_step(batch: texar.torch.data.data.dataset_utils.Batch)[source]

Perform one step of training, i.e., perform a forward and backward pass for a single batch. Parameter updates should also be performed when necessary.

Parameters:batch – The batch to train on.
Returns:The dictionary containing values returned by the model. This is used to compute metrics.
_train_loop(iterator: texar.torch.data.data.data_iterators.DataIterator) → None[source]

Run the entire training loop given the data iterator.

Parameters:iterator – The iterator over the training data.
_validate_loop(iterator: texar.torch.data.data.data_iterators.DataIterator) → None[source]

Run the validation loop given the data iterator.

Parameters:iterator – The iterator over the validation data.
_test_loop(iterator: texar.torch.data.data.data_iterators.DataIterator) → None[source]

Run the entire testing loop given the data iterator.

Parameters:iterator – The iterator over the test data.

Conditions

Event

class texar.torch.run.condition.Event[source]

An enumeration.

Condition

class texar.torch.run.condition.Condition[source]

epoch

class texar.torch.run.condition.epoch(num_epochs: int = 1)[source]

Triggers when the specified number of epochs has ended.

Parameters:num_epochs (int) – The number of epochs to wait before triggering the event. In other words, the event is triggered every num_epochs epochs.

iteration

class texar.torch.run.condition.iteration(num_iters: int = 1, mode: str = 'train')[source]

Triggers when the specified number of iterations had ended.

Parameters:
  • num_iters (int) – The number of iterations to wait before triggering the event. In other words, the event is triggered every num_iters iterations.
  • mode (str) – The mode under which iterations are counted. Available choices are "train", "valid", and "test". Defaults to "train".

validation

class texar.torch.run.condition.validation(num_validations: int = 1, better: Optional[bool] = None)[source]

Triggers when validation ends, and optionally checks if validation results improve or worsen.

Parameters:
  • num_validations (int) – The number of validations to wait before triggering the event. In other words, the event is triggered every num_validations validations.
  • better (bool, optional) – If True, this event only triggers when validation results improve; if False, only triggers when results worsen. Defaults to None, in which case the event triggers regardless of results.

consecutive

class texar.torch.run.condition.consecutive(cond: texar.torch.run.condition.Condition, times: int, clear_after_trigger: bool = True)[source]

Triggers when the specified condition passes checks for several times consecutively.

For example: consecutive(validation(better=False), times=3) would trigger if validation results do not improve for 3 times in a row.

Warning

This method works by calling the inner condition at each event point that it registers. The consecutive counter is reset to zero if any check returns False. Thus, the behavior of consecutive might be different to what you expect. For instance:

  • cond.consecutive(cond.iteration(1), n_times) is equivalent to cond.iteration(n_times).
  • cond.consecutive(cond.iteration(2), n_times) will never trigger.

It is recommended against using consecutive for conditions except validation. You should also be careful when implementing custom conditions for using with consecutive.

Warning

Conditions are stateful objects. Using a registered condition as the inner condition here could result in unexpected behaviors. For example:

my_cond = cond.validation(better=True)
executor.on(my_cond, some_action)
executor.on(cond.consecutive(my_cond, 2), some_other_action)

In the code above, if no other conditions are registered, some_other_action will never be called. This is because both conditions are checked at the end of each iteration, but the consecutive condition internally checks my_cond, which has already updated the previous best result that it stored. As a result, the check will never succeed.

Parameters:
  • cond – The base condition to check.
  • times (int) – The number of times the base condition should pass checks consecutively.
  • clear_after_trigger (bool) – Whether the counter should be cleared after the event is triggered. If clear_after_trigger is set to False, once this event is triggered, it will trigger every time cond is triggered, until cond fails to trigger at some point. Defaults to True.

once

class texar.torch.run.condition.once(cond: texar.torch.run.condition.Condition)[source]

Triggers only when the specified condition triggers for the first time.

Internally, this condition calls the remove_action() method to remove itself from the registered actions.

For example: once(iteration(5)) would only trigger on the 5th epoch of the entire training loop.

Warning

Conditions are stateful objects. Using a registered condition as the inner condition here could result in unexpected behaviors. Please refer to consecutive for a concrete example.

Parameters:cond – The base condition to check.

time

class texar.torch.run.condition.time(*, hours: Optional[float] = None, minutes: Optional[float] = None, seconds: Optional[float] = None, only_training: bool = True)[source]

Metrics

Metric

class texar.torch.run.metric.Metric(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

Base class of all metrics. You should not directly inherit this class, but inherit from SimpleMetric or StreamingMetric instead.

Subclasses can override the class attributes to indicate their behaviors:

  • higher_is_better: If True, higher (comparison using the greater than operator > returns True) values are considered better metric values. If False, lower values are considered better. Defaults to True.
  • required_pred: If True, predicted values are required to compute the metric value. Defaults to True.
  • requires_label: If True, labels are required to compute the metric value. Defaults to True.
Keyword Arguments:
 
  • pred_name (str, optional) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str, optional) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".
  • higher_is_better (bool, optional) – If specified, the higher_is_better attribute for the instance is overwritten by the specified value.
metric_name

Name of the metric. By default, the class name is used.

pred_name

Name of the predicted value. This will be used as the key to the dictionary returned by the model.

label_name

Name of the label (ground truth / gold value). This will be used as the key to the batch object returned by the dataset.

reset() → None[source]

Reset the internal state of the metric, and erase all previously added data points.

add(predicted: Sequence[Input], labels: Sequence[Input]) → None[source]

Record a data batch in the metric.

Parameters:
  • predicted – The list of predicted values.
  • labels – The list of labels.
value() → Value[source]

Compute the metric value.

Returns:The metric value.
better(cur: Value, prev: Value) → Optional[bool][source]

Compare two metric values and return which is better.

Parameters:
  • cur – The “current” metric value.
  • prev – The “previous” metric value.
Returns:

Return value is either a bool or None.

  • If True, the current metric value is considered better.
  • If False, the previous metric value is considered better.
  • If None, the two values are considered to be the same, or uncomparable.

SimpleMetric

class texar.torch.run.metric.SimpleMetric(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

Base class of simple metrics. Simple metrics are metrics that do not support incremental computation. The value of the metric is computed only after all data points have been added.

The default implementation of add() simply stores the predicted values and labels into lists.

reset() → None[source]

Reset the internal state of the metric, and erase all previously added data points.

add(predicted: Sequence[Input], labels: Sequence[Input])[source]

Record a data batch in the metric.

Parameters:
  • predicted – The list of predicted values.
  • labels – The list of labels.
value()[source]

Compute the metric value.

Returns:The metric value.

StreamingMetric

class texar.torch.run.metric.StreamingMetric(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

Base class of streaming metrics. Streaming metrics are metrics that support incremental computation. The value of the metric may be queried before all data points have been added, and the computation should not be expensive.

The default implementation of add() only keeps track of the number of data points added. You should override this method.

reset() → None[source]

Reset the internal state of the metric, and erase all previously added data points.

add(predicted: Sequence[Input], labels: Sequence[Input]) → None[source]

Record a data batch in the metric.

Parameters:
  • predicted – The list of predicted values.
  • labels – The list of labels.

Accuracy

class texar.torch.run.metric.Accuracy(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

The accuracy metric for evaluation classification tasks. Accuracy is defined as the ratio of correct (exactly matching) predictions out of all predictions.

Accuracy is a StreamingMetric, requires both predicted values and labels. Accuracy values are float numbers between 0 and 1, with higher values being better.

Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

ConfusionMatrix

class texar.torch.run.metric.ConfusionMatrix(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

The confusion matrix is an evaluation metric for classification tasks.

Confusion matrix is a StreamingMetric, requires both predicted values and labels. Confusion matrix values are NumPy arrays, with no clear definition of “better”. Comparison of two confusion matrices are not meaningful.

The value indexed at (i, j) of the confusion matrix is the number of data points whose predicted label is i and whose ground truth label is j. Labels are internally mapped to indices.

Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".
class_id

Mapping of predicted values and labels to indices within the matrix.

Precision

class texar.torch.run.metric.Precision(mode: str = 'binary', pos_label: Optional[Input] = None, *, pred_name: str, label_name: str = 'label')[source]

The precision metric for evaluation classification tasks. Precision is defined as the ratio of tp / (tp + fp), where tp is the number of true positives and fp is the number of false positives.

Precision is a StreamingMetric, requires both predicted values and labels. Precision values are float numbers between 0 and 1, with higher values being better.

Parameters:
  • mode (str) –

    The mode for computing averages across multiple labels. Defaults to "binary". Available options include:

    • "binary": Only report results for the class specified by pos_label. This is only meaningful for binary classification tasks.
    • "micro": Return the precision value computed using globally counted true positives and false positives.
    • "macro": Return the unweighted average of precision values for each label.
    • "weighted": Return the average of precision values for each label, weighted by the number of true instances for each label.
  • pos_label (str, optional) – The label for the positive class. Only used if mode is set to "binary".
Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

Recall

class texar.torch.run.metric.Recall(mode: str = 'binary', pos_label: Optional[Input] = None, *, pred_name: str, label_name: str = 'label')[source]

The recall metric for evaluation classification tasks. Precision is defined as the ratio of tp / (tp + fn), where tp is the number of true positives and fn is the number of false negatives.

Recall is a StreamingMetric, requires both predicted values and labels. Recall values are float numbers between 0 and 1, with higher values being better.

Parameters:
  • mode (str) –

    The mode for computing averages across multiple labels. Defaults to "binary". Available options include:

    • "binary": Only report results for the class specified by pos_label. This is only meaningful for binary classification tasks.
    • "micro": Return the recall value computed using globally counted true positives and false negatives.
    • "macro": Return the unweighted average of recall values for each label.
    • "weighted": Return the average of recall values for each label, weighted by the number of true instances for each label.
  • pos_label (str, optional) – The label for the positive class. Only used if mode is set to "binary".
Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

F1

class texar.torch.run.metric.F1(mode: str = 'binary', pos_label: Optional[Input] = None, *, pred_name: str, label_name: str = 'label')[source]

The F1 metric for evaluation classification tasks. F1 is defined as the harmonic mean of precision and recall.

F1 is a StreamingMetric, requires both predicted values and labels. F1 values are float numbers between 0 and 1, with higher values being better.

Parameters:
  • mode (str) –

    The mode for computing averages across multiple labels. Defaults to "binary". Available options include:

    • "binary": Only report results for the class specified by pos_label. This is only meaningful for binary classification tasks.
    • "micro": Return the F1 value computed using globally counted true positives, false positives, and false negatives.
    • "macro": Return the unweighted average of F1 values for each label.
    • "weighted": Return the average of F1 values for each label, weighted by the number of true instances for each label.
  • pos_label (str, optional) – The label for the positive class. Only used if mode is set to "binary".
Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

PearsonR

class texar.torch.run.metric.PearsonR(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

The Pearson correlation coefficient (Pearson’s r) metric for evaluation regression tasks. Pearson’s r is a measure of linear correlation between two sets of variables. Pearson’s r ranges between -1 and 1, with 1 indicating total positive linear correlation, -1 indicating total negative linear correlation, and 0 indication no linear correlation.

Pearson’s r is a StreamingMetric, requires both predicted values and labels. Pearson’s r values are float numbers between -1 and 1, with higher values being better.

Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

RMSE

class texar.torch.run.metric.RMSE(*, pred_name: Optional[str], label_name: Optional[str] = 'label', higher_is_better: Optional[bool] = None)[source]

The root mean squared error (RMSE) metric for evaluation regression tasks. RMSE is defined as the standard deviation of the residuals (difference between predicted values and ground truth values).

RMSE is a StreamingMetric, requires both predicted values and labels. RMSE values are float numbers with a lower bound of 0. Lower values are better.

Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
  • label_name (str) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to "label".

Average

class texar.torch.run.metric.Average(*, pred_name: str = 'loss', higher_is_better: bool = False)[source]

The average of a specific predicted value.

Average is a StreamingMetric, requires only predicted values. Average values are unbounded float numbers. By default, lower values are better, but the behavior can be configured.

Keyword Arguments:
 
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model. Defaults to "loss".
  • higher_is_better (bool, optional) – If specified, the higher_is_better attribute for the instance is overwritten by the specified value. Defaults to False.

AveragePerplexity

class texar.torch.run.metric.AveragePerplexity(*, pred_name: str = 'loss', higher_is_better: bool = False)[source]

RunningAverage

class texar.torch.run.metric.RunningAverage(queue_size: int, *, pred_name: str = 'loss', higher_is_better: bool = False)[source]

The running average of a specific predicted value, i.e., the average computed over the most recent queue_size values.

Running average is a StreamingMetric, requires only predicted values. Running average values are unbounded float numbers. By default, lower values are better, but the behavior can be configured.

Keyword Arguments:
 
  • queue_size (int) – Size of the queue to keep history values. The running average is computed over the most recent queue_size values.
  • pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model. Defaults to "loss".
  • higher_is_better (bool, optional) – If specified, the higher_is_better attribute for the instance is overwritten by the specified value.

LR

class texar.torch.run.metric.LR(optimizer: torch.optim.optimizer.Optimizer, param_group: int = 0)[source]

The learning rate (LR) of the given optimizer. This is not exactly a metric, but rather a convenience object to print learning rates in log.

LR is a StreamingMetric, requires neither predicted values nor labels. LR values are unbounded float numbers, with no clear definition of “better”. Comparison of two learning rates are not meaningful.

Keyword Arguments:
 
  • optimizer – The optimizer instance.
  • param_group (int, optional) – Index of the parameter group to obtain the learning rate from. Defaults to 0. You don’t need to specify this if the optimizer contains only one parameter group (e.g., constructed using optim_class(model.parameters()).

Actions

reset_params

class texar.torch.run.action.reset_params(training_state: bool = True)[source]

scale_lr

class texar.torch.run.action.scale_lr(scale: float)[source]

early_stop

class texar.torch.run.action.early_stop(patience: int)[source]