From 34825d314a8e9f679dd94cf846584791a61e40fd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Oct 2023 02:59:20 -0700 Subject: [PATCH] Ignore incorrect type annotations related to jax dtypes PiperOrigin-RevId: 571882947 --- examples/flax_integration/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flax_integration/main.py b/examples/flax_integration/main.py index 9e2fee0..be5f22a 100644 --- a/examples/flax_integration/main.py +++ b/examples/flax_integration/main.py @@ -151,7 +151,7 @@ def _loss_fn(params): scores = model.apply( flax.core.copy(model_state, {"params": params}), inputs ) - loss = loss_fn(scores, labels, where=mask, reduce_fn=jnp.mean) + loss = loss_fn(scores, labels, where=mask, reduce_fn=jnp.mean) # pytype: disable=wrong-arg-types # jnp-type return loss params = model_state["params"]