Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store correct JAX representation in QJIT object #520

Closed
erick-xanadu opened this issue Feb 15, 2024 · 9 comments · Fixed by #531
Closed

Store correct JAX representation in QJIT object #520

erick-xanadu opened this issue Feb 15, 2024 · 9 comments · Fixed by #531

Comments

@erick-xanadu
Copy link
Contributor

In this function:

def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
    # ... snip ...
    with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
        make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
        jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

    # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
    jaxpr2, out_type2 = jaxpr_remove_implicit(jaxpr, out_type)
    module, context = jaxpr_to_mlir(func.__name__, jaxpr2)
    return module, context, jaxpr, out_type2, out_tree

we obtain a jaxpr representation from make_jaxpr and then we proceed to do some post-processing of it.

I think we are returning the wrong jaxpr (it should be jaxpr2) and we can replace the names to appropriately be called:

def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
    # ... snip ...
    with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
        make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
        jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

    # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
    jaxpr, out_type = jaxpr_remove_implicit(jaxpr, out_type)
    module, context = jaxpr_to_mlir(func.__name__, jaxpr)
    return module, context, jaxpr, out_type, out_tree
@erick-xanadu erick-xanadu changed the title Are we storing an incorrect JAX representation Store correct JAX representation in QJIT object Feb 15, 2024
@dime10
Copy link
Contributor

dime10 commented Feb 16, 2024

@grwlf Any thoughts on this?

@dime10
Copy link
Contributor

dime10 commented Feb 16, 2024

@erick-xanadu Wouldn't the treedef (out_tree) need updating along with the abstract values (out_type)?

@erick-xanadu
Copy link
Contributor Author

@dime10 I am not sure. I think out_tree should be preserved as that is what the user expects.

@sergei-mironov
Copy link
Contributor

sergei-mironov commented Feb 20, 2024

@grwlf Any thoughts on this?

I think there are no errors here, but different view points are possible. Consider the following program:

@qjit(keep_intermediate=True)
def fun(a):
    r = jnp.ones((a + 1,), dtype=int)
    return r

The full Jaxpr of it is

{ lambda ; a:i64[]. let
    b:i64[] = add a 1
    c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
  in (b, c) }  // <--------- Note two return values.

Here, b is a calculated dimension variable to be returned along with its tensor. If it was not a top-level program, we might want to use b in subsequent calculations. But since it is a top-level Jaxpr, we know that we don't need dimensions anymore, so we remove implicit outputs in order to get the desired StableHLO code. The corresponding IR is

module @fun {
  func.func public @jit_fun(%arg0: tensor<i64>) -> tensor<?xi64> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1> : tensor<i64>
    %1 = stablehlo.add %arg0, %0 : tensor<i64>
    %2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.dynamic_broadcast_in_dim %0, %3, dims = [] : (tensor<i64>, tensor<1xi32>) -> tensor<?xi64>
    return %4 : tensor<?xi64>  // <-------  Note: only one return value
  }
}

The question is: which version of Jaxpr should we call the correct one? I suggest to think of jaxpr2 as of an intermediate part of StableHLO lowering, and return jaxpr as the correct representation of the program.

@dime10
Copy link
Contributor

dime10 commented Feb 20, 2024

Thanks @grwlf! I'd be inclined to say that jaxpr2 should be the output of the "generate jaxpr" stage, and used in all subsequent processing. Unless the jaxpr version is actually needed anything?

Another question, is out_tree (it only has one version) compatible with jaxpr or jaxpr2?

@sergei-mironov
Copy link
Contributor

sergei-mironov commented Feb 21, 2024

I'd be inclined to say that jaxpr2 should be the output of the "generate jaxpr" stage, and used in all subsequent processing.

Do you have some arguments for this? Do you think we can call jaxpr2 a valid Jaxpr program? I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

Another question, is out_tree (it only has one version) compatible with jaxpr or jaxpr2?

out_tree describes the set of explicit arguments, so there is only one version shared by both jaxprs.

Note that it is not true for out_type. The original type lists implicit results while the out_type2 does not (I assume that jaxpr_remove_implicit removes the implicit part). Again, out_type2 might be even strictly-speaking incorrect: its OutDBIdx values might refer to non-existent positions in the list of outputs (one needs to double-check this).

@erick-xanadu
Copy link
Contributor Author

I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have _no_cleanup_deadvars?)

I think ideally the jaxpr we produce should be valid. I understand that the moment it was decided to remove this return value, we deviated from that. Do you remember why this return value was removed?

@sergei-mironov
Copy link
Contributor

I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have _no_cleanup_deadvars?)

Not exactly that. Consider the Jaxpr after the implicit outputs reduction is applied

{ lambda ; a:i64[]. let
    b:i64[] = add a 1
    c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
  in c }

Here, c has type i64[b] but there is no b variable any more in the outer scope.

Do you remember why this return value was removed?

I think it is removed solely to keep StableHLO lowering code satisfied. StableHLO does not need dimension variables so I believe (didn't look very carefully there) that its lowering code is permissively ignores them so everything keeps working.

@dime10
Copy link
Contributor

dime10 commented Feb 22, 2024

To summarize the resolution we came to:

  • the jaxpr with implicit results will be the canonical representation at that level, in order to avoid potentially "incorrect" jaxpr in downstream applications
  • filtering implicit args from the jaxpr is a pre-processing step to the mlir lowering only
  • the out_type after filtering will be removed since it is redundant (only contains (..., True) entries)

dime10 added a commit that referenced this issue Feb 22, 2024
dime10 added a commit that referenced this issue Feb 23, 2024
dime10 added a commit that referenced this issue Feb 23, 2024
This is part 2 of a refactor started in #529. 

The QJIT class is reworked into 5 distinct compilation stages:
- pre-compilation (like autograph)
- capture (jaxpr generation)
- ir-generation (mlir generation)
- compilation (llvm and binary code generation - cannot be split up
since this happens in the compiler driver)
- execution

The class is also streamlined by using a new compilation cache to handle
previously compiled functions and signature lookups.

One point of contention might be the results produced by the split of
the `trace_to_mlir` function, which have been simplified and need to be
double checked against #520. EDIT:
c71c322
should address this concern

[sc-57014]

closes #268 
closes #520
rauletorresc pushed a commit that referenced this issue Feb 26, 2024
This is part 2 of a refactor started in #529. 

The QJIT class is reworked into 5 distinct compilation stages:
- pre-compilation (like autograph)
- capture (jaxpr generation)
- ir-generation (mlir generation)
- compilation (llvm and binary code generation - cannot be split up
since this happens in the compiler driver)
- execution

The class is also streamlined by using a new compilation cache to handle
previously compiled functions and signature lookups.

One point of contention might be the results produced by the split of
the `trace_to_mlir` function, which have been simplified and need to be
double checked against #520. EDIT:
c71c322
should address this concern

[sc-57014]

closes #268 
closes #520
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants