From e78d8a321e4f57771c596aad751cbf9982e5357c Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Sep 2023 16:00:19 -0700 Subject: [PATCH] [Pallas] Add Mosaic lowering rule for fpowi. PiperOrigin-RevId: 565800521 --- jax/_src/pallas/mosaic/lowering.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index df694ca0e296..53919200f51f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -30,6 +30,7 @@ from jax._src import state from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -1036,6 +1037,14 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): skip_mlir_conversions.add(lax.pow_p) +def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): + return lower_fun(lax_internal._integer_pow, multiple_results=False)( + ctx, x, y=y) + + +lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule + + def _exp2_lowering_rule(ctx: LoweringRuleContext, x): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here.