GoPJRT (Installing)
GoPJRT leverages OpenXLA to compile, optimize and accelerate numeric computations (with large data) from Go using various backends supported by OpenXLA: CPU, GPUs (Nvidia, AMD ROCm*, Intel*, Apple Metal*) and TPU*. It can be used to power Machine Learning frameworks (e.g. GoMLX), image processing, scientific computation, game AIs, etc.
And because Jax, TensorFlow and optionally PyTorch run on XLA, it is possible to run Jax functions in Go with GoPJRT, and probably TensorFlow and PyTorch as well. See example 2 in xlabuilder/README.md.
(*) Not tested or partially supported by the hardware vendor.
GoPJRT aims to be minimalist and robust: it provides well-maintained, extensible Go wrappers for OpenXLA PJRT.
GoPJRT is not very ergonomic (error handling everywhere), but it's expected to be a stable building block for other projects to create a friendlier API on top. The same way Jax is a Python friendlier API on top of XLA/PJRT.
One such friendlier API co-developed with GoPJRT is GoMLX, a Go machine learning framework. But GoPJRT may be used as a standalone, for lower level access to XLA and other accelerator use cases—like running Jax functions in Go, maybe an "accelerated" image processing or scientific simulation pipeline.
"PJRT" stands for "Pretty much Just another RunTime."
It is the heart of the OpenXLA project: it takes an IR (intermediate representation) of the "computation graph," JIT (Just-In-Time) compiles it (once) and executes it fast (many times). See the Google's "PJRT: Simplifying ML Hardware and Framework Integration" blog post.
A "computation graph" is the part of your program (usually vectorial math/machine learning related) that one wants to "accelerate." It must be provided in an IR (intermediate representation) that is understood by the PJRT plugin. A few ways to create the computation graph IR:
- github.com/gomlx/stablehlo: StableHLO is the current preferred IR language for XLA PJRT. This library (co-developed with GoPJRT) is a Go API for building computation graphs in StableHLO that can be directly fed to GoPJRT. See examples below.
- github.com/gomlx/gopjtr/xlabuilder: This is a wrapper Go library to an XLA C++ library that generates the previous IR (called MHLO). It is still supported by XLA and by GoPJRT, but it is being deprecated.
- Using Jax, Tensorflow, PyTorchXLA: Jax/Tensorflow/PyTorchXLA can output the StableHLO of JIT compiled functions
that can be fed directly to PJRT (as text). We don't detail this here, but the authors did this a lot during development of GoPJRT, github.com/gomlx/stablehlo and github.com/gomlx/gopjtr/xlabuilder for testing.
Note
The IR (intermediary representation) that PJRT plugins accept are text, but not human-friendly to read/write. Small ones are debuggable, or can be used to probe which operations are being used behind the scenes, but definitely not friendly.
A "PJRT Plugin" is a dynamically linked library (.so
file in Linux or .dylib
in Darwin).
Typically, there is one plugin per hardware you are supporting. E.g.: there are PJRT plugins
for CPU (Linux/amd64 for now, but likely it could be compiled for other CPUs -- SIMD/AVX are well-supported),
for TPUs (Google's accelerator),
GPUs (Nvidia is well-supported; there are AMD and Intel's PJRT plugins, but they were not tested),
and others are in development.
- Minimalistic example, that assumes you have your StableHLO code in a variable (
[]byte
) calledstablehloCode
:
var flagPluginName = flag.String("plugin", "cuda", "PRJT plugin name or full path")
...
plugin, err := pjrt.GetPlugin(*flagPluginName)
client, err := plugin.NewClient(nil)
executor, err := client.Compile().WithStableHLO(stablehloCode).Done()
for ii, value := range []float32{minX, minY, maxX, maxY} {
inputs[ii], err = pjrt.ScalarToBuffer(m.client, value)
}
outputs, err := m.exec.Execute(inputs...).Done()
flat, err := pjrt.BufferToArray[float32](outputs[0])
outputs[0].Destroy() // Don't wait for the GC, destroy the buffer immediately.
...
- See mandelbrot.ipynb notebook
with an example building the computation for a Mandelbrot image using
stablehlo
, it includes a sample of the computation's StableHLO IR.

The main package is github.com/gomlx/gopjrt/pjrt
, and we'll refer to it as simply pjrt
.
The pjrt
package includes the following main concepts:
Plugin
: represents a PJRT plugin. It is created by callingpjrt.GetPlugin(name)
(wherename
is the name of the plugin). It is the main entry point to the PJRT plugin.Client
: first thing created after loading a plugin. It seems one can create a singletonClient
per plugin, it's not very clear to me why one would create more than oneClient
.LoadedExecutable
: Created when one callsClient.Compile
a StableHLO program. The program is compiled and optimized to the PJRT target hardware and made ready to run.Buffer
: Represents a buffer with the input/output data for the computations in the accelerators. There are methods to transfer it to/from the host memory. They are the inputs and outputs ofLoadedExecutable.Execute
.
PJRT plugins by default are loaded after the program is started (using dlopen
).
But there is also the option to pre-link the CPU PJRT plugin in your program.
For that, import (as _
) one of the following packages:
github.com/gomlx/gopjrt/pjrt/cpu/static
: pre-link the CPU PJRT statically, so you don't need to distribute a CPU PJRT with your program. But it's slower to build, potentially taking a few extra (annoying) seconds (static libraries are much slower to link).github.com/gomlx/gopjrt/pjrt/cpu/dynamic
: pre-link the CPU PJRT dynamically (as opposed to load it after the Go program starts). It is fast to build, but it still requires deploying the PJRT plugin along with your program. Not commonly used, but a possibility.
While it uses CGO to dynamically load the plugin and call its C API, pjrt
doesn't require anything other than the plugin
to be installed.
The project release includes pre-built CPU released for Linux/amd64 only now. It's been compiled for Macs before—I don't have easy access to an Apple Mac to maintain it.
GoPJRT requires a C library installed for XlaBuilder and one or more "PJRT plugin" modules (the thing that actually does the JIT compilation of your computation graph). To facilitate, it provides an interactive and self-explanatory installer (it comes with lots of help messages):
go run github.com/gomlx/gopjrt/cmd/gopjt_installer
You can also directly provide the flags you want to avoid the interactive mode (so it can be used in scripts like Dockerfiles).
Note
For now it only works for Linux/amd64 (or Windows+WSL) and Nvidia CUDA. I managed to write for Darwin (macOS) before, but not having easy access to a Mac to maintain it, eventually I dropped it. I would also love to support AMD ROCm, but again, I don't have easy access to hardwre to test/maintain it. If you feel like contributing or donating hardware/cloud credits, please contact me.
There are also some older bash install scripts under github.com/gomlx/gopjrt/cmd
,
but they are deprecated and eventually will be removed in a few versions. Let me know if you need them.
If you want to build from scratch (both xlabuilder
and pjrt
dependencies), go to the c/
subdirectory
and run basel.sh
.
It uses Bazel due to its dependencies to OpenXLA/XLA.
If not in one of the supported platforms, you will need to create a xla_configure.OS_ARCH.bazelrc
file.
See docs/devel.md on hints on how to compile a plugin from OpenXLA/XLA sources.
Also, see this blog post with the link and references to the Intel and Apple hardware plugins.
- When is feature X from PJRT or XlaBuilder going to be supported ? Yes, GoPJRT doesn't wrap everything—although it does cover the most common operations. The simple ops and structs are auto-generated. But many require hand-writing. Please, if it is useful to your project, create an issue; I'm happy to add it. I focused on the needs of GoMLX, but the idea is that it can serve other purposes, and I'm happy to support it.
- Why does PJRT spit out so many logs? Can we disable it?
This is a great question ... imagine if every library we use decided they also want to clutter our stderr?
I have an open question in Abseil about it.
It may be some issue with Abseil Logging which also has this other issue
of not allowing two different linked programs/libraries to call its initialization (see Issue #1656).
A hacky workaround is duplicating fd 2 and assign to Go's
os.Stderr
, and then close fd 2, so PJRT plugins won't have where to log. This hack is encoded in the functionpjrt.SuppressAbseilLoggingHack()
: call it before callingpjrt.GetPlugin
. But it may have unintended consequences if some other library depends on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.
Discussion in the Slack channel #gomlx (you can join the slack server here).
Environment variables that help control or debug how GoPJRT works:
PJRT_PLUGIN_LIBRARY_PATH
: Path to search for PJRT plugins. GoPJRT also searches in/usr/local/lib/gomlx/pjrt
,${HOME}/.local/lib/gomlx/pjrt
, in the standard library paths for the system, and in the paths defined in$LD_LIBRARY_PATH
.XLA_FLAGS
: Used by the C++ PJRT plugins. Documentation is linked by the Jax XLA_FLAGS page, but I found it easier to just set this to "--help" and it prints out the flags.XLA_DEBUG_OPTIONS
: If set, it is parsed as aDebugOptions
proto that is passed during the JIT-compilation (Client.Compile()
) of a computation graph. It is not documented how it works in PJRT (e.g., I observed a great slow down when this is set, even if set to the default values), but the proto has some documentation.GOPJRT_INSTALL_DIR
andGOPJRT_NOSUDO
: used by the installation scripts, see "Installing" section above.
- Google Drive Directory with Design Docs: Some links are outdated or redirected, but invaluable information.
- How to use the PJRT C API? #openxla/xla/issues/7038: discussion of folks trying to use PJRT in their projects. Some examples leveraging some of the XLA C++ library.
- How to use PJRT C API v.2 #openxla/xla/issues/7038.
- PJRT C API README.md: a collection of links to other documents.
- Public Design Document.
- Gemini helped quite a bit in parsing and understanding things—despite the hallucinations—other AIs may help as well.
All tests support the following build tags to pre-link the CPU plugin (as opposed to dlopen
the plugin) -- select at most one of them:
--tags pjrt_cpu_static
: link (preload) the CPU PJRT plugin from the static library (.a
) version. Slowest to build (but executes the same speed).--tags pjrt_cpu_dynamic
: link (preload) the CPU PJRT plugin from the dynamic library (.so
) version. Faster to build, but deployments require deploying thelibpjrt_c_api_cpu_dynamic.so
file along.
For Darwin (macOS), for the time being it is hardcoded with static linking, so avoid using these tags.
This project uses the following components from the OpenXLA project:
-
This project includes a (slightly modified) copy of the OpenXLA's
pjrt_c_api.h
file. -
OpenXLA PJRT CPU Plugin: This plugin enables execution of XLA computations on the CPU.
-
OpenXLA PJRT CUDA Plugin: This plugin enables execution of XLA computations on NVIDIA GPUs.
-
We gratefully acknowledge the OpenXLA team for their valuable work in developing and maintaining these plugins.
GoPJRT is licensed under the Apache 2.0 license.
The OpenXLA project, including pjrt_c_api.h
file, the CPU and CUDA plugins, is licensed under the Apache 2.0 license.
The CUDA plugin also uses the Nvidia CUDA Toolkit, which is subject to Nvidia's licensing terms and must be installed by the user.
For more information about OpenXLA, please visit their website at openxla.org, or the GitHub page at github.com/openxla/xla