# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conditions for the Executor module.
"""
import functools
import types
from abc import ABC, abstractmethod
from enum import Enum, auto
from time import time as time_now
from typing import Any, Dict, Optional, Tuple
from texar.torch.run.executor_utils import MetricList
from texar.torch.utils.types import MaybeTuple
# pylint: disable=unused-argument
__all__ = [
"Event",
"EventPoint",
"Condition",
"epoch",
"iteration",
"validation",
"consecutive",
"time",
]
[docs]class Event(Enum):
Iteration = auto()
Epoch = auto()
Training = auto()
Validation = auto()
ValidationIteration = auto()
Testing = auto()
TestingIteration = auto()
ParameterUpdate = auto()
EventPoint = Tuple[Event, bool]
[docs]class Condition(ABC):
_hooks: Dict[EventPoint, Any]
@property
@abstractmethod
def _hash_attributes(self) -> MaybeTuple[Any]:
raise NotImplementedError
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Condition):
return False
return self._hash_attributes == other._hash_attributes # pylint: disable=protected-access
def __hash__(self):
return hash(self._hash_attributes)
@property
def hooks(self) -> Dict[EventPoint, Any]:
return self._hooks
def __init__(self):
self._hooks = {}
for hook_name in dir(self):
if not hook_name.startswith("check_"):
continue
name = hook_name
if name.endswith("_begin"):
name = name[6:-6]
point = False
elif name.endswith("_end"):
name = name[6:-4]
point = True
else:
raise ValueError(
"Final part of hook name must be 'begin' or 'end'")
if name not in Event.__members__:
name = ''.join(x.capitalize() for x in name.split("_"))
if name not in Event.__members__:
raise ValueError(
f"Hook name '{hook_name}' is not a valid event")
event = Event.__members__[name]
self._hooks[(event, point)] = getattr(self, hook_name)
[docs]class epoch(Condition):
r"""Triggers when the specified number of epochs has ended.
Args:
num_epochs (int): The number of epochs to wait before triggering the
event. In other words, the event is triggered every
:attr:`num_epochs` epochs.
"""
def __init__(self, num_epochs: int = 1):
if not isinstance(num_epochs, int) or num_epochs <= 0:
raise ValueError("`num_epochs` must be a positive integer")
super().__init__()
self.num_epochs = num_epochs
self.count = 0
@property
def _hash_attributes(self):
return self.num_epochs
def check_epoch_end(self, executor) -> bool:
self.count += 1
if self.count == self.num_epochs:
self.count = 0
return True
return False
[docs]class iteration(Condition):
r"""Triggers when the specified number of iterations had ended.
Args:
num_iters (int): The number of iterations to wait before triggering the
event. In other words, the event is triggered every
:attr:`num_iters` iterations.
mode (str): The mode under which iterations are counted. Available
choices are ``"train"``, ``"valid"``, and ``"test"``. Defaults to
``"train"``.
"""
def __new__(cls, num_iters: int = 1, mode: str = "train"):
obj = super().__new__(cls)
# pylint: disable=protected-access
if mode == "train":
obj.check_iteration_end = obj._check_iteration_end
elif mode == "valid":
obj.check_validation_iteration_end = obj._check_iteration_end
elif mode == "test":
obj.check_testing_iteration_end = obj._check_iteration_end
else:
raise ValueError(f"Invalid mode {mode}")
# pylint: enable=protected-access
return obj
def __init__(self, num_iters: int = 1, mode: str = "train"):
if not isinstance(num_iters, int) or num_iters <= 0:
raise ValueError("`num_iters` must be a positive integer")
super().__init__()
self.num_iters = num_iters
self.count = 0
@property
def _hash_attributes(self):
return self.num_iters
def _check_iteration_end(self, executor) -> bool:
self.count += 1
if self.count == self.num_iters:
self.count = 0
return True
return False
[docs]class validation(Condition):
r"""Triggers when validation ends, and optionally checks if validation
results improve or worsen.
Args:
num_validations (int): The number of validations to wait before
triggering the event. In other words, the event is triggered every
:attr:`num_validations` validations.
better (bool, optional): If `True`, this event only triggers when
validation results improve; if `False`, only triggers when results
worsen. Defaults to `None`, in which case the event triggers
regardless of results.
"""
def __init__(self, num_validations: int = 1, better: Optional[bool] = None):
if not isinstance(num_validations, int) or num_validations <= 0:
raise ValueError("`num_validations` must be a positive integer")
super().__init__()
self.num_valids = num_validations
self.count = 0
self.better = better
self.prev_result: Optional[MetricList] = None
@property
def _hash_attributes(self):
return self.num_valids, self.better
def check_validation_end(self, executor) -> bool:
self.count += 1
if self.count < self.num_valids:
return False
self.count = 0
if self.better is None:
return True
metrics = executor.status["eval_metric"]
cur_result = MetricList(metrics)
if self.prev_result is not None:
better = cur_result > self.prev_result
else:
better = True
if better:
self.prev_result = cur_result
return better == self.better
[docs]class consecutive(Condition):
r"""Triggers when the specified condition passes checks for several times
consecutively.
For example: :python:`consecutive(validation(better=False), times=3)` would
trigger if validation results do not improve for 3 times in a row.
.. warning::
This method works by calling the inner condition at each event point
that it registers. The consecutive counter is reset to zero if any check
returns `False`. Thus, the behavior of :class:`consecutive` might be
different to what you expect. For instance:
- :python:`cond.consecutive(cond.iteration(1), n_times)` is equivalent
to :python:`cond.iteration(n_times)`.
- :python:`cond.consecutive(cond.iteration(2), n_times)` will never
trigger.
It is recommended against using :class:`consecutive` for conditions
except :class:`validation`. You should also be careful when implementing
custom conditions for using with :class:`consecutive`.
.. warning::
Conditions are stateful objects. Using a registered condition as the
inner condition here could result in unexpected behaviors. For example:
.. code-block:: python
my_cond = cond.validation(better=True)
executor.on(my_cond, some_action)
executor.on(cond.consecutive(my_cond, 2), some_other_action)
In the code above, if no other conditions are registered,
:python:`some_other_action` will never be called. This is because both
conditions are checked at the end of each iteration, but the
:class:`consecutive` condition internally checks :python:`my_cond`,
which has already updated the previous best result that it stored. As a
result, the check will never succeed.
Args:
cond: The base condition to check.
times (int): The number of times the base condition should pass checks
consecutively.
clear_after_trigger (bool): Whether the counter should be cleared after
the event is triggered. If :attr:`clear_after_trigger` is set to
`False`, once this event is triggered, it will trigger every time
:attr:`cond` is triggered, until :attr:`cond` fails to trigger at
some point. Defaults to `True`.
"""
def __init__(self, cond: Condition, times: int,
clear_after_trigger: bool = True):
super().__init__()
self.cond = cond
self.times = times
self.count = 0
self.clear_after_trigger = clear_after_trigger
for hook_point, method in self.cond.hooks.items():
self._hooks[hook_point] = self._create_check_method(method)
@property
def _hash_attributes(self):
return self.cond, self.times, self.clear_after_trigger
def _create_check_method(self, method):
@functools.wraps(method)
def check_fn(self, executor) -> bool:
if method(executor):
self.count += 1
if self.count >= self.times:
if self.clear_after_trigger:
self.count = 0
return True
else:
self.count = 0
return False
return types.MethodType(check_fn, self)
[docs]class once(Condition):
r"""Triggers only when the specified condition triggers for the first time.
Internally, this condition calls the
:meth:`~texar.torch.run.Executor.remove_action` method to remove itself from
the registered actions.
For example: :python:`once(iteration(5))` would only trigger on the 5th
epoch of the entire training loop.
.. warning::
Conditions are stateful objects. Using a registered condition as the
inner condition here could result in unexpected behaviors. Please refer
to :class:`consecutive` for a concrete example.
Args:
cond: The base condition to check.
"""
def __init__(self, cond: Condition):
super().__init__()
self.cond = cond
for hook_point, method in self.cond.hooks.items():
self._hooks[hook_point] = self._create_check_method(method)
@property
def _hash_attributes(self):
return self.cond
def _create_check_method(self, method):
@functools.wraps(method)
def check_fn(self, executor) -> bool:
if method(executor):
executor.remove_action()
return True
return False
return types.MethodType(check_fn, self)
[docs]class time(Condition):
def __init__(self, *, hours: Optional[float] = None,
minutes: Optional[float] = None,
seconds: Optional[float] = None,
only_training: bool = True):
super().__init__()
self.seconds = 0.0
if hours is not None:
self.seconds += hours * 3600.0
if minutes is not None:
self.seconds += minutes * 60.0
if seconds is not None:
self.seconds += seconds
self.only_training = only_training
self.start_time: Optional[float] = None
self.accumulated_time = 0.0
@property
def _hash_attributes(self):
return self.seconds, self.only_training
def _should_trigger(self) -> bool:
total_time = self.accumulated_time
if self.start_time is None:
cur_time = None
else:
cur_time = time_now()
total_time += cur_time - self.start_time
self.start_time = cur_time
if total_time >= self.seconds:
self.accumulated_time = 0.0
return True
else:
self.accumulated_time = total_time
return False
def check_training_begin(self, executor) -> bool:
self.start_time = time_now()
return False
def check_training_end(self, executor) -> bool:
return self._should_trigger()
def check_validation_begin(self, executor) -> bool:
if self.only_training and self.start_time is not None:
self.accumulated_time += time_now() - self.start_time
self.start_time = None
return self._should_trigger()
def check_validation_end(self, executor) -> bool:
if self.only_training:
self.start_time = time_now()
return False
else:
return self._should_trigger()
def check_testing_begin(self, executor) -> bool:
if self.only_training and self.start_time is not None:
self.accumulated_time += time_now() - self.start_time
self.start_time = None
return self._should_trigger()
def check_testing_end(self, executor) -> bool:
if self.only_training:
self.start_time = time_now()
return False
else:
return self._should_trigger()
def check_iteration_end(self, executor) -> bool:
return self._should_trigger()