Skip to content

Commit

Permalink
[Pallas] Add Mosaic lowering rule for fpowi.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565800521
  • Loading branch information
emilyfertig authored and jax authors committed Sep 15, 2023
1 parent 65de2cf commit e78d8a3
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import state
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
Expand Down Expand Up @@ -1036,6 +1037,14 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
skip_mlir_conversions.add(lax.pow_p)


def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
ctx, x, y=y)


lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule


def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
# here.
Expand Down

0 comments on commit e78d8a3

Please sign in to comment.