-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Failed precondition error when trying to compile hlo #5544
Comments
Can you provide the jax_to_hlo.py command you're using to generate the HLO? |
Of course! I started with R as input and nbrs as global parameter inside the function and used jax_to_hlo.py from the commandline:
Then i tried to insert the nbrs as a constant and called jax_to_hlo.py from inside a python file:
Both hlo.txt files contain "type=UNSIGNED" which produce an error, when trying to compile. |
Hm, when I run your Python code above for calling jax_to_hlo, the resulting HLO doesn't have the error-causing unsigned comparisons. Can you confirm that you're running the latest jax and jaxlib versions, or try upgrading if not? I also noticed the HLO file you provided uses Windows-style newlines, do you happen to be running on Windows? I wouldn't think that would make a difference here, but maybe worth checking. |
Apparently it was caused by an older version. After updating jax (0.2.6->0.2.9), jax-md (0.1.8 -> 0.1.10) and jaxlib (0.1.56->0.1.59) I was able to compile it without the type=unsigned! I tried to compile a simple function and check the output similiar to the Example in #5337. I used R and nbrs as constants and just returned the mean of the index:
For this example this should ouput: 3938.076923076923 but the output I get from compiling it with the pjrt client is:
Did I make a mistake here? I used the following file to compile the hlo:
Thank you for your help! |
Are you sure that's your up-to-date C++ code? It's not passing any arguments into |
Hello skye, sorry for the delayed answer. In the example above I did not use an input_shape, I used only constants and extraded the hlo with the python file. I made a better example this time: I tried using global parameters inside the function (like the position array or nbrs array), but, when using jax numpy arrays, the output value was wrong (probably because the value of the array is not available during compile). I made a really small reproducable python example with 3 functions. One uses a globally defined jax np array, one a original np array and one uses an array as input. Then I just shift an index by a scalar and return it. I extracted the hlo with the following commands:
Then I compiled it with the pjrt client:
I got the following outputs:
For jax numpy array:
For input array:
I guess this is indeed expected behaviour? The concrete values are not available for the jax numpy array. Thank you for your help! |
Hey sorry for not following up on this! Did you manage to resolve your issue? |
Hello, I was not able to use a nbrs list as a constant, but I found a workaround. Now I use the idx, reference position and max occupancy of the neigborlist as input parameter and create a new NeigborList object. Then i just update the idx and reference. This seems to work very well for my applications! |
Great, glad to hear you found a workaround, and thanks for sharing! |
Hello,
I am currently trying to run a jax function from c++ similiar to #5337 and #2766. However, when trying to compile the hlo file produced by jax_to_hlo.py, I get the following error:
I tried to run the HloModule with PJRT and with LocalClient/LocalExecutable and got the same error.
The function im trying to run in c++ is a function to update the neihgborlist and return the updated indices. The example is mostly taken from this notebook from jax-md: https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nve_neighbor_list.ipynb
I have attached the python file and the produced hlo.txt file. I used R as input and nbrs as constant when exporting the hlo (I tried both as constants and both as global parameters inside the function, but got the same results)
The code to compile the hlo is mostly taken from #2766.
I run the c++ file with bazel.
When manually changing the type=UNSIGNED to SIGNED or FLOATS inside the hlo.txt file, the file compiles, but gives (as expected) the wrong result.
Is this a bug in the HLO pipeline? Any help is appreciated!
Thanks in advance!
HLOerror.zip
The text was updated successfully, but these errors were encountered: