Skip to content

Commit

Permalink
Updating terminology in README
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 3, 2023
1 parent 06436b5 commit 6587a1c
Showing 1 changed file with 19 additions and 40 deletions.
59 changes: 19 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ wanted to emphasize a few points to consider.
The files in this repo come in three categories:

1. In the root directory, there are the standard packaging files like a
`setup.py` and `pyproject.toml`. Most of this setup is pretty standard, but
I'll highlight some of the unique elements in the packaging section below.
For example, we'll use a slightly strange combination of PEP-517/518 and
CMake to build the extensions. This isn't strictly necessary, but it's the
easiest packaging setup that I've been able to put together.
`pyproject.toml`. Most of this setup is pretty standard, but
I'll highlight some unique elements in the packaging section below.

2. Next, the `src/kepler_jax` directory is a Python module with the definition
of our JAX primitive roughly following the JAX [How primitives
Expand Down Expand Up @@ -260,17 +257,12 @@ pybind11 since that's what I'm most familiar with. The [LAPACK ops in
jaxlib][jaxlib-lapack] are implemented using Cython if you'd like to see an
example of how to do that.

Another choice that I've made is to use [CMake](https://cmake.org) to build the
extensions. It would be totally possible (and perhaps preferable if you only
support CPU usage) to stick to just using setuptools directly, but setuptools
doesn't seem to have great support for compiling CUDA extensions so that's why I
settled on CMake. In the end, it's not too painful since CMake can be included
as a build dependency in `pyproject.toml` so users won't have to install it
separately. Another build option would be to use [bazel](https://bazel.build) to
compile the code, like the JAX project, but I don't have any experience with it
so I decided to stick with what I know. _The key point is that we're just
compiling a regular old Python module so you can use whatever infrastructure
you're familiar with!_
Another choice that I've made is to use [scikit-build-core](scikit-build-core)
and [CMake](https://cmake.org) to build the extensions. Another build option
would be to use [bazel](https://bazel.build) to compile the code, like the JAX
project, but I don't have any experience with it, so I decided to stick with
what I know. _The key point is that we're just compiling a regular old Python
module, so you can use whatever infrastructure you're familiar with!_

With these choices out of the way, the boilerplate code required to define the
interface is, using the `cpu_kepler` function defined in the previous section as
Expand Down Expand Up @@ -305,20 +297,13 @@ this.
With that out of the way, the actual build routine is defined in the following
files:
- In `./pyproject.toml`, we specify that `pybind11` and `cmake` are required
build dependencies and that we'll use `setuptools.build_meta` as the build
backend.
- In `./pyproject.toml`, we specify that `pybind11` and `scikit-build-core` are
required build dependencies and that we'll use `scikit-build-core` as the
build backend.
- `setup.py` is a pretty typical setup file with a custom class for building the
extensions that executes CMake for the actual compilation step. This does
include some extra configuration arguments for CMake to make sure that it uses
the correct Python libraries and installs the compiled objects to the right
place. It might be possible to use something like [scikit-build][scikit-build]
to replace this step, but I struggled to get it working.
- Finally, `CMakeLists.txt` defines the build process for CMake using
[pybind11's support for CMake builds][pybind11-cmake]. This will also,
optionally, build the GPU ops as discussed below.
- Then, `CMakeLists.txt` defines the build process for CMake using [pybind11's
support for CMake builds][pybind11-cmake]. This will also, optionally, build
the GPU ops as discussed below.
With these files in place, we can now compile our XLA custom call ops using
Expand Down Expand Up @@ -605,18 +590,12 @@ __global__ void kepler_kernel(
## Building & packaging for the GPU

Since we're already using CMake to build our project, it's not too hard to add
support for CUDA. I've chosen to enable GPU builds by the environment variable
`KEPLER_JAX_CUDA=yes` that you'll see in both `setup.py` and `CMakeLists.txt`.
Other than conditionally adding an `Extension` in `setup.py`, everything else on
the Python side is the same. In `CMakeLists.txt`, we also add a conditional:
support for CUDA. I've chosen to enable GPU builds whenever CMake can detect
CUDA support using `CheckLanguage` in `CMakelists.txt`:

```cmake
if (KEPLER_JAX_CUDA)
enable_language(CUDA)
# ...
else()
message(STATUS "Building without CUDA")
endif()
include(CheckLanguage)
check_language(CUDA)
```

Then, to expose this to JAX, we need to update the translation rule from above as follows:
Expand Down Expand Up @@ -700,6 +679,6 @@ Colab:
[kepler-h]: https://github.com/dfm/extending-jax/blob/main/lib/kepler.h
[capsule]: https://docs.python.org/3/c-api/capsule.html "Capsules"
[jaxlib-lapack]: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx "jax/lapack.pyx"
[scikit-build]: https://scikit-build.readthedocs.io/ "scikit-build"
[scikit-build-core]: https://github.com/scikit-build/scikit-build-core "scikit-build-core"
[pybind11-cmake]: https://pybind11.readthedocs.io/en/stable/compiling.html#building-with-cmake "Building with CMake"
[exoplanet-tutorial]: https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets "A more realistic example: radial velocity exoplanets"

0 comments on commit 6587a1c

Please sign in to comment.