-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Examples of Running a JAX function in C++ #5337
Comments
I'll show an example for a simple JAX function. Let me know if it suffices for your use case. There are two steps involved here:
Step 1: Use jax/tools/jax_to_hlo.py to save a JAX program. Suppose we have a dummy JAX program import jax.numpy as jnp
def fn(x, y, z):
return jnp.dot(x, y) / z Let's convert it to HLO, with input shapes and constants provided (see usage in $ python3 jax_to_hlo.py \
--fn prog.fn \
--input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
--constants '{"z": 2.0}' \
--hlo_text_dest /tmp/fn_hlo.txt \
--hlo_proto_dest /tmp/fn_hlo.pb Pay special attention to the order of parameters specified in Let's see the saved HloModule $ cat /tmp/fn_hlo.txt
HloModule xla_computation_ordered_wrapper.9
ENTRY xla_computation_ordered_wrapper.9 {
constant.3 = pred[] constant(false)
parameter.1 = f32[2,2]{1,0} parameter(0)
parameter.2 = f32[2,2]{1,0} parameter(1)
dot.4 = f32[2,2]{1,0} dot(parameter.1, parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
constant.5 = f32[] constant(2)
broadcast.6 = f32[2,2]{1,0} broadcast(constant.5), dimensions={}
divide.7 = f32[2,2]{1,0} divide(dot.4, broadcast.6)
ROOT tuple.8 = (f32[2,2]{1,0}) tuple(divide.7)
} Note the single output with shape Step 2: Use PJRT runtime to run the saved HloModule with user-provided input values Note that there are multiple C++ runtime APIs that we can use to run an HloModule -- HloRunner, LocalClient/LocalExecutable, or PJRT, which are all in the Tensorflow tree. Since I'm most familiar with JAX's runtime API, I'll show an example using PJRT (see tensorflow/compiler/xla/pjrt/pjrt_client.h). Suppose we have the following BUILD and cc files
The BUILD file looks like
The cc file looks like // An example for reading a HloModule from a HloProto file and execute the
// module on PJRT CPU client.
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
int main(int argc, char** argv) {
tensorflow::port::InitMain("", &argc, &argv);
// Load HloModule from file.
std::string hlo_filename = "/tmp/fn_hlo.txt";
std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
[](xla::HloModuleConfig* config) { config->set_seed(42); };
std::unique_ptr<xla::HloModule> test_module =
LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
"txt", config_modifier_hook)
.ValueOrDie();
const xla::HloModuleProto test_module_proto = test_module->ToProto();
// Run it using JAX C++ Runtime (PJRT).
// Get a CPU client.
std::unique_ptr<xla::PjRtClient> client =
xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie();
// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
xla::CompileOptions compile_options;
std::unique_ptr<xla::PjRtExecutable> executable =
client->Compile(xla_computation, compile_options).ValueOrDie();
// Prepare inputs.
xla::Literal literal_x =
xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
xla::Literal literal_y =
xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}});
std::unique_ptr<xla::PjRtBuffer> param_x =
client->BufferFromHostLiteral(literal_x, client->local_devices()[0])
.ValueOrDie();
std::unique_ptr<xla::PjRtBuffer> param_y =
client->BufferFromHostLiteral(literal_y, client->local_devices()[0])
.ValueOrDie();
// Execute on CPU.
xla::ExecuteOptions execute_options;
// One vector<buffer> for each device.
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
executable->Execute({{param_x.get(), param_y.get()}}, execute_options)
.ValueOrDie();
// Get result.
std::shared_ptr<xla::Literal> result_literal =
results[0][0]->ToLiteral().ValueOrDie();
LOG(INFO) << "result = " << *result_literal;
return 0;
}
To run it $ bazel run -c opt :main
2021-01-07 17:30:23.472798: I tensorflow/compiler/xla/examples/jax_cpp/main.cc:69] result = (
f32[2,2] {
{ 1.5, 1.5 },
{ 3.5, 3.5 }
}
) |
Do you think it makes sense to link this from the FAQ? |
Awesome! This will be very helpful, thank you. I only have one question: if I understand correctly, the arguments to Is one of these approaches preferred, or would they be more or less equivalent in terms of performance? Thanks again! |
@gnecula we could turn this into a developer doc to make it more discoverable? @drebain For the question on NN weights as constants or parameters, I believe making them constants will improve dispatch performance. There may be issues with increased memory consumption. I would recommend try both and measure performance and memory consumption. Do let me know if you find any part of the example unclear, so I can make a developer doc that others can use (running JAX program via C++ API seems to have come up a few times) Just a clarification on the // Executes on devices addressable by the client. Requires executable has a
// device_assignment and all devices in the device_assignment are addressable
// by the client.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) = 0; The inner vector correspond to the parameters of the HLO computation. It's a Span<vector> because it's one set of parameters per device. If it's only a single device, then you just pass in {{param_x.get(), param_y.get()}}. Note the double braces. |
@drebain a bit of nuance on constants vs parameters, there seems to be some nontrivial tradeoffs (may depend on platforms too) some XLA experts recommend turning small or scalars to constants while staying away from large constants unfortunately, it's not easy to experiment with this, so you will have to play around with it to know |
I see. I guess It makes sense that this would depend on the size of the constants. I will try both on a medium-sized MLP network and report back on which is faster for future reference. |
Ok, as promised I have done some basic benchmarks and it seems that compiling network weights in as constants gets me a roughly 30% performance increase over passing them as arguments. This is with a 6-layer, 256-unit MLP, running on an RTX 3090. |
For gpu, you may need the following changes
"//tensorflow/compiler/xla/pjrt:cpu_device", |
PJRT is the lowest level runtime API, and only advisable if one wants to avoid direct TF deps. Otherwise, please use jax2tf and the usual TF C++ server for running a SavedModel. https://github.com/google/jax/tree/main/jax/experimental/jax2tf |
Thank you for the script, it's very useful! Do you know how to cast a |
@zhangqiaorjc Thank you for the example, it's very useful! I tried it using the newest codebase, but the output is 2022-06-18 22:26:10.013679: I jax_cpp/main.cc:68] result = (
f32[2,2] {
{ 2, 3 },
{ 2, 3 }
}
) It looks like Did this API change recently? Is there any building option to switch to row-major matrix like it in Python? |
I'm trying to determine to what extent something like what @zhangqiaorjc shared above could be used to deploy JAX code into real-time/low-memory/restricted (e.g. no dynamic memory allocation) environments. For example, would it be possible to "pre-compute" Does the |
^ This is discussed here: #22184 |
I am trying to build the JAX CPP example for GPU Backend. It runs into linking error. Build file is as follows
Build Commands
Runs into following linking error
|
I tried to use the example above (#5337 (comment)), but in TensorFlow 2.17 the header |
Are there any examples available of running
jit
functions defined in python from C++? I see that there is an interface for generating something usable by XLA but it is a bit unclear how to use the result of this when the function is dependent on variables/weights (e.g. a flax module).@zhangqiaorjc I am told you have some knowledge of this?
Thanks
The text was updated successfully, but these errors were encountered: