Skip to content

Commit

Permalink
Add jax.lax example
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 11, 2024
1 parent 134f15c commit df5f7a1
Showing 1 changed file with 152 additions and 24 deletions.
176 changes: 152 additions & 24 deletions benchmark-numba-vs-jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "4551c294",
"id": "477c6d2e",
"metadata": {},
"source": [
"# Sgkit: benchmarking Numba vs JAX\n",
Expand All @@ -13,7 +13,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "54ee793e",
"id": "c040f675",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,7 +23,7 @@
},
{
"cell_type": "markdown",
"id": "2f20edb5",
"id": "4ea0650a",
"metadata": {},
"source": [
"## Numba\n",
Expand All @@ -34,7 +34,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "0dbae777",
"id": "c9365d52",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,7 +44,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "ad31bde3",
"id": "e19d0a3c",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,17 +54,17 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "ffc03e46",
"id": "ee1b2ccd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n_variant: 10000, n_sample: 100, time: 0.003943204879760742\n",
"n_variant: 100000, n_sample: 1000, time: 0.13063406944274902\n",
"n_variant: 20000, n_sample: 5000, time: 0.21017789840698242\n",
"n_variant: 10000, n_sample: 10000, time: 0.2675461769104004\n"
"n_variant: 10000, n_sample: 100, time: 0.0037958621978759766\n",
"n_variant: 100000, n_sample: 1000, time: 0.1375119686126709\n",
"n_variant: 20000, n_sample: 5000, time: 0.20614981651306152\n",
"n_variant: 10000, n_sample: 10000, time: 0.22558307647705078\n"
]
}
],
Expand All @@ -85,7 +85,7 @@
},
{
"cell_type": "markdown",
"id": "13db463d",
"id": "365b72f0",
"metadata": {},
"source": [
"## JAX\n",
Expand All @@ -96,7 +96,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "f1f42b1e",
"id": "07ff3a47",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -107,7 +107,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "0f1bf161",
"id": "fa22b278",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -122,7 +122,7 @@
},
{
"cell_type": "markdown",
"id": "d0a429a4",
"id": "ececfa1a",
"metadata": {},
"source": [
"The user-level function is very similar to the Numba version. The main difference is that we use JAX's `vmap` and `jit` functions to vectorize and compile the `count_alleles_jax` function."
Expand All @@ -131,7 +131,7 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "a87c28db",
"id": "b5517191",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -150,7 +150,7 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "dc37ee2f",
"id": "578adcd8",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -200,17 +200,17 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "de81f08a",
"id": "b6deaa29",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n_variant: 10000, n_sample: 100, time: 0.027678966522216797\n",
"n_variant: 100000, n_sample: 1000, time: 4.066641092300415\n",
"n_variant: 20000, n_sample: 5000, time: 4.6699299812316895\n",
"n_variant: 10000, n_sample: 10000, time: 4.4864161014556885\n"
"n_variant: 10000, n_sample: 100, time: 0.028610944747924805\n",
"n_variant: 100000, n_sample: 1000, time: 3.8228919506073\n",
"n_variant: 20000, n_sample: 5000, time: 3.199921131134033\n",
"n_variant: 10000, n_sample: 10000, time: 3.4262049198150635\n"
]
}
],
Expand All @@ -231,7 +231,7 @@
},
{
"cell_type": "markdown",
"id": "ed602a14",
"id": "8285d789",
"metadata": {},
"source": [
"The JAX version is a lot slower than Numba - over an order of magnitude slower for the last three results.\n",
Expand All @@ -241,10 +241,138 @@
"But it's not clear if there is something wrong with the JAX code or whether it can't do as well as Numba for this problem."
]
},
{
"cell_type": "markdown",
"id": "48c04502",
"metadata": {},
"source": [
"## LAX\n",
"\n",
"JAX has some lower-level primitives in the LAX module that might be suitable for this problem. In particular, `jax.lax.scan` can be used for implementing a `count_alleles` function without using NumPy operations. Would this be more efficient?"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bdef139a",
"id": "990c8e52",
"metadata": {},
"outputs": [],
"source": [
"from jax import lax\n",
"\n",
"def _count_alleles(res, el):\n",
" res = res.at[el].add(1)\n",
" return res, None\n",
"\n",
"def count_alleles_lax(g, out):\n",
" counts, _ = lax.scan(_count_alleles, out, g)\n",
" return counts"
]
},
{
"cell_type": "markdown",
"id": "785de6bf",
"metadata": {},
"source": [
"Note that we pass in the output array like Numba does, rather than allocating it in the loop (like `jax.numpy` does above)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "23696f46",
"metadata": {},
"outputs": [],
"source": [
"def count_call_alleles_lax(\n",
" ds: Dataset,\n",
" *,\n",
" call_genotype: Hashable = variables.call_genotype,\n",
" merge: bool = True,\n",
") -> Dataset:\n",
" variables.validate(ds, {call_genotype: variables.call_genotype_spec})\n",
" n_alleles = ds.sizes[\"alleles\"]\n",
" G = da.asarray(ds[call_genotype])\n",
" if G.numblocks[2] > 1:\n",
" raise ValueError(\n",
" f\"Variable {call_genotype} must have only a single chunk in the ploidy dimension. \"\n",
" \"Consider rechunking to change the size of chunks.\"\n",
" )\n",
" shape = (G.chunks[0], G.chunks[1], n_alleles)\n",
"\n",
" # call vmap twice to vectorize over first two dimensions (variants, samples)\n",
" count_alleles_vectorized = jax.vmap(jax.vmap(count_alleles_lax))\n",
"\n",
" # jit compile\n",
" count_alleles_vectorized_jit = jax.jit(count_alleles_vectorized)\n",
"\n",
" # precompile...\n",
" count_alleles_vectorized_jit(np.ones((4, 4, 2), dtype=np.int8), np.zeros((4, 4, 2), dtype=np.int8)).block_until_ready()\n",
"\n",
" N = np.empty((G.chunks[0][0], G.chunks[1][0], n_alleles), dtype=np.uint8)\n",
" new_ds = create_dataset(\n",
" {\n",
" variables.call_allele_count: (\n",
" (\"variants\", \"samples\", \"alleles\"),\n",
" da.map_blocks(\n",
" count_alleles_vectorized_jit,\n",
" G,\n",
" N,\n",
" chunks=shape,\n",
" dtype=np.uint8,\n",
" drop_axis=2,\n",
" new_axis=2,\n",
" ),\n",
" )\n",
" }\n",
" )\n",
" return conditional_merge_datasets(ds, new_ds, merge)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8326ecb6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n_variant: 10000, n_sample: 100, time: 0.02508997917175293\n",
"n_variant: 100000, n_sample: 1000, time: 2.639747142791748\n",
"n_variant: 20000, n_sample: 5000, time: 2.751932144165039\n",
"n_variant: 10000, n_sample: 10000, time: 2.7614219188690186\n"
]
}
],
"source": [
"for n_variant, n_sample in matrix:\n",
" ds = sg.simulate_genotype_call_dataset(\n",
" n_variant=n_variant, n_sample=n_sample, missing_pct=0.0\n",
" )\n",
" ds = ds.chunk({\"variants\": 10000, \"samples\": 1000})\n",
"\n",
" ds = count_call_alleles_lax(ds)\n",
" start = time.time()\n",
" ds = ds.load()\n",
" end = time.time()\n",
"\n",
" print(f\"n_variant: {n_variant}, n_sample: {n_sample}, time: {end - start}\")"
]
},
{
"cell_type": "markdown",
"id": "23768bcb",
"metadata": {},
"source": [
"The LAX version is a bit faster than the JAX NumPy version - but not much, and is still around an order of magnitude slower than Numba."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d008de7f",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -453,7 +581,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cdd17f41",
"id": "8e6c899d",
"metadata": {},
"outputs": [],
"source": []
Expand Down

0 comments on commit df5f7a1

Please sign in to comment.