Skip to content

Commit

Permalink
add while loop support
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jan 8, 2025
1 parent a0c9176 commit 2c59a64
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
33 changes: 27 additions & 6 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,27 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):

@PlxprInterpreter.register_primitive(while_loop_prim)
def handle_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle a while loop primitive."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state)
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts_body, *abstract_shapes, *init_state
)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state
)

return while_loop_prim.bind(
*invals,
Expand All @@ -442,6 +454,7 @@ def handle_while_loop(
body_slice=body_slice,
cond_slice=cond_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand Down Expand Up @@ -483,16 +496,24 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params):


def flatten_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle the while loop by a flattened python strategy."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes_slice = invals[abstract_shapes_slice]

fn_res = init_state
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)
while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes_slice, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes_slice, *fn_res)

return fn_res

Expand Down
28 changes: 22 additions & 6 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,16 +411,26 @@ def _get_while_loop_qfunc_prim():
while_loop_prim.multiple_results = True

@while_loop_prim.def_impl
def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
def _(

Check notice on line 414 in pennylane/compiler/qjit_api.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/qjit_api.py#L414

Too many arguments (6/5) (too-many-arguments)
*args,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):

jaxpr_consts_body = args[body_slice]
jaxpr_consts_cond = args[cond_slice]
init_state = args[args_slice]

abstract_shapes = args[abstract_shapes_slice]
# If cond_fn(*init_state) is False, return the initial state
fn_res = init_state
while jax.core.eval_jaxpr(jaxpr_cond_fn, jaxpr_consts_cond, *fn_res)[0]:
fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, jaxpr_consts_body, *fn_res)
fn_res = jax.core.eval_jaxpr(
jaxpr_body_fn, jaxpr_consts_body, *abstract_shapes, *fn_res
)

return fn_res

Expand Down Expand Up @@ -461,26 +471,32 @@ def _call_capture_enabled(self, *init_state):

while_loop_prim = _get_while_loop_qfunc_prim()

abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(init_state)

flat_body_fn = FlatFn(self.body_fn)
jaxpr_body_fn = jax.make_jaxpr(flat_body_fn)(*init_state)
jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn)(*init_state)
jaxpr_body_fn = jax.make_jaxpr(flat_body_fn, abstracted_axes=abstracted_axes)(*init_state)
jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn, abstracted_axes=abstracted_axes)(*init_state)

n_bf_c = len(jaxpr_body_fn.consts)
n_cf_c = len(jaxpr_cond_fn.consts)
end_abstract_shapes = -len(abstract_shapes) if abstract_shapes else None
body_consts = slice(0, n_bf_c)
cond_consts = slice(n_bf_c, n_bf_c + n_cf_c)
args_slice = slice(n_cf_c + n_bf_c, None)
args_slice = slice(n_cf_c + n_bf_c, end_abstract_shapes)
abstract_shapes_slice = slice(end_abstract_shapes, None) if abstract_shapes else slice(0, 0)

flat_args, _ = jax.tree_util.tree_flatten(init_state)
results = while_loop_prim.bind(
*jaxpr_body_fn.consts,
*jaxpr_cond_fn.consts,
*flat_args,
*abstract_shapes,
jaxpr_body_fn=jaxpr_body_fn.jaxpr,
jaxpr_cond_fn=jaxpr_cond_fn.jaxpr,
body_slice=body_consts,
cond_slice=cond_consts,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)
assert flat_body_fn.out_tree is not None, "Should be set when constructing the jaxpr"
return jax.tree_util.tree_unflatten(flat_body_fn.out_tree, results)
Expand Down

0 comments on commit 2c59a64

Please sign in to comment.