Executor¶
Executor¶
Executor¶
- class texar.torch.run.Executor(model, *, train_data=None, valid_data=None, test_data=None, batching_strategy=None, device=None, tbx_logging_dir=None, tbx_log_every=None, checkpoint_dir=None, max_to_keep=None, save_every=None, save_training_state=True, train_metrics=None, optimizer=None, lr_scheduler=None, stop_training_on=None, num_iters_per_update=1, grad_clip=None, valid_metrics=None, validate_every=None, plateau_condition=None, action_on_plateau=None, validate_mode='eval', test_metrics=None, test_every=None, test_mode='predict', log_every=None, log_format=None, log_destination=None, print_model_arch=True, valid_log_format=None, test_log_format=None, valid_progress_log_format=None, test_progress_log_format=None, show_live_progress=False)[source]¶
Executor
is a substitute for the general training/evaluation loop. It is designed with the following goals in mind:Minimize the amount of boilerplate code that is essentially the same across all experiments.
Provide best practices and hide hideous details from the user.
Guarantee reproducability (runs with same configuration & seed always produces the same result) and portability (same code runs whether using GPU or not).
Meanwhile, allowing flexible configurations and support user-overridden behaviors.
Example
Here is a realistic training loop example using
Executor
, showcasing the features built in.Executor
takes care of common training procedures including logging, checkpoint management, evaluation metrics, validation, and patience-based early-stopping.from texar.torch.run import * executor = Executor( model=model, train_data=datasets["train"], valid_data=datasets["dev"], test_data=datasets["test"], checkpoint_dir=args.save_dir, save_every=cond.validation(better=True), train_metrics=("loss", metric.RunningAverage(args.display_steps)), optimizer={"type": torch.optim.Adam}, grad_clip=args.grad_clip, log_every=cond.iteration(args.display_steps), validate_every=cond.epoch(1), valid_metrics=[ metric.PearsonR(pred_name="preds"), ("loss", metric.Average())], plateau_condition=[ cond.consecutive(cond.validation(better=False), 2)], action_on_plateau=[ action.early_stop(patience=2), action.reset_params(), action.scale_lr(0.8)], stop_training_on=cond.iteration(args.max_train_steps), test_mode='eval', ) executor.train()
Concepts
To make full use of
Executor
, you’ll need to understand the following concepts:Event: Events are a set of pre-defined time spans within the training and evaluation loop. Common events include a training iteration (
Event.Iteration
), an epoch (Event.Epoch
), or an entire validation (Event.Validation
). Please refer to the enum classEvent
for the full list of events. The beginning and end of events are called event points.Condition:
Condition
s are checks performed at the beginning or end of events. If a check passes, we say that the condition “triggers”. To give a few examples:cond.iteration(num_iters)
is checked only at the end of training iterations, and the check passes when the number of iterations passed equalsnum_iters
. Thus, the condition triggers at the end of everynum_iters
iterations.cond.validation(better=True)
is checked only at the end of validations, and the check passes if the current validation result is better than the previous best. Thus, the condition triggers whenever a finished validation yields an improved result.cond.time
is checked at the end of iterations, and at the beginning and end of training, validation, and testing. It checks whether the elapsed time has reached the specified duration, and triggers when it does.
Custom conditions must subclass the
Condition
class.Action: Actions are special callback functions that are called either at the beginning or end of an event (actions registered by
on_event()
), or when conditions trigger (actions registered byon()
, and by specifyingsave_every
and similar arguments). This is similar to hook functions that exists in common frameworks. Actions take a single argument – theExecutor
instance itself, and can perform any operations within.Custom actions can be simple functions, or subclass the
Action
class.Metric: Metrics are used to evaluate the output of models. They are categorized into two classes:
SimpleMetric
andStreamingMetric
. The only difference is that streaming metrics support incremental computation of metric values, so they can be used to aggregate results over the training set, or provide intermediate results on-the-fly.Custom metrics must subclass one of the above classes.
Customization
You can easily extend the
Executor
class by subclassing it and overriding methods. Methods of interest include:_train_step()
: Perform a single step of training (i.e. process a single batch, callbackward()
, and potentially call optimizer updates). Takes the data batch as argument._validate_step()
: Perform a single step of validation. Takes the data batch as argument._test_step()
: Perform a single step of testing. Takes the data batch as argument._train_loop()
: Runs the entire training loop. Takes the data iterator as argument._train_loop()
: Runs the entire validation loop. Takes the data iterator as argument._test_loop()
: Runs the entire testing loop. Takes the data iterator as argument.
You can also define custom events by writing a new enum class and modifying the
_EVENT_TYPES
attribute. Event points can be signaled by calling_fire_event()
. For example:class GANEvent(Enum): DiscriminatorUpdate = auto() GeneratorUpdate = auto() class GANExecutor(Executor): _EVENT_TYPES = (Event, GANEvent) def __init__(self, *args, optimizer_g, optimizer_d, **kwargs): kwargs["optimizer"] = { "g": optimizer_g, "d": optimizer_d, } super.__init__(*args, **kwargs) def _train_step(self, batch): self._fire_event(GANEvent.GeneratorUpdate, False) z = torch.randn(len(batch), args.latent_dim) fake_image = self.model.generator(z) logits = self.model.discriminator(fake_image) g_loss = F.binary_cross_entropy(logits, torch.ones(len(batch)) g_loss.backward() self.optimizer["g"].step() self.optimizer["g"].zero_grad() self._fire_event(GANEvent.GeneratorUpdate, True) self._fire_event(GANEvent.DiscriminatorUpdate, False) real_logits = self.model.discriminator(batch.image) fake_logits = self.model.discriminator(fake_image.detach()) real_loss = F.binary_cross_entropy(real_logits, torch.ones(len(batch))) fake_loss = F.binary_cross_entropy(fake_logits, torch.zeros(len(batch))) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() self.optimizer["d"].step() self.optimizer["d"].zero_grad() self._fire_event(GANEvent.DiscriminatorUpdate, True) return {"g_loss": g_loss, "d_loss": d_loss}
Arguments
The constructor of
Executor
takes many arguments, almost all of which are keyword-only and optional. Some arguments can take values of multiple types, either a single instance of a specific type, or a list or dictionary of values of that type.Arguments grouped by functions:
General arguments:
- model: torch.nn.Module
The model to train or evaluate. The model must be a subclass of torch.nn.Module, with its
forward()
method taking a single argumentbatch
and returning a dictionary of torch.Tensors:The
batch
argument is of typeBatch
, which is the batch object produced by your provided dataset.The returned dictionary must contain an entry named
"loss"
, which will be used as the loss to backpropagate during training. You can also include other values and use them in metrics.
If the model performs different routines during training and evaluation (for instance, a sequence-to-sequence model may train using teacher-forcing but evaluate using beam search-based inference), you can define another method
predict()
following the same signature asforward()
. To usepredict()
instead offorward()
in validation or testing, setvalidate_mode
ortest_mode
to"predict"
instead of"eval"
.If the model you use does not follow this convention, you will need to wrap the model in a new class. The following example demonstrates how to wrap
XLNetRegressor
:class RegressorWrapper(tx.modules.XLNetRegressor): def forward(self, batch): preds = super().forward(token_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask) loss = (preds - batch.label_ids) ** 2 loss = loss.sum() / len(batch) return {"loss": loss, "preds": preds}
- train_data:
DatasetBase
The dataset used during training. Must be specified for training.
- valid_data:
DatasetBase
The dataset used during validation. If not specified, you cannot perform validation during training (e.g., by setting
validate_every
).- test_data:
DatasetBase
, or a list or dictionary The dataset(s) used during testing. If not specified, you cannot perform testing during training (e.g., by setting
test_every
).- batching_strategy:
BatchingStrategy
The batching strategy to use for batching. This will be passed as the
batching_strategy
argument forDataIterator
during training, and evaluation if the corresponding mode is set to"eval"
.- validate_mode: str
The evaluation mode for validation. Available choices are
"eval"
and"predict"
. Defaults to"eval"
. When mode is set to"eval"
,forward()
method of the model will be called; when set to"predict"
,predict()
method of the model will be called.- test_mode: str
The evaluation mode for testing. Available choices are
"eval"
and"predict"
. Defaults to"predict"
.- device:
torch.device
The device on which the model and data should be placed. Defaults to None, in which case GPUs will be used if available.
Arguments for tensorboard logging:
- tbx_logging_dir: str
Path to the directory for storing tensorboard logs.
- tbx_log_every:
Condition
Conditions that, when triggered, saves the tensorboard logs for train metrics. If None, log_every will be used.
Arguments for checkpoint management:
- checkpoint_dir: str
Path to the directory for storing checkpoints. If not specified, you cannot save/load the model during training (e.g., by setting
save_every
or usingreset_params
).- max_to_keep: int
The maximum number of checkpoints to keep in the checkpoint directory. When the number of checkpoints exceed this limit, the oldest one will be removed. If None, no such limit is imposed. Defaults to None.
Note
Be careful when saving periodic snapshots along with the best performing checkpoint. Periodic snapshots might overwrite best models if checkpoint limit is exceeded.
A better workaround is to only save periodic snapshots with the built-in mechanism, and register a custom action for saving a single best performing checkpoint:
# Don't executor = Executor( save_every=[cond.epoch(1), cond.validation(better=True)], max_to_keep=3, ...) # Do executor = Executor( save_every=cond.epoch(1), max_to_keep=3, ...) @executor.on(cond.validation(better=True)) def save_best_model(executor): executor.save(path=another_directory, max_to_keep=1)
- save_every:
Condition
, or a list Conditions that, when triggered, saves the model.
In the following example, the model will be saved every 1000 iterations, or whenever validation results improve.
save_every=[cond.validation(better=True), cond.iteration(1000)]
- save_training_state: bool
If False, only save model parameters in checkpoint. If True, also save optimizer and scheduler states, along with random number generator states from Python, NumPy, and PyTorch. Defaults to True.
Arguments for training:
- train_metrics:
Metric
, or a list or dictionary The metrics computed over the training set. In case of multiple metrics, two sets of metric values will be compared in the provided order.
For example, if two metrics f1 (
F1
) and loss (Average
) are defined (in this order), when comparing two sets of values, the one with a higher f1 is considered better. If the two sets have the same f1 value, the one with a lower loss is considered better.Acceptable values include:
A single
Metric
, or a list ofMetric
s. These metrics will be automatically named when they’re logged.A tuple of (
str
,Metric
), or a list of this. These metrics are explicitly named according to the provided strings. Note that names must be unique, and should be valid identifier names.An
OrderedDict
mapping names to metrics. Note that a plain (unordered) dictionary is not accepted.
Note
Metrics that are logged will be evaluated once every time logging is performed. For efficiency considerations, such metrics should be
StreamingMetric
s. Please take extra care when implementing your own metrics.- optimizer: torch.optim.Optimizer, or a dictionary of hyperparameters
The optimizer used during training. This can be a torch.optim.Optimizer instance, or a dictionary of hyperparameters that will be passed into
texar.torch.utils.get_instance()
. Must be specified for training.- lr_scheduler: LRScheduler, or a dictionary of hyperparameters, optional
The learning rate scheduler used during training. This can be an LRScheduler instance, or a dictionary of hyperparameters that will be passed into
texar.torch.utils.get_instance()
.- stop_training_on:
Condition
, or a list Conditions that, when triggered, will stop training.
In the following example, training will be terminated after 5 epochs or 20000 iterations, whichever comes first:
stop_training_on=[cond.epoch(5), cond.iteration(20000)]
- num_iters_per_update: int
Number of iterations to run before performing a parameter update. When this value is greater than 1, the loss is scaled by its reciprocal. Defaults to 1, in which case the parameters are updated after each
.backward()
call.This can be used to accumulate gradients across multiple batches, in order to simulate the effect of using a large batch size on a machine with limited memory.
- grad_clip: float, optional
Maximum norm of the gradients. Please refer to nn.utils.clip_grad_norm_ for details. Defaults to None, i.e. no clipping.
Arguments for validation:
- valid_metrics:
Metric
, or a list or dictionary The metrics computed over the validation set. Please see train_metrics for details.
- validate_every:
Condition
, or a list Conditions that, when triggered, performs validation.
In the following example, the model will be validated once per epoch.
validate_every=cond.epoch(1)
- plateau_condition:
Condition
, or a list Conditions that, when triggered, indicates that training has reached a plateau, i.e., the model has stopped improving.
In the following example, we consider that training has reached a plateau if validation metrics have not improved for 3 consecutive validations.
plateau_condition=cond.consecutive(cond.validation(better=False))
- action_on_plateau:
Action
, or a list Actions that will be called when training has reached a plateau.
In the following example, we perform patience-based early-stopping when reaching plateaus. A patience of 2 means training will be terminated after plateau is reached twice. We also scale the learning rate by 0.8, and reset the model & optimizer parameters to the previous best checkpoint.
action_on_plateau=[ action.reset_params(), action.scale_lr(0.8), action.early_stop(patience=2)]
Arguments for testing:
- test_metrics:
Metric
, or a list or dictionary The metrics computed over the test set. Please see train_metrics for details.
test()
can only be called iftest_metrics
is not None.Note
valid_metrics
will be automatically shared withtest_metrics
if:test_metrics
is None;validate_mode
is the same astest_mode
.
In this case, calling validation while testing (or vice versa) could cause incorrect results since the same
Metric
instances are used.- test_every:
Condition
, or a list Conditions that, when triggered, performs testing.
In the following example, the model will be tested whenever validation results improve.
test_every=cond.validation(better=True)
Arguments for logging:
- log_every:
Condition
, or a list Conditions that, when triggered, performs logging.
In the following example, a log will be printed every 100 iterations, and after every epoch.
log_every=[cond.iteration(100), cond.epoch()]
- log_destination:
str
, IO object, or a list Logging destinations. Acceptable values include:
An IO object. This can be an opened file,
sys.stdout
, or any other file-like object.A string, denoting the path to a log file.
Executor
will open the file and close it when the program exits. The file will be opened in “append” ("a"
) mode to prevent accidentally overwriting previous logs. To force overwrite, supply the file object instead.A list, with each element being one of the above.
When writing to a file, special syntax for terminals (e.g. color codes) are emitted. Also, live progress is not written to files.
By default, the log is only written to
sys.stdout
.
- log_format:
str
The format string for logs during training.
The format string follows the syntax of Python format strings. The status variables you can reference include:
epoch
(int): The current epoch.iteration
(int): The current iteration.progress
(float): The epoch progress represented in percentage, i.e. a floating-point number between 0 and 100. It should be noted that progress may not be accurate, and may not be available if the data is loaded lazily.speed
(float): Average number of data examples processed per second. It should be noted that speed may not be accurate.time
: The current date and time. Time format can be set using the “format spec” syntax of Python format strings, i.e.:{time:%H:%M:%S}
prints time in the format08:26:03
. If time format is not specified, it is equivalent to{time:%Y-%m-%d %H:%M:%S}
, which corresponds to the format2018-07-06 08:26:03
. For more information on time formatting, please refer to documentation for Python built-in functiondate.strftime()
.metric
: An aggregated representation of all metrics, in the format of<name1>: <value1>, <name2>: <value2>, ...
. The format spec formetric
will be applied to all metrics whose value supports such format spec. For more fine grained control over metric formatting, use the following methods.metric.<name>
: Value of the metric under the specified name<name>
.<name>
: Value of the metric under the specified name<name>
.Note
The metric can only be looked-up if its name does not coincide with built-in status variables. For instance, a metric named “loss” can be looked-up by
loss
andmetric.loss
, but a metric named “time” can only be looked-up bymetric.time
.
The default format string is:
{time} : Epoch {epoch} @ {iteration}it ({progress}%, {speed}), {{{metric:.3f}}}
which produces logs similar to:
2019-08-05 11:46:26 : Epoch 1 @ 800it (20.3%, 16.14ex/s), {loss: 0.358, Accuracy: 0.567}
- valid_log_format: str
The format string for logs when validation completes. Please refer to log_format for details of the format string.
The default format string is:
{time} : Epoch {epoch}, {split} result = {{{metric:.3f}}}
which produces logs similar to:
2019-08-05 11:36:53 : Epoch 6, valid result = {PearsonR: 0.918, loss: 0.363}
- test_log_format: str, optional
The format string for logs when testing completes. Please refer to log_format for details of the format string.
If None, the format string for validation (
valid_log_format
) is used. Defaults to None.- valid_progress_log_format: str
The format string for logs during validation if live progress is enabled. Please refer to log_format for details of the format string.
The default format string is:
{time} : Evaluating on {split} ({progress}%, {speed}), {{{metric:.3f}}}
which produces logs similar to:
2019-08-05 11:35:56 : Evaluating on test (65.4%, 1.12s/ex), {PearsonR: 0.911, loss: 0.384}
- test_progress_log_format: str
The format string for logs during testing if live progress is enabled. Please refer to log_format for details of the format string.
If None, the format string for validation (
valid_progress_log_format
) is used. Defaults to None.- print_model_arch: bool
If True, model architecture and will logged in a readable format. Defaults to True.
- show_live_progress: bool or str
Controls whether live progress will be shown. Acceptable values include:
True: Live progress is enabled for training, validation, and testing.
False: Live progress is disabled.
"train"
,"valid"
,"test"
, or a list of these strings: Live progress is enabled for the specified stages only.
If live progress is enabled for a certain stage, the specified format string will be shown similar to a sticky progress bar at the bottom of the terminal window, and updated after each iteration.
Note that live progress is only shown on terminals. It will not be printed to log files.
Warning
This may incur extra overhead because an update requires re-evaluating metrics. Make sure that all metrics logged are
StreamingMetric
s. You can also explicitly log only streaming metrics, or disable live progress for certain stages.
- write_log(log_str, *, mode='info', skip_tty=False, skip_non_tty=False)[source]¶
Write a string to log.
- Parameters
log_str (str) – The string to log.
mode (str) –
The logging mode. Supported values are:
"log"
: A plain log. Logs created bylog_every
andeval_log_every
are in this format."info"
: Indicates a result or notification. Includes a header with"INFO"
and a timestamp in green color. This is the default mode."warning"
: Indicates an unexpected or incorrect situation. Includes a header with"WARNING"
and a timestamp. The entire warning is in red color.
skip_tty – If True, the log string will not be written to terminal log destinations. Defaults to False.
skip_non_tty – If True, the log string will not be written to non-terminal (e.g., files) destinations. Defaults to False.
- on(cond, func=None)[source]¶
Register a function as an action triggered on a condition. For example:
executor = Executor(...) @executor.on(cond.iteration(10, mode="valid")) def log_during_validation(executor): logging.info("Validation iter %d", executor.status["iteration"])
The action function takes exactly one argument: the executor instance itself.
- Parameters
cond – The condition that will call the function when triggered. Must be of type
Condition
.func (optional) – The function to register. If not None, the function will be registered; if None, a decorator will be returned.
- Returns
If
func
is None, this method will return a decorator to wrap around the function to register as hook.If
func
is not None, this method will return theExecutor
itself, allowing chained calls.
- on_event(event, point='end', func=None)[source]¶
Register a function as an action triggered on an event point. For example:
executor = Executor(...) @executor.on_event(Event.Epoch, 'end') def log_at_end_of_epoch(executor): logging.info("Epoch %d done", executor.status["epoch"])
The action function takes exactly one argument: the executor instance itself.
- Parameters
event – The event to hook on. Must be an enum value from
Event
.point (str) – The point of event to hook on. Supported values are
"begin"
and"end"
. Defaults to"end"
.func (optional) – The function to register. If not None, the function will be registered; if None, a decorator will be returned.
- Returns
If
func
is None, this method will return a decorator to wrap around the function to register as hook.If
func
is not None, this method will return theExecutor
itself, allowing chained calls.
- save(path=None, save_training_state=None)[source]¶
Save a snapshot of the current state to a checkpoint file.
- Parameters
path (str, optional) – Path to the checkpoint directory. If None,
checkpoint_dir
in the constructor arguments will be used. Defaults to None.save_training_state (bool) – If True, will save entire training state from checkpoint. If False, only save model weights. If None, the value from the constructor arguments will be used. Defaults to None.
- load(path=None, load_training_state=True, allow_failure=False)[source]¶
Load a previous model checkpoint from file.
- Parameters
path (str, optional) – Path to a specific checkpoint or a checkpoint directory. If a directory is specified, the most recent checkpoint in the directory is loaded. If None,
checkpoint_dir
in the constructor arguments will be used. Defaults to None.load_training_state (bool) – If True, will load entire training state from checkpoint (if the checkpoint contains training state). Otherwise, just load model weights. Defaults to True.
allow_failure (bool) – If True, no exceptions will be raised if no checkpoints were found. Defaults to False. Note that exceptions are still raised if the provided
path
does not exist, or the selected checkpoint is corrupted.
- Returns
Path of the loaded checkpoint, or None if load failed.
- terminate()[source]¶
Terminate training. This method is intended to be called within actions. An example use case would be to implement a custom early-stopping mechanism.
It is guaranteed that no other event points will be fired once
terminate()
is called. However, conditions and actions under the same event point is still called.
- remove_action()[source]¶
Remove the current action being run. This method is intended to be called within actions. An example use case would be to implement an action that is only run once at a certain event point.
- test(dataset=None)[source]¶
Start the test loop.
- Parameters
dataset (optional) –
The dataset(s) to test on. Acceptable values include:
A single
DatasetBase
instance.A list of
DatasetBase
instances.A dictionary mapping names to
DatasetBase
instances.
If None,
test_data
from the constructor arguments is used. Defaults to None.
- _write_log(log_str, skip_tty=False, skip_non_tty=False, newline=True, clear_line=False)[source]¶
Write a string to log.
- Parameters
log_str (str) – The string to log.
skip_tty – If True, the log string will not be written to terminal log destinations. Defaults to False.
skip_non_tty – If True, the log string will not be written to non-terminal (e.g., files) destinations. Defaults to False.
newline – If True, print a newline character after printing the log string. Defaults to True.
clear_line – If True, clear the line before printing. This only works for terminals. Defaults to False.
- _create_logging_fn(format_str, metrics, tracker, warn_non_streaming_metric=False)[source]¶
Given a logging format string, create a function that takes the executor instance as argument and returns the logging string.
- Parameters
format_str (str) – The logging format string.
metrics – The metrics dictionary that will be used in the logging hook function.
tracker – The progress tracker that will be used in the logging hook function.
warn_non_streaming_metric (bool) – If True, will issue a warning if any of the provided metric is not a
StreamingMetric
. Defaults to False. This is set to True when the logging string is used as the status line.
- Returns
A hook function to print logs given the format string. Note that the hook function can accept additional arguments to pass to
executor.write_log()
, allowing it to be used in combination withfunctools.partial()
.
- _fire_event(event, end)[source]¶
Signal the beginning or end of an event. Internally, this is where conditions are checked and actions are executed.
- Parameters
event – The
Event
to fire.end – If True, the fired event point is the end of
event
. If False, the fired event point is the beginning ofevent
.
- Raises
If any triggered action calls
terminate()
,ExecutorTerminateSignal
is thrown after all conditions are checked and actions executed.
- _validate_step(batch)[source]¶
Perform one step of validation, i.e., perform a forward pass (or decoding, depending on
validate_mode
) for a single batch.- Parameters
batch – The batch to validate on.
- Returns
The dictionary containing values returned by the model. This is used to compute metrics.
- _test_step(batch)[source]¶
Perform one step of testing, i.e., perform a forward pass (or decoding, depending on
test_mode
) for a single batch.- Parameters
batch – The batch to test on.
- Returns
The dictionary containing values returned by the model. This is used to compute metrics.
- _train_step(batch)[source]¶
Perform one step of training, i.e., perform a forward and backward pass for a single batch. Parameter updates should also be performed when necessary.
- Parameters
batch – The batch to train on.
- Returns
The dictionary containing values returned by the model. This is used to compute metrics.
- _train_loop(iterator)[source]¶
Run the entire training loop given the data iterator.
- Parameters
iterator – The iterator over the training data.
Conditions¶
Event¶
Condition¶
epoch¶
iteration¶
validation¶
- class texar.torch.run.condition.validation(num_validations=1, better=None)[source]¶
Triggers when validation ends, and optionally checks if validation results improve or worsen.
- Parameters
num_validations (int) – The number of validations to wait before triggering the event. In other words, the event is triggered every
num_validations
validations.better (bool, optional) – If True, this event only triggers when validation results improve; if False, only triggers when results worsen. Defaults to None, in which case the event triggers regardless of results.
consecutive¶
- class texar.torch.run.condition.consecutive(cond, times, clear_after_trigger=True)[source]¶
Triggers when the specified condition passes checks for several times consecutively.
For example:
consecutive(validation(better=False), times=3)
would trigger if validation results do not improve for 3 times in a row.Warning
This method works by calling the inner condition at each event point that it registers. The consecutive counter is reset to zero if any check returns False. Thus, the behavior of
consecutive
might be different to what you expect. For instance:cond.consecutive(cond.iteration(1), n_times)
is equivalent tocond.iteration(n_times)
.cond.consecutive(cond.iteration(2), n_times)
will never trigger.
It is recommended against using
consecutive
for conditions exceptvalidation
. You should also be careful when implementing custom conditions for using withconsecutive
.Warning
Conditions are stateful objects. Using a registered condition as the inner condition here could result in unexpected behaviors. For example:
my_cond = cond.validation(better=True) executor.on(my_cond, some_action) executor.on(cond.consecutive(my_cond, 2), some_other_action)
In the code above, if no other conditions are registered,
some_other_action
will never be called. This is because both conditions are checked at the end of each iteration, but theconsecutive
condition internally checksmy_cond
, which has already updated the previous best result that it stored. As a result, the check will never succeed.- Parameters
cond – The base condition to check.
times (int) – The number of times the base condition should pass checks consecutively.
clear_after_trigger (bool) – Whether the counter should be cleared after the event is triggered. If
clear_after_trigger
is set to False, once this event is triggered, it will trigger every timecond
is triggered, untilcond
fails to trigger at some point. Defaults to True.
once¶
- class texar.torch.run.condition.once(cond)[source]¶
Triggers only when the specified condition triggers for the first time.
Internally, this condition calls the
remove_action()
method to remove itself from the registered actions.For example:
once(iteration(5))
would only trigger on the 5th epoch of the entire training loop.Warning
Conditions are stateful objects. Using a registered condition as the inner condition here could result in unexpected behaviors. Please refer to
consecutive
for a concrete example.- Parameters
cond – The base condition to check.
time¶
Metrics¶
Metric¶
- class texar.torch.run.metric.Metric(*, pred_name, label_name='label', higher_is_better=None)[source]¶
Base class of all metrics. You should not directly inherit this class, but inherit from
SimpleMetric
orStreamingMetric
instead.Subclasses can override the class attributes to indicate their behaviors:
higher_is_better
: If True, higher (comparison using the greater than operator>
returns True) values are considered better metric values. If False, lower values are considered better. Defaults to True.required_pred
: If True, predicted values are required to compute the metric value. Defaults to True.requires_label
: If True, labels are required to compute the metric value. Defaults to True.
- Keyword Arguments
pred_name (str, optional) – Name of the predicted value. This will be used as the key to the dictionary returned by the model.
label_name (str, optional) – Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to
"label"
.higher_is_better (bool, optional) – If specified, the
higher_is_better
attribute for the instance is overwritten by the specified value.
- property metric_name¶
Name of the metric. By default, the class name is used.
- property pred_name¶
Name of the predicted value. This will be used as the key to the dictionary returned by the model.
- property label_name¶
Name of the label (ground truth / gold value). This will be used as the key to the batch object returned by the dataset.
- abstract reset()[source]¶
Reset the internal state of the metric, and erase all previously added data points.
- abstract add(predicted, labels)[source]¶
Record a data batch in the metric.
- Parameters
predicted – The list of predicted values.
labels – The list of labels.
- better(cur, prev)[source]¶
Compare two metric values and return which is better.
- Parameters
cur – The “current” metric value.
prev – The “previous” metric value.
- Returns
Return value is either a bool or None.
If True, the current metric value is considered better.
If False, the previous metric value is considered better.
If None, the two values are considered to be the same, or uncomparable.
- finalize(executor)[source]¶
Finalize the metric. Called when the whole dataset has been fully iterated, e.g., at the end of an epoch, or the end of validation or testing.
The default behavior is no-op. Most metrics won’t need this, special ones such as
FileWriterMetric
utilizes this to performs one-time only operations.- Parameters
executor – The
Executor
instance.
SimpleMetric¶
- class texar.torch.run.metric.SimpleMetric(*, pred_name, label_name='label', higher_is_better=None)[source]¶
Base class of simple metrics. Simple metrics are metrics that do not support incremental computation. The value of the metric is computed only after all data points have been added.
The default implementation of
add()
simply stores the predicted values and labels into lists.- reset()[source]¶
Reset the internal state of the metric, and erase all previously added data points.
StreamingMetric¶
- class texar.torch.run.metric.StreamingMetric(*, pred_name, label_name='label', higher_is_better=None)[source]¶
Base class of streaming metrics. Streaming metrics are metrics that support incremental computation. The value of the metric may be queried before all data points have been added, and the computation should not be expensive.
The default implementation of
add()
only keeps track of the number of data points added. You should override this method.
Accuracy¶
- class texar.torch.run.metric.Accuracy(*, pred_name, label_name='label', higher_is_better=None)[source]¶
The accuracy metric for evaluation classification tasks. Accuracy is defined as the ratio of correct (exactly matching) predictions out of all predictions.
Accuracy is a
StreamingMetric
, requires both predicted values and labels. Accuracy values arefloat
numbers between 0 and 1, with higher values being better.
ConfusionMatrix¶
- class texar.torch.run.metric.ConfusionMatrix(*, pred_name, label_name='label', higher_is_better=None)[source]¶
The confusion matrix is an evaluation metric for classification tasks.
Confusion matrix is a
StreamingMetric
, requires both predicted values and labels. Confusion matrix values are NumPy arrays, with no clear definition of “better”. Comparison of two confusion matrices are not meaningful.The value indexed at
(i, j)
of the confusion matrix is the number of data points whose predicted label is i and whose ground truth label is j. Labels are internally mapped to indices.- Keyword Arguments
- property class_id¶
Mapping of predicted values and labels to indices within the matrix.
Precision¶
- class texar.torch.run.metric.Precision(mode='binary', pos_label=None, *, pred_name, label_name='label')[source]¶
The precision metric for evaluation classification tasks. Precision is defined as the ratio of
tp / (tp + fp)
, wheretp
is the number of true positives andfp
is the number of false positives.Precision is a
StreamingMetric
, requires both predicted values and labels. Precision values arefloat
numbers between 0 and 1, with higher values being better.- Parameters
mode (str) –
The mode for computing averages across multiple labels. Defaults to
"binary"
. Available options include:"binary"
: Only report results for the class specified bypos_label
. This is only meaningful for binary classification tasks."micro"
: Return the precision value computed using globally counted true positives and false positives."macro"
: Return the unweighted average of precision values for each label."weighted"
: Return the average of precision values for each label, weighted by the number of true instances for each label.
pos_label (str, optional) – The label for the positive class. Only used if
mode
is set to"binary"
.
- Keyword Arguments
Recall¶
- class texar.torch.run.metric.Recall(mode='binary', pos_label=None, *, pred_name, label_name='label')[source]¶
The recall metric for evaluation classification tasks. Precision is defined as the ratio of
tp / (tp + fn)
, wheretp
is the number of true positives andfn
is the number of false negatives.Recall is a
StreamingMetric
, requires both predicted values and labels. Recall values arefloat
numbers between 0 and 1, with higher values being better.- Parameters
mode (str) –
The mode for computing averages across multiple labels. Defaults to
"binary"
. Available options include:"binary"
: Only report results for the class specified bypos_label
. This is only meaningful for binary classification tasks."micro"
: Return the recall value computed using globally counted true positives and false negatives."macro"
: Return the unweighted average of recall values for each label."weighted"
: Return the average of recall values for each label, weighted by the number of true instances for each label.
pos_label (str, optional) – The label for the positive class. Only used if
mode
is set to"binary"
.
- Keyword Arguments
F1¶
- class texar.torch.run.metric.F1(mode='binary', pos_label=None, *, pred_name, label_name='label')[source]¶
The F1 metric for evaluation classification tasks. F1 is defined as the harmonic mean of precision and recall.
F1 is a
StreamingMetric
, requires both predicted values and labels. F1 values arefloat
numbers between 0 and 1, with higher values being better.- Parameters
mode (str) –
The mode for computing averages across multiple labels. Defaults to
"binary"
. Available options include:"binary"
: Only report results for the class specified bypos_label
. This is only meaningful for binary classification tasks."micro"
: Return the F1 value computed using globally counted true positives, false positives, and false negatives."macro"
: Return the unweighted average of F1 values for each label."weighted"
: Return the average of F1 values for each label, weighted by the number of true instances for each label.
pos_label (str, optional) – The label for the positive class. Only used if
mode
is set to"binary"
.
- Keyword Arguments
PearsonR¶
- class texar.torch.run.metric.PearsonR(*, pred_name, label_name='label', higher_is_better=None)[source]¶
The Pearson correlation coefficient (Pearson’s r) metric for evaluation regression tasks. Pearson’s r is a measure of linear correlation between two sets of variables. Pearson’s r ranges between -1 and 1, with 1 indicating total positive linear correlation, -1 indicating total negative linear correlation, and 0 indication no linear correlation.
Pearson’s r is a
StreamingMetric
, requires both predicted values and labels. Pearson’s r values arefloat
numbers between -1 and 1, with higher values being better.
RMSE¶
- class texar.torch.run.metric.RMSE(*, pred_name, label_name='label', higher_is_better=None)[source]¶
The root mean squared error (RMSE) metric for evaluation regression tasks. RMSE is defined as the standard deviation of the residuals (difference between predicted values and ground truth values).
RMSE is a
StreamingMetric
, requires both predicted values and labels. RMSE values arefloat
numbers with a lower bound of 0. Lower values are better.
Average¶
- class texar.torch.run.metric.Average(*, pred_name='loss', higher_is_better=False)[source]¶
The average of a specific predicted value.
Average is a
StreamingMetric
, requires only predicted values. Average values are unboundedfloat
numbers. By default, lower values are better, but the behavior can be configured.- Keyword Arguments
AveragePerplexity¶
RunningAverage¶
- class texar.torch.run.metric.RunningAverage(queue_size, *, pred_name='loss', higher_is_better=False)[source]¶
The running average of a specific predicted value, i.e., the average computed over the most recent
queue_size
values.Running average is a
StreamingMetric
, requires only predicted values. Running average values are unboundedfloat
numbers. By default, lower values are better, but the behavior can be configured.- Keyword Arguments
queue_size (int) – Size of the queue to keep history values. The running average is computed over the most recent
queue_size
values.pred_name (str) – Name of the predicted value. This will be used as the key to the dictionary returned by the model. Defaults to
"loss"
.higher_is_better (bool, optional) – If specified, the
higher_is_better
attribute for the instance is overwritten by the specified value.
LR¶
- class texar.torch.run.metric.LR(optimizer, param_group=0)[source]¶
The learning rate (LR) of the given optimizer. This is not exactly a metric, but rather a convenience object to print learning rates in log.
LR is a
StreamingMetric
, requires neither predicted values nor labels. LR values are unboundedfloat
numbers, with no clear definition of “better”. Comparison of two learning rates are not meaningful.- Keyword Arguments
optimizer – The optimizer instance.
param_group (int, optional) – Index of the parameter group to obtain the learning rate from. Defaults to 0. You don’t need to specify this if the optimizer contains only one parameter group (e.g., constructed using
optim_class(model.parameters())
.