Skip to content

Commit

Permalink
[JAX] Replace uses of jax.xla_computation() with jax.jit().lower().
Browse files Browse the repository at this point in the history
jax.xla_computation() is deprecated in favor of jax.jit(...).lower(...).

The most common replacements are either jax.jit(...).lower(...).compiler_ir(dialect='hlo') or jax.jit(...).lower(...).cost_analysis().

PiperOrigin-RevId: 509613403
  • Loading branch information
hawkinsp authored and romanngg committed Mar 9, 2023
1 parent 5d38d3e commit c5f8eb9
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,15 +2275,12 @@ def _get_fwd(


def _get_flops(f: Callable, optimize: bool, *a, **kw) -> float:
e = jax.jit(f).lower(*a, **kw)
if optimize:
e = jax.jit(f).lower(*a, **kw).compile()
return e.cost_analysis()[0]['flops']
analysis = e.compile().cost_analysis()[0]
else:
m = jax.xla_computation(f)(*a, **kw)
client = jax.lib.xla_bridge.get_backend()
m = m.as_hlo_module()
analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
return analysis['flops']
analysis = e.cost_analysis()
return analysis['flops']



Expand Down

0 comments on commit c5f8eb9

Please sign in to comment.