Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed precondition error when trying to compile hlo #5544

Closed
maxstupp opened this issue Jan 28, 2021 · 9 comments
Closed

Failed precondition error when trying to compile hlo #5544

maxstupp opened this issue Jan 28, 2021 · 9 comments
Assignees

Comments

@maxstupp
Copy link

Hello,

I am currently trying to run a jax function from c++ similiar to #5337 and #2766. However, when trying to compile the hlo file produced by jax_to_hlo.py, I get the following error:

2021-01-28 17:14:23.168398: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3699990000 Hz
2021-01-28 17:14:23.171305: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x5642643a24f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-01-28 17:14:23.171340: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Host, Default Version
2021-01-28 17:14:23.176602: F tensorflow/stream_executor/lib/statusor.cc:34] Attempting to fetch value instead of handling error Failed precondition: Expected comparison type SIGNED.
actual: UNSIGNED
operand: s64[]

I tried to run the HloModule with PJRT and with LocalClient/LocalExecutable and got the same error.

The function im trying to run in c++ is a function to update the neihgborlist and return the updated indices. The example is mostly taken from this notebook from jax-md: https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nve_neighbor_list.ipynb

@jit
def neighbor_update(R, nbrs):
    nbrs = neighbor_fn(R, nbrs)
    return nbrs.idx

I have attached the python file and the produced hlo.txt file. I used R as input and nbrs as constant when exporting the hlo (I tried both as constants and both as global parameters inside the function, but got the same results)

The code to compile the hlo is mostly taken from #2766.
I run the c++ file with bazel.

#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"

int main(int argc, char** argv) {
    tensorflow::port::InitMain("", &argc, &argv);
    using namespace xla;

    std::string pathToModel = "/tmp/nbrs_hlo.txt";
    se::Platform* platform =
            xla::PlatformUtil::GetPlatform("cpu").ValueOrDie();

    std::unique_ptr<xla::HloModule> module = xla::LoadModuleFromFile(pathToModel).ValueOrDie();
    xla::LocalClient* client = xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
    xla::Shape state_shape = xla::ShapeUtil::MakeShape(xla::F64, {6400,2});
 
    std::vector<std::unique_ptr<xla::LocalExecutable>>linearizeExecutable = client->Compile(xla::XlaComputation(module->ToProto()),
                                                         {&state_shape},
                                                         xla::ExecutableBuildOptions()).ValueOrDie();
   
    return 0;
}

When manually changing the type=UNSIGNED to SIGNED or FLOATS inside the hlo.txt file, the file compiles, but gives (as expected) the wrong result.

Is this a bug in the HLO pipeline? Any help is appreciated!
Thanks in advance!

HLOerror.zip

@skye
Copy link
Member

skye commented Jan 28, 2021

Can you provide the jax_to_hlo.py command you're using to generate the HLO?

@skye skye self-assigned this Jan 28, 2021
@maxstupp
Copy link
Author

maxstupp commented Jan 29, 2021

Of course!

I started with R as input and nbrs as global parameter inside the function and used jax_to_hlo.py from the commandline:

python3 jax_to_hlo.py --fn neighborlist_test.neighbor_update --input_shapes '[("R", "f64[6400,2]")]' --hlo_text_dest /path_to_hlo/nbrs_hlo.txt

Then i tried to insert the nbrs as a constant and called jax_to_hlo.py from inside a python file:

import os 
from ast import literal_eval
from jax.lib import xla_client
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from neighborlist_test import neighbor_update, nbrs

hlo_text_dest = '/path_to_hlo/nbrs_hlo.txt'
hlo_proto_dest = '/path_to_hlo/nbrs_hlo.pb'

fn = neighbor_update

manual_input = '[("R", "f64[6400,2]")]'

constants = {"nbrs": nbrs} #nbrs as constant


input_shapes = [(name, xla_client.Shape(shape_str))
                  for name, shape_str in literal_eval(manual_input)]

hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants)

with open(hlo_text_dest, 'w') as f:
  f.write(hlo_text)

with open(hlo_proto_dest, 'wb') as f:
  f.write(hlo_proto)

Both hlo.txt files contain "type=UNSIGNED" which produce an error, when trying to compile.

@skye
Copy link
Member

skye commented Jan 29, 2021

Hm, when I run your Python code above for calling jax_to_hlo, the resulting HLO doesn't have the error-causing unsigned comparisons. Can you confirm that you're running the latest jax and jaxlib versions, or try upgrading if not?

I also noticed the HLO file you provided uses Windows-style newlines, do you happen to be running on Windows? I wouldn't think that would make a difference here, but maybe worth checking.

@maxstupp
Copy link
Author

Apparently it was caused by an older version. After updating jax (0.2.6->0.2.9), jax-md (0.1.8 -> 0.1.10) and jaxlib (0.1.56->0.1.59) I was able to compile it without the type=unsigned!

I tried to compile a simple function and check the output similiar to the Example in #5337. I used R and nbrs as constants and just returned the mean of the index:

@jit
def neighbor_update(R, nbrs):
    nbrs = neighbor_fn(R, nbrs)
    return onp.mean(nbrs.idx)

For this example this should ouput: 3938.076923076923 but the output I get from compiling it with the pjrt client is:

2021-01-31 16:27:00.691322: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Host, Default Version
result = (
f64[] 108799.5

Did I make a mistake here?

I used the following file to compile the hlo:

#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"

int main(int argc, char** argv) {
  tensorflow::port::InitMain("", &argc, &argv);

  // Load HloModule from file.
  std::string hlo_filename = "/path_to_hlo/nbrs_hlo.txt";
  std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
      [](xla::HloModuleConfig* config) { config->set_seed(42); };
  std::unique_ptr<xla::HloModule> test_module =
      LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
                         "txt", config_modifier_hook)
          .ValueOrDie();
  const xla::HloModuleProto test_module_proto = test_module->ToProto();
    
  // Run it using JAX C++ Runtime (PJRT).

  // Get a CPU client.
  std::unique_ptr<xla::PjRtClient> client =
      xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie();

  // Compile XlaComputation to PjRtExecutable.
  xla::XlaComputation xla_computation(test_module_proto);
  
  xla::CompileOptions compile_options;
  
  std::unique_ptr<xla::PjRtExecutable> executable =
      client->Compile(xla_computation, compile_options).ValueOrDie();

  // Execute on CPU.
  xla::ExecuteOptions execute_options;
  // One vector<buffer> for each device.
  std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
      executable->Execute({{}}, execute_options)
          .ValueOrDie();
       
  // Get result.
  std::shared_ptr<xla::Literal> result_literal =
      results[0][0]->ToLiteral().ValueOrDie();
  LOG(INFO) << "result = " << *result_literal;
 
  return 0;
}

Thank you for your help!

@skye
Copy link
Member

skye commented Feb 1, 2021

Are you sure that's your up-to-date C++ code? It's not passing any arguments into executable->Execute, so I get an error from that when I try to run it.

@maxstupp
Copy link
Author

maxstupp commented Feb 4, 2021

Hello skye,

sorry for the delayed answer. In the example above I did not use an input_shape, I used only constants and extraded the hlo with the python file. I made a better example this time:

I tried using global parameters inside the function (like the position array or nbrs array), but, when using jax numpy arrays, the output value was wrong (probably because the value of the array is not available during compile).

I made a really small reproducable python example with 3 functions. One uses a globally defined jax np array, one a original np array and one uses an array as input. Then I just shift an index by a scalar and return it.

array_test.zip

I extracted the hlo with the following commands:

python3 jax_to_hlo.py --fn array_test.onp_array --input_shapes '[("scalar", "f32[1]")]' --hlo_text_dest /file_to_hlo/onp_array_hlo.txt

python3 jax_to_hlo.py --fn array_test.jnp_array --input_shapes '[("scalar", "f32[1]")]' --hlo_text_dest /file_to_hlo/jnp_array_hlo.txt

python3 jax_to_hlo.py --fn array_test.input_array --input_shapes '[("ar", "f32[7,2]")]' --constants '{"scalar": 3.0}' --hlo_text_dest /file_to_hlo/input_array_hlo.txt

Then I compiled it with the pjrt client:

#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"

int main(int argc, char** argv) {
  tensorflow::port::InitMain("", &argc, &argv);

  // Load HloModule from file.
  std::string hlo_filename = "/path_to_hlo/input_array_hlo.txt";
  std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
      [](xla::HloModuleConfig* config) { config->set_seed(42); };
  std::unique_ptr<xla::HloModule> test_module =
      LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
                         "txt", config_modifier_hook)
          .ValueOrDie();
  const xla::HloModuleProto test_module_proto = test_module->ToProto();
    
  // Run it using JAX C++ Runtime (PJRT).

  // Get a CPU client.
  std::unique_ptr<xla::PjRtClient> client =
      xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie();

  // Compile XlaComputation to PjRtExecutable.
  xla::XlaComputation xla_computation(test_module_proto);
  
  xla::CompileOptions compile_options;
  
  std::unique_ptr<xla::PjRtExecutable> executable =
      client->Compile(xla_computation, compile_options).ValueOrDie();

  // Prepare inputs.
  xla::Literal literal_x =
      xla::LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 1.25f}, {0.0f, 2.5f}, {0.0f, 3.75f}, {0.0f, 5.0f}, {0.0f, 6.25f}, {0.0f, 7.5f}}); //for input_array function
      //xla::LiteralUtil::CreateR1<float>({3.0f}); //for the first 2 functions 

  std::unique_ptr<xla::PjRtBuffer> param_x =
      client->BufferFromHostLiteral(literal_x, client->local_devices()[0])
        .ValueOrDie();

  // Execute on CPU.
  xla::ExecuteOptions execute_options;
  // One vector<buffer> for each device.
  std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =  
      executable->Execute({{param_x.get()}}, execute_options)
          .ValueOrDie();
       
  // Get result.
  std::shared_ptr<xla::Literal> result_literal =
      results[0][0]->ToLiteral().ValueOrDie(); //what does this mean?
  //LOG(INFO) << "result = " << *result_literal;
  std::cout << "result = " << *result_literal; 
 
  return 0;
} 

I got the following outputs:
For original numpy array:

2021-02-04 12:24:20.563541: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Host, Default Version
result = (
f64[1] {4.25}
)

For jax numpy array:

2021-02-04 12:23:47.327335: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Host, Default Version
result = (
f64[1] {3}
)

For input array:

2021-02-04 12:24:58.751672: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Host, Default Version
result = (
f32[] 4.25
)

I guess this is indeed expected behaviour? The concrete values are not available for the jax numpy array.

Thank you for your help!

@skye
Copy link
Member

skye commented Feb 23, 2021

Hey sorry for not following up on this! Did you manage to resolve your issue?

@maxstupp
Copy link
Author

Hello,

I was not able to use a nbrs list as a constant, but I found a workaround. Now I use the idx, reference position and max occupancy of the neigborlist as input parameter and create a new NeigborList object. Then i just update the idx and reference. This seems to work very well for my applications!

@skye
Copy link
Member

skye commented Feb 23, 2021

Great, glad to hear you found a workaround, and thanks for sharing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants