Skip to content

Commit

Permalink
improve: replace class configs
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Nov 30, 2023
1 parent aff0288 commit f30633e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 27 deletions.
6 changes: 2 additions & 4 deletions lantern/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from typing import Optional

import torch.utils.tensorboard

from lantern import FunctionalBase


class EarlyStopping(FunctionalBase):
class EarlyStopping(FunctionalBase, arbitrary_types_allowed=True):
"""Keeps track of the best score and how long ago it was calculated."""

tensorboard_logger: torch.utils.tensorboard.SummaryWriter
best_score: Optional[float] = None
scores_since_improvement: int = -1

class Config:
arbitrary_types_allowed = True

def score(self, value):
if self.best_score is None or value > self.best_score:
return self.replace(
Expand Down
20 changes: 5 additions & 15 deletions lantern/metric.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import numpy as np
import functools
from typing import Any, Callable, List, Optional, Union

import numpy as np

from lantern import FunctionalBase, star
from typing import Callable, Any, Optional, List, Union


class MapMetric(FunctionalBase):
map_fn: Optional[Callable[..., Any]]
state: List[Any]

class Config:
arbitrary_types_allowed = True
allow_mutation = True

def __init__(self, state=list(), map_fn=None):
super().__init__(
state=state,
Expand Down Expand Up @@ -101,14 +99,10 @@ def __iter__(self):
Metric = MapMetric


class ReduceMetric(FunctionalBase):
class ReduceMetric(FunctionalBase, arbitrary_types_allowed=True):
reduce_fn: Callable[..., Any]
state: Any

class Config:
arbitrary_types_allowed = True
allow_mutation = True

def update_(self, *args, **kwargs):
self.state = self.reduce_fn(self.state, *args, **kwargs)
return self
Expand Down Expand Up @@ -141,10 +135,6 @@ class AggregateMetric(FunctionalBase):
metric: Union[MapMetric, ReduceMetric]
aggregate_fn: Callable

class Config:
arbitrary_types_allowed = True
allow_mutation = True

def map(self, fn):
return self.replace(aggregate_fn=lambda state: fn(self.aggregate_fn(state)))

Expand Down
12 changes: 4 additions & 8 deletions lantern/metric_table.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import textwrap
from typing import Any, Dict, Union

import pandas as pd
from lantern import FunctionalBase
from typing import Dict, Union, Any

# from wire_damage.tools import MapMetric, ReduceMetric, AggregateMetric
from lantern import FunctionalBase


class MetricTable(FunctionalBase):
class MetricTable(FunctionalBase, arbitrary_types_allowed=True):
name: str
metrics: Dict[str, Any]
# metrics: Dict[str, Union[MapMetric, ReduceMetric, AggregateMetric]]

class Config:
arbitrary_types_allowed = True

def __init__(self, name, metrics):
super().__init__(
Expand Down

0 comments on commit f30633e

Please sign in to comment.