Skip to content

Commit

Permalink
Update clu to use inspect.get_annotations(cls) oveer cls.__annotations__
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723027214
  • Loading branch information
CLU Authors authored and copybara-github committed Feb 4, 2025
1 parent 43acbbd commit d3278c2
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def evaluate(variables_p, test_ds):
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import inspect
from typing import Any, TypeVar, Protocol

from absl import logging
Expand Down Expand Up @@ -536,7 +537,8 @@ def empty(cls: type[C]) -> C:
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
**{
metric_name: metric.empty()
for metric_name, metric in cls.__annotations__.items()
for metric_name, metric
in inspect.get_annotations(cls, eval_str=True).items()
})

@classmethod
Expand All @@ -546,7 +548,8 @@ def _from_model_output(cls: type[C], **kwargs) -> C:
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
**{
metric_name: metric.from_model_output(**kwargs)
for metric_name, metric in cls.__annotations__.items()
for metric_name, metric
in inspect.get_annotations(cls, eval_str=True).items()
})

@classmethod
Expand Down

0 comments on commit d3278c2

Please sign in to comment.