Skip to content

Commit

Permalink
Expand device_put benchmarks to run with different numbers of array…
Browse files Browse the repository at this point in the history
…s and input types

For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
  • Loading branch information
junwhanahn authored and jax authors committed Jun 9, 2024
1 parent a8246ea commit 6617a0d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,29 @@ def host_local_array_to_global_array(state):
(input_data, input_data), global_mesh, (in_pspec, in_pspec))

@google_benchmark.register
def device_put(state):
x = np.array(1, np.int32)
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
def device_put_from_numpy_array(state):
x = [np.array(1, np.int32)] * state.range(0)
while state:
_ = jax.device_put(x).block_until_ready()
_ = jax.block_until_ready(jax.device_put(x))


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
def device_put_from_jax_array(state):
x = [np.array(1, np.int32)] * state.range(0)
x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0]))
d = jax.devices()[1]
while state:
_ = jax.block_until_ready(jax.device_put(x, device=d))


@google_benchmark.register
Expand Down

0 comments on commit 6617a0d

Please sign in to comment.