-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store correct JAX representation in QJIT object #520
Comments
@grwlf Any thoughts on this? |
@erick-xanadu Wouldn't the treedef ( |
@dime10 I am not sure. I think |
I think there are no errors here, but different view points are possible. Consider the following program: @qjit(keep_intermediate=True)
def fun(a):
r = jnp.ones((a + 1,), dtype=int)
return r The full Jaxpr of it is
Here, module @fun {
func.func public @jit_fun(%arg0: tensor<i64>) -> tensor<?xi64> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1> : tensor<i64>
%1 = stablehlo.add %arg0, %0 : tensor<i64>
%2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
%3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
%4 = stablehlo.dynamic_broadcast_in_dim %0, %3, dims = [] : (tensor<i64>, tensor<1xi32>) -> tensor<?xi64>
return %4 : tensor<?xi64> // <------- Note: only one return value
}
} The question is: which version of Jaxpr should we call the correct one? I suggest to think of |
Thanks @grwlf! I'd be inclined to say that Another question, is |
Do you have some arguments for this? Do you think we can call
Note that it is not true for |
By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have I think ideally the jaxpr we produce should be valid. I understand that the moment it was decided to remove this return value, we deviated from that. Do you remember why this return value was removed? |
Not exactly that. Consider the Jaxpr after the implicit outputs reduction is applied
Here,
I think it is removed solely to keep StableHLO lowering code satisfied. StableHLO does not need dimension variables so I believe (didn't look very carefully there) that its lowering code is permissively ignores them so everything keeps working. |
To summarize the resolution we came to:
|
This is part 2 of a refactor started in #529. The QJIT class is reworked into 5 distinct compilation stages: - pre-compilation (like autograph) - capture (jaxpr generation) - ir-generation (mlir generation) - compilation (llvm and binary code generation - cannot be split up since this happens in the compiler driver) - execution The class is also streamlined by using a new compilation cache to handle previously compiled functions and signature lookups. One point of contention might be the results produced by the split of the `trace_to_mlir` function, which have been simplified and need to be double checked against #520. EDIT: c71c322 should address this concern [sc-57014] closes #268 closes #520
This is part 2 of a refactor started in #529. The QJIT class is reworked into 5 distinct compilation stages: - pre-compilation (like autograph) - capture (jaxpr generation) - ir-generation (mlir generation) - compilation (llvm and binary code generation - cannot be split up since this happens in the compiler driver) - execution The class is also streamlined by using a new compilation cache to handle previously compiled functions and signature lookups. One point of contention might be the results produced by the split of the `trace_to_mlir` function, which have been simplified and need to be double checked against #520. EDIT: c71c322 should address this concern [sc-57014] closes #268 closes #520
In this function:
we obtain a jaxpr representation from
make_jaxpr
and then we proceed to do some post-processing of it.I think we are returning the wrong jaxpr (it should be jaxpr2) and we can replace the names to appropriately be called:
The text was updated successfully, but these errors were encountered: