Skip to content

Commit

Permalink
rewrite mulitask metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Oct 2, 2024
1 parent 46025c1 commit 73da69a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 200 deletions.
11 changes: 1 addition & 10 deletions danling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,7 @@

from danling import metrics, modules, optim, registry, runner, tensors, typing, utils

from .metrics import (
AverageMeter,
AverageMeters,
MetricMeter,
MetricMeters,
MultiTaskAverageMeters,
MultiTaskMetricMeters,
)
from .metrics import AverageMeter, AverageMeters, MetricMeter, MetricMeters
from .optim import LRScheduler
from .registry import GlobalRegistry, Registry
from .runner import AccelerateRunner, BaseRunner, TorchRunner
Expand Down Expand Up @@ -65,10 +58,8 @@
"MultiTaskMetrics",
"MetricMeter",
"MetricMeters",
"MultiTaskMetricMeters",
"AverageMeter",
"AverageMeters",
"MultiTaskAverageMeters",
"NestedTensor",
"PNTensor",
"tensor",
Expand Down
9 changes: 4 additions & 5 deletions danling/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,23 @@

from lazy_imports import try_import

from .average_meter import AverageMeter, AverageMeters, MultiTaskAverageMeters
from .metric_meter import MetricMeter, MetricMeters, MultiTaskMetricMeters
from .average_meter import AverageMeter, AverageMeters
from .metric_meter import MetricMeter, MetricMeters
from .preprocesses import preprocess_binary, preprocess_multiclass, preprocess_multilabel, preprocess_regression

with try_import() as lazy_import:
from .functional import accuracy, auprc, auroc, f1_score, mcc, pearson, r2_score, rmse, spearman
from .metrics import Metrics, MultiTaskMetrics
from .metrics import Metrics
from .multi_task import MultiTaskMetrics


__all__ = [
"Metrics",
"MultiTaskMetrics",
"MetricMeter",
"MetricMeters",
"MultiTaskMetricMeters",
"AverageMeter",
"AverageMeters",
"MultiTaskAverageMeters",
"regression_metrics",
"binary_metrics",
"multiclass_metrics",
Expand Down
79 changes: 0 additions & 79 deletions danling/metrics/average_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,82 +211,3 @@ def set(self, name: str, meter: AverageMeter) -> None: # pylint: disable=W0237
if not isinstance(meter, AverageMeter):
raise ValueError(f"Expected meter to be an instance of AverageMeter, but got {type(meter)}")
super().set(name, meter)


class MultiTaskAverageMeters(MultiTaskDict):
r"""
Manages multiple average meters in one object with multi-task support.
See Also:
[`AverageMeter`]: Computes and stores the average and current value.
[`AverageMeters`]: Manage multiple average meters in one object.
[`MetricMeters`]: Manage multiple metric meters in one object.
Examples:
>>> meters = MultiTaskAverageMeters()
>>> meters.update({"loss": 0.6, "dataset1.cls.auroc": 0.7, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.9})
>>> f"{meters:.4f}"
'loss: 0.6000 (0.6000)\ndataset1.cls.auroc: 0.7000 (0.7000)\ndataset1.reg.r2: 0.8000 (0.8000)\ndataset2.r2: 0.9000 (0.9000)'
>>> meters['loss'].update(0.9, n=1)
>>> f"{meters:.4f}"
'loss: 0.9000 (0.7500)\ndataset1.cls.auroc: 0.7000 (0.7000)\ndataset1.reg.r2: 0.8000 (0.8000)\ndataset2.r2: 0.9000 (0.9000)'
>>> meters.sum.dict()
{'loss': 1.5, 'dataset1': {'cls': {'auroc': 0.7}, 'reg': {'r2': 0.8}}, 'dataset2': {'r2': 0.9}}
>>> meters.count.dict()
{'loss': 2, 'dataset1': {'cls': {'auroc': 1}, 'reg': {'r2': 1}}, 'dataset2': {'r2': 1}}
>>> meters.reset()
>>> f"{meters:.4f}"
'loss: 0.0000 (nan)\ndataset1.cls.auroc: 0.0000 (nan)\ndataset1.reg.r2: 0.0000 (nan)\ndataset2.r2: 0.0000 (nan)'
>>> meters = MultiTaskAverageMeters(return_average=True)
>>> meters.update({"loss": 0.6, "dataset1.a.auroc": 0.7, "dataset1.b.auroc": 0.8, "dataset2.auroc": 0.9})
>>> f"{meters:.4f}"
'loss: 0.6000 (0.6000)\ndataset1.a.auroc: 0.7000 (0.7000)\ndataset1.b.auroc: 0.8000 (0.8000)\ndataset2.auroc: 0.9000 (0.9000)'
>>> meters.update({"loss": 0.9, "dataset1.a.auroc": 0.8, "dataset1.b.auroc": 0.9, "dataset2.auroc": 1.0})
>>> f"{meters:.4f}"
'loss: 0.9000 (0.7500)\ndataset1.a.auroc: 0.8000 (0.7500)\ndataset1.b.auroc: 0.9000 (0.8500)\ndataset2.auroc: 1.0000 (0.9500)'
""" # noqa: E501

@property
def sum(self) -> NestedDict[str, float]:
return NestedDict({key: meter.sum for key, meter in self.all_items()})

@property
def count(self) -> NestedDict[str, int]:
return NestedDict({key: meter.count for key, meter in self.all_items()})

def update(self, *args: Dict, **values: float) -> None: # pylint: disable=W0237
r"""
Updates the average and current value in all meters.
Args:
values: Dict of values to be added to the average.
n: Number of values to be added.
Raises:
ValueError: If the value is not an instance of (int, float, Mapping).
""" # noqa: E501

if args:
if len(args) > 1:
raise ValueError("Expected only one positional argument, but got multiple.")
values = args[0].update(values) or args[0] if values else args[0]

for meter, value in values.items():
if not isinstance(value, (int, float, Mapping)):
raise ValueError(f"Expected values to be int, float, or a Mapping, but got {type(value)}")
self[meter].update(value)

# evil hack, as the default_factory must not be set to make `NestedDict` happy
# this have some side effects, it will break attribute style intermediate nested dict auto creation
# but everything has a price
def get(self, name: Any, default=None) -> Any:
if not name.startswith("_") and not name.endswith("_"):
return self.setdefault(name, AverageMeter())
return super().get(name, default)

def set(self, name: str, meter: AverageMeter | AverageMeters) -> None: # pylint: disable=W0237
if not isinstance(meter, (AverageMeter, AverageMeters)):
raise ValueError(
f"Expected meter to be an instance of AverageMeter or AverageMeters, but got {type(meter)}"
)
super().set(name, meter)
56 changes: 1 addition & 55 deletions danling/metrics/metric_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,58 +212,4 @@ class MultiTaskMetricMeters(MultiTaskAverageMeters):
ValueError: Metric loss not found in ...
""" # noqa: E501

def __init__(self, *args, **kwargs):
super().__init__(*args, default_factory=MultiTaskMetricMeters, **kwargs)

def update( # type: ignore[override] # pylint: disable=W0221
self,
values: Mapping[str, Tuple[Tensor | NestedTensor | Sequence, Tensor | NestedTensor | Sequence]],
) -> None:
r"""
Updates the average and current value in all meters.
Args:
input: Input values to compute the metrics.
target: Target values to compute the metrics.
"""

for metric, value in values.items():
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetricMeters):
for met in self[metric].all_values():
if isinstance(value, Mapping):
met.update(**value)
elif isinstance(value, Sequence):
met.update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (MetricMeters, MetricMeter)):
if isinstance(value, Mapping):
self[metric].update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, "
f"or MetricMeter, but got {type(self[metric])}"
)

# MultiTaskAverageMeters.get is hacked
def get(self, name: Any, default=None) -> Any:
return MultiTaskDict.get(self, name, default)

def set( # pylint: disable=W0237
self,
name: str,
metric: MetricMeter | MetricMeters | Callable, # type: ignore[override]
) -> None:
if callable(metric):
metric = MetricMeter(metric)
if not isinstance(metric, (MetricMeter, MetricMeters)):
raise ValueError(
f"Expected {metric} to be an instance of MetricMeter or MetricMeters, but got {type(metric)}"
)
super().set(name, metric)
pass
52 changes: 1 addition & 51 deletions danling/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,54 +416,4 @@ class MultiTaskMetrics(MultiTaskDict):
ValueError: Metric loss not found in ...
""" # noqa: E501

def __init__(self, *args, **kwargs):
super().__init__(*args, default_factory=MultiTaskMetrics, **kwargs)

def update(
self,
values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence]],
) -> None:
r"""
Updates the average and current value in all metrics.
Args:
values: Dict of values to be added to the average.
Raises:
ValueError: If the value is not an instance of (Mapping).
"""

for metric, value in values.items():
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetrics):
for name, met in self[metric].items():
if name in value:
val = value[name]
if isinstance(value, Mapping):
met.update(**val)
elif isinstance(value, Sequence):
met.update(*val)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (Metrics, Metric)):
if isinstance(value, Mapping):
self[metric].update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetrics, Metrics, or Metric, "
"but got {type(self[metric])}"
)

def set( # pylint: disable=W0237
self,
name: str,
metric: Metrics | Metric, # type: ignore[override]
) -> None:
if not isinstance(metric, (Metrics, Metric)):
raise ValueError(f"Expected {metric} to be an instance of Metrics or Metric, but got {type(metric)}")
super().set(name, metric)
pass
119 changes: 119 additions & 0 deletions danling/metrics/multi_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# DanLing
# Copyright (C) 2022-Present DanLing

# This program is free software: you can redistribute it and/or modify
# it under the terms of the following licenses:
# - The Unlicense
# - GNU Affero General Public License v3.0 or later
# - GNU General Public License v2.0 or later
# - BSD 4-Clause "Original" or "Old" License
# - MIT License
# - Apache License 2.0

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the LICENSE file for more details.

# pylint: disable=redefined-builtin
from __future__ import annotations

from collections.abc import Mapping, Sequence

from torch import Tensor
from torcheval.metrics import Metric as EvalMetric
from torchmetrics import Metric as TorchMetric

from danling.tensors import NestedTensor

from .metric_meter import MetricMeter
from .metrics import Metrics
from .utils import MultiTaskDict


class MultiTaskMetrics(MultiTaskDict):
r"""
Examples:
>>> from danling.metrics.functional import auroc, auprc, pearson, spearman, accuracy, mcc
>>> from torcheval import
>>> metrics = MultiTaskMetrics()
>>> metrics.dataset1.cls = Metrics(auroc=auroc, auprc=auprc)
>>> metrics.dataset1.reg = Metrics(pearson=pearson, spearman=spearman)
>>> metrics.dataset2 = Metrics(auroc=auroc, auprc=auprc)
>>> metrics
MultiTaskMetrics(<class 'danling.metrics.multi_task.MultiTaskMetrics'>,
('dataset1'): MultiTaskMetrics(<class 'danling.metrics.multi_task.MultiTaskMetrics'>,
('cls'): Metrics('auroc', 'auprc')
('reg'): Metrics('pearson', 'spearman')
)
('dataset2'): Metrics('auroc', 'auprc')
)
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset1.reg": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0.2, 0.3, 0.5, 0.7]}, "dataset2": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 1, 0, 1]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)\ndataset1.reg: pearson: 0.9691 (0.9691)\tspearman: 1.0000 (1.0000)\ndataset2: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)'
>>> metrics.setattr("return_average", True)
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset1.reg": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0.2, 0.4, 0.6, 0.8]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 1, 0]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)'
>>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [1, 0, 1, 0]}}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.2500 (0.5286)\tauprc: 0.5000 (0.4789)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)'
>>> metrics.update(dict(loss="")) # doctest: +ELLIPSIS
Traceback (most recent call last):
ValueError: Metric loss not found in ...
""" # noqa: E501

def __init__(self, *args, **kwargs):
super().__init__(*args, default_factory=MultiTaskMetrics, **kwargs)

def update(
self,
values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence]],
) -> None:
r"""
Updates the average and current value in all metrics.
Args:
values: Dict of values to be added to the average.
Raises:
ValueError: If the value is not an instance of (Mapping).
"""

for metric, value in values.items():
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetrics):
for name, met in self[metric].items():
if name in value:
val = value[name]
if isinstance(value, Mapping):
met.update(**val)
elif isinstance(value, Sequence):
met.update(*val)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (Metrics, MetricMeter, TorchMetric, EvalMetric)):
if isinstance(value, Mapping):
self[metric].update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetrics, Metrics, or Metric, "
"but got {type(self[metric])}"
)

def set( # pylint: disable=W0237
self,
name: str,
metric: Metrics | MetricMeter | TorchMetric | EvalMetric, # type: ignore[override]
) -> None:
if not isinstance(metric, (Metrics, MetricMeter, TorchMetric, EvalMetric)):
raise ValueError(
f"Expected {metric} to be an instance of Metrics, MetricMeter, torcheval.Metric, or torcemetric.Metric",
f"but got {type(metric)}",
)
super().set(name, metric)

0 comments on commit 73da69a

Please sign in to comment.