diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index df694ca0e296..53919200f51f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -30,6 +30,7 @@ from jax._src import state from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -1036,6 +1037,14 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): skip_mlir_conversions.add(lax.pow_p) +def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): + return lower_fun(lax_internal._integer_pow, multiple_results=False)( + ctx, x, y=y) + + +lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule + + def _exp2_lowering_rule(ctx: LoweringRuleContext, x): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here.