From f30633ed4ea38c975f2f9f1de8a859ed8852a7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20L=C3=B6wenstr=C3=B6m?= Date: Thu, 30 Nov 2023 19:04:43 +0100 Subject: [PATCH] improve: replace class configs --- lantern/early_stopping.py | 6 ++---- lantern/metric.py | 20 +++++--------------- lantern/metric_table.py | 12 ++++-------- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/lantern/early_stopping.py b/lantern/early_stopping.py index f69e131..383d6d8 100644 --- a/lantern/early_stopping.py +++ b/lantern/early_stopping.py @@ -1,19 +1,17 @@ from typing import Optional + import torch.utils.tensorboard from lantern import FunctionalBase -class EarlyStopping(FunctionalBase): +class EarlyStopping(FunctionalBase, arbitrary_types_allowed=True): """Keeps track of the best score and how long ago it was calculated.""" tensorboard_logger: torch.utils.tensorboard.SummaryWriter best_score: Optional[float] = None scores_since_improvement: int = -1 - class Config: - arbitrary_types_allowed = True - def score(self, value): if self.best_score is None or value > self.best_score: return self.replace( diff --git a/lantern/metric.py b/lantern/metric.py index c6a0c7d..a382c55 100644 --- a/lantern/metric.py +++ b/lantern/metric.py @@ -1,17 +1,15 @@ -import numpy as np import functools +from typing import Any, Callable, List, Optional, Union + +import numpy as np + from lantern import FunctionalBase, star -from typing import Callable, Any, Optional, List, Union class MapMetric(FunctionalBase): map_fn: Optional[Callable[..., Any]] state: List[Any] - class Config: - arbitrary_types_allowed = True - allow_mutation = True - def __init__(self, state=list(), map_fn=None): super().__init__( state=state, @@ -101,14 +99,10 @@ def __iter__(self): Metric = MapMetric -class ReduceMetric(FunctionalBase): +class ReduceMetric(FunctionalBase, arbitrary_types_allowed=True): reduce_fn: Callable[..., Any] state: Any - class Config: - arbitrary_types_allowed = True - allow_mutation = True - def update_(self, *args, **kwargs): self.state = self.reduce_fn(self.state, *args, **kwargs) return self @@ -141,10 +135,6 @@ class AggregateMetric(FunctionalBase): metric: Union[MapMetric, ReduceMetric] aggregate_fn: Callable - class Config: - arbitrary_types_allowed = True - allow_mutation = True - def map(self, fn): return self.replace(aggregate_fn=lambda state: fn(self.aggregate_fn(state))) diff --git a/lantern/metric_table.py b/lantern/metric_table.py index 7241dbe..2152dd5 100644 --- a/lantern/metric_table.py +++ b/lantern/metric_table.py @@ -1,18 +1,14 @@ import textwrap +from typing import Any, Dict, Union + import pandas as pd -from lantern import FunctionalBase -from typing import Dict, Union, Any -# from wire_damage.tools import MapMetric, ReduceMetric, AggregateMetric +from lantern import FunctionalBase -class MetricTable(FunctionalBase): +class MetricTable(FunctionalBase, arbitrary_types_allowed=True): name: str metrics: Dict[str, Any] - # metrics: Dict[str, Union[MapMetric, ReduceMetric, AggregateMetric]] - - class Config: - arbitrary_types_allowed = True def __init__(self, name, metrics): super().__init__(