Source code for texar.torch.run.metric.base_metric

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base classes for Executor metrics.
"""

from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Sequence, TypeVar

__all__ = [
    "Metric",
    "SimpleMetric",
    "StreamingMetric",
]

Input = TypeVar('Input')
Value = TypeVar('Value')


[docs]class Metric(Generic[Input, Value], ABC): r"""Base class of all metrics. You should not directly inherit this class, but inherit from :class:`SimpleMetric` or :class:`StreamingMetric` instead. Subclasses can override the class attributes to indicate their behaviors: - :attr:`higher_is_better`: If `True`, higher (comparison using the greater than operator ``>`` returns `True`) values are considered better metric values. If `False`, lower values are considered better. Defaults to `True`. - :attr:`required_pred`: If `True`, predicted values are required to compute the metric value. Defaults to `True`. - :attr:`requires_label`: If `True`, labels are required to compute the metric value. Defaults to `True`. Keyword Args: pred_name (str, optional): Name of the predicted value. This will be used as the key to the dictionary returned by the model. label_name (str, optional): Name of the label. This will be used as the key to the batch object returned by the dataset. Defaults to ``"label"``. higher_is_better (bool, optional): If specified, the :attr:`higher_is_better` attribute for the instance is overwritten by the specified value. """ higher_is_better: bool = True requires_pred: bool = True requires_label: bool = True def __init__(self, *, pred_name: Optional[str], label_name: Optional[str] = "label", higher_is_better: Optional[bool] = None): self.reset() if self.requires_pred and pred_name is None: raise ValueError(f"Metric {self.metric_name} requires a " f"prediction name, but None is provided") if self.requires_label and label_name is None: raise ValueError(f"Metric {self.metric_name} requires a " f"label name, but None is provided") if higher_is_better is not None: self.higher_is_better = higher_is_better if not self.requires_pred: pred_name = None if not self.requires_label: label_name = None self._pred_name = pred_name self._label_name = label_name @property def metric_name(self) -> str: r"""Name of the metric. By default, the class name is used.""" return self.__class__.__name__ @property def pred_name(self) -> Optional[str]: r"""Name of the predicted value. This will be used as the key to the dictionary returned by the model. """ return self._pred_name @property def label_name(self) -> Optional[str]: r"""Name of the label (ground truth / gold value). This will be used as the key to the batch object returned by the dataset. """ return self._label_name
[docs] @abstractmethod def reset(self) -> None: r"""Reset the internal state of the metric, and erase all previously added data points. """ raise NotImplementedError
[docs] @abstractmethod def add(self, predicted: Sequence[Input], labels: Sequence[Input]) -> None: r"""Record a data batch in the metric. Args: predicted: The list of predicted values. labels: The list of labels. """ raise NotImplementedError
[docs] @abstractmethod def value(self) -> Value: r"""Compute the metric value. Returns: The metric value. """ raise NotImplementedError
[docs] def better(self, cur: Value, prev: Value) -> Optional[bool]: r"""Compare two metric values and return which is better. Args: cur: The "current" metric value. prev: The "previous" metric value. Returns: Return value is either a `bool` or `None`. - If `True`, the current metric value is considered better. - If `False`, the previous metric value is considered better. - If `None`, the two values are considered to be the same, or uncomparable. """ result = (True if cur > prev else # type: ignore False if cur < prev else None) # type: ignore if not self.higher_is_better and result is not None: result = not result return result
[docs]class SimpleMetric(Metric[Input, Value], ABC): r"""Base class of simple metrics. Simple metrics are metrics that do not support incremental computation. The value of the metric is computed only after all data points have been added. The default implementation of :meth:`add` simply stores the predicted values and labels into lists. """ labels: List[Input] predicted: List[Input] _cached_value: Optional[Value]
[docs] def reset(self) -> None: self.labels = [] self.predicted = [] self._cached_value = None
[docs] def add(self, predicted: Sequence[Input], labels: Sequence[Input]): if len(predicted) != len(labels): raise ValueError( "Lists `predicted` and `labels` should have the same length") self.predicted.extend(predicted) self.labels.extend(labels) self._cached_value = None
[docs] def value(self): if self._cached_value is not None: return self._cached_value self._cached_value = self._value() return self._cached_value
def _value(self) -> Value: r"""Compute the metric value. This function is called in :meth:`texar.torch.run.metric.SimpleMetric.value` and the output is cached. This prevents recalculation of metrics which may be time consuming. Returns: The metric value. """ raise NotImplementedError
[docs]class StreamingMetric(Metric[Input, Value], ABC): r"""Base class of streaming metrics. Streaming metrics are metrics that support incremental computation. The value of the metric may be queried before all data points have been added, and the computation should not be expensive. The default implementation of :meth:`add` only keeps track of the number of data points added. You should override this method. """ count: int
[docs] def reset(self) -> None: self.count = 0
[docs] def add(self, predicted: Sequence[Input], labels: Sequence[Input]) -> None: if len(predicted) != len(labels): raise ValueError( "Lists `predicted` and `labels` should have the same length") self.count += len(predicted)