Skip to content

Commit

Permalink
Change the count type from int32 to float32 to avoid overflows.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723026512
  • Loading branch information
CLU Authors authored and copybara-github committed Feb 4, 2025
1 parent 43acbbd commit cfac12b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,20 +784,21 @@ class Average(Metric):

@classmethod
def empty(cls) -> Average:
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.float32))

@classmethod
def from_model_output(
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> Average:
values, mask = _broadcast_masks(values, mask)
return cls(
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
total=jnp.where(mask, values, jnp.zeros_like(values)).sum().astype(
jnp.float32),
count=jnp.where(
mask,
jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32),
).sum(),
).sum().astype(jnp.float32),
)

def merge(self, other: Average) -> Average:
Expand Down

0 comments on commit cfac12b

Please sign in to comment.