Source code for ml_logger.metrics

"""Implementation of different type of metrics."""

import operator
from typing import Any, Iterable, Optional, Union

from ml_logger.types import ComparisonOpType, LogType, NumType, ValueType


[docs]class BaseMetric: """Base Metric class. This class is not to be used directly.""" def __init__(self, name: str): """All metrics extend this class. It is not to be used directly Args: name (str): Name of the metric """ self.name = name self.val: ValueType self.reset()
[docs] def reset(self) -> None: """Reset the metric to the default value.""" self.val = 0
[docs] def update(self, val: Any) -> None: """Update the metric using the current val. Args: val (Any): Current value. This value is used to update the metric """ pass
[docs] def get_val(self) -> ValueType: """Get the current value of the metric.""" return self.val
def __str__(self) -> str: return str(self.get_val()) def __repr__(self) -> str: return f"{self.__class__} {self.__dict__}"
[docs]class CurrentMetric(BaseMetric): """Metric to track only the most recent value. Args: BaseMetric: Base metric class """ def __init__(self, name: str): super().__init__(name)
[docs] def update(self, val: ValueType) -> None: """Update the metric using the current val. Args: val (Any): Current value. The metric value is set to this value """ self.val = val
[docs]class ConstantMetric(BaseMetric): """Metric to track one fixed value. This is generally used for logging strings Args: BaseMetric: Base metric class """ def __init__(self, name: str, val: ValueType): self.name = name self.val = val
[docs] def reset(self) -> None: """Do nothing for the constant metrics.""" return None
[docs] def update(self, val: Optional[ValueType] = None) -> None: """Do nothing for the constant metrics. Args: val (Any): This value is ignored """ return None
[docs]class ComparisonMetric(BaseMetric): """Metric to track the min/max value. This is generally used for logging best accuracy, least loss, etc. Args: BaseMetric: Base metric class """ def __init__( self, name: str, default_val: ValueType, comparison_op: ComparisonOpType ): """Metric to track the min/max value. This is generally used for logging best accuracy, least loss, etc. Args: name (str): Name of the metric default_val (ValueType): Default value to initialise the metric comparison_op (ComparisonOpType): Operator to compare the current value with the incoming value. If comparison_op(current_val, new_val) is true, we update the current value. """ self.name = name self._default_val = default_val self.comparison_op = comparison_op self.val = default_val
[docs] def reset(self) -> None: """Reset the metric to the default value.""" self.val = self._default_val
[docs] def update(self, val: ValueType) -> None: """Use the comparison operator to decide which value to keep. If the output of self.comparison_op(val, self) Args: val (ValueType): Value to compare the current value with. If comparison_op(current_val, new_val) is true, we update the current value. """ if self.comparison_op(self.val, val): self.val = val
[docs]class MaxMetric(ComparisonMetric): """Metric to track the max value. This is generally used for logging best accuracy, etc. Args: ComparisonMetric: Comparison metric class """ def __init__(self, name: str): """Metric to track the max value. This is generally used for logging best accuracy, etc. Args: name (str): Name of the metric """ super().__init__( name=name, default_val=float("-inf"), comparison_op=operator.lt )
[docs]class MinMetric(ComparisonMetric): """Metric to track the min value. This is generally used for logging least loss, etc. Args: ComparisonMetric: Comparison metric class """ def __init__(self, name: str): """Metric to track the min value. This is generally used for logging least loss, etc. Args: name (str): Name of the metric """ super().__init__(name=name, default_val=float("inf"), comparison_op=operator.gt)
[docs]class AverageMetric(BaseMetric): """Metric to track the average value. This is generally used for logging strings Args: BaseMetric: Base metric class """ def __init__(self, name: str): self.name = name self.val: float self.avg: float self.sum: float self.count: float self.reset()
[docs] def reset(self) -> None: """Reset Metric.""" self.val = 0.0 self.avg = 0.0 self.sum = 0.0 self.count = 0.0
[docs] def update(self, val: NumType, n: int = 1) -> None: """Update the metric. Update the metric using the current average value and the number of samples used to compute the average value Args: val (NumType): current average value n (int, optional): Number of samples used to compute the average. Defaults to 1 """ self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
[docs] def get_val(self) -> float: """Get the current average value.""" return self.avg
[docs]class SumMetric(AverageMetric): """Metric to track the sum value. Args: BaseMetric: Base metric class """ def __init__(self, name: str): super().__init__(name)
[docs] def get_val(self) -> float: """Get the current sum value.""" return self.sum
[docs]class MetricDict: """Class that wraps over a collection of metrics.""" def __init__(self, metric_list: Iterable[BaseMetric]): """Class that wraps over a collection of metrics. Args: metric_list (Iterable[BaseMetric]): list of metrics to wrap over """ self._metrics_dict = {metric.name: metric for metric in metric_list}
[docs] def reset(self) -> None: """Reset all the metrics to default values.""" for key in self._metrics_dict: self._metrics_dict[key].reset()
[docs] def update(self, metrics_dict: Union[LogType, "MetricDict"]) -> None: """Update all the metrics using the current values. Args: metrics_dict (Union[LogType, MetricDict]): Current value of metrics """ if isinstance(metrics_dict, MetricDict): metrics_dict = metrics_dict.to_dict() for key, val in metrics_dict.items(): if key in self._metrics_dict: if isinstance(val, (str, float, int)): self._metrics_dict[key].update(val) else: self._metrics_dict[key].update(*val)
def __str__(self) -> str: return "\n".join([repr(val) for key, val in self._metrics_dict.items()])
[docs] def to_dict(self) -> LogType: """Convert the metrics into a dictionary for `LogBook`. Returns: LogType: Metric data in as a dictionary """ return {key: val.get_val() for key, val in self._metrics_dict.items()}