Skip to content

Commit

Permalink
[XLA] Add more info for docs (#2063)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengYang1 authored and guizili0 committed Apr 26, 2023
1 parent 5b9f729 commit 76d3c55
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 35 deletions.
77 changes: 50 additions & 27 deletions docs/guide/OpenXLA_Support_on_GPU.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,64 @@ Then we can get the library with xla extension **./bazel-bin/itex/libitex_xla_
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so'
$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib # Some functions defined in xla_extension.so are needed by libitex_xla_extension.so
$ export ONEDNN_VERBOSE=1 # Optional variable setting. Enable onednn verbose to check if it runs on GPU.
$ export ITEX_VERBOSE=1 # Optional variable setting. It shows detailed optimization/compilation/execution info.
```
* **Run the below jax python code.**
```python
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
y = jnp.dot(x, x.T).block_until_ready()
print(y)
@jax.jit
def lax_conv():
key = jax.random.PRNGKey(0)
lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32)
rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32)
side = jax.random.uniform(key, (1,1,1,1), jnp.float32)
out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1))
out = jax.nn.relu(out)
out = jnp.multiply(out, side)
return out
print(lax_conv())
```
* **Reference result:**
```
onednn_verbose,info,oneDNN v3.1.0 (commit xxxx)
onednn_verbose,info,cpu,runtime:DPC++,nthr:1
onednn_verbose,info,cpu,isa:Intel 64
onednn_verbose,info,gpu,runtime:DPC++
onednn_verbose,info,cpu,engine,0,backend:OpenCL,name:Intel(R) Xeon(R) Gold 6346 CPU @ 3.10GHz,driver_version:2022.15.12,binary_kernels:disabled
onednn_verbose,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled
onednn_verbose,info,gpu,engine,1,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled
onednn_verbose,info,experimental features are enabled
onednn_verbose,info,use batch_normalization stats one pass is enabled
onednn_verbose,info,experimental functionality for sparse domain is enabled
onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,exec,gpu:0,matmul,jit:gemm:any,undef,src_f32::blocked:abc:f0 wei_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,attr-scratchpad:user ,,1x3000x3000:1x3000x3000:1x3000x3000,xxxxxxxx
[[2938.1716 17.388428 36.508217 ... 32.315964 51.31904 -34.432026]
[17.388428 3031.179 41.194576 ... 47.248768 58.077858 -13.371612]
[36.508217 41.194576 3000.4697 ... 8.10901 -42.501842 26.495111]
...
[32.315964 47.248768 8.10901 ... 2916.339 34.38107 39.404522]
[51.31904 58.077858 -42.501842 ... 34.38107 3032.2844 63.69183 ]
[-34.432026 -13.371612 26.495111 ... 39.404522 63.69183 3033.4866 ]]
I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
I itex/core/compiler/xla/service/service.cc:176] XLA service 0x56060b5ae740 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): <undefined>, <undefined>
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (1): <undefined>, <undefined>
[[[[2.0449753 2.093208 2.1844783 1.9769732 1.5857391 1.6942389]
[1.9218378 2.2862523 2.1549542 1.8367321 1.3978379 1.3860377]
[1.9456574 2.062028 2.0365305 1.901286 1.5255247 1.1421617]
[2.0621 2.2933435 2.1257985 2.1095486 1.5584903 1.1229166]
[1.7746235 2.2446113 1.7870374 1.8216239 1.557919 0.9832508]
[2.0887792 2.5433128 1.9749291 2.2580051 1.6096935 1.264905 ]]]
[[[2.175818 2.0094342 2.005763 1.6559253 1.3896458 1.4036925]
[2.1342552 1.8239582 1.6091168 1.434404 1.671778 1.7397764]
[1.930626 1.659667 1.6508744 1.3305787 1.4061482 2.0829628]
[2.130649 1.6637266 1.594426 1.2636002 1.7168686 1.8598001]
[1.9009514 1.7938274 1.4870623 1.6193901 1.5297288 2.0247464]
[2.0905268 1.7598859 1.9362347 1.9513799 1.9403584 2.1483061]]]]
```
If `ITEX_VERBOSE=1` is set, the log looks like this:
```
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:301] Running HLO pass pipeline on module jit_lax_conv: optimization
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion_merger
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass multi_output_fusion
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass gpu-conv-rewriter
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass onednn-fused-convolution-rewriter
I itex/core/compiler/xla/service/gpu/gpu_compiler.cc:1221] Build kernel via LLVM kernel compilation.
I itex/core/compiler/xla/service/gpu/spir_compiler.cc:255] CompileTargetBinary - CompileToSpir time: 11 us (cumulative: 99.2 ms, max: 74.9 ms, #called: 8)
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2201] Executing computation jit_lax_conv; num_replicas=1 num_partitions=1 num_addressable_devices=1
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete.
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1299] PjRtStreamExecutorBuffer::ToLiteral
```
Check it runs on GPU but not CPU. For example, "onednn_verbose,exec,**gpu**:0,matmul, ..." means "matmul" runs on GPU.
**4. More JAX examples.**
Get examples from [https://github.com/google/jax](https://github.com/google/jax/tree/jaxlib-v0.4.4/examples) to run.
Expand Down
14 changes: 7 additions & 7 deletions itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
if (!Match(instr, pattern)) {
continue;
}
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -229,7 +229,7 @@ StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
if (config.conv_result_scale() != 1) {
continue;
}
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseConvAlpha: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -327,7 +327,7 @@ StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
continue;
}

if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -401,7 +401,7 @@ StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
}))))) {
continue;
}
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -481,7 +481,7 @@ StatusOr<bool> FuseRelu(HloComputation* comp) {
continue;
}

if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseRelu: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -524,7 +524,7 @@ StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
0, m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))));
continue;
}
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseConvertToF16: ", conv->ToString());
})) {
continue;
Expand Down Expand Up @@ -609,7 +609,7 @@ StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
} else {
continue;
}
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
return absl::StrCat("FuseConvertToS8: ", conv->ToString());
})) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace gpu {
class CudnnFusedConvRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-fused-convolution-rewriter";
return "onednn-fused-convolution-rewriter";
}

StatusOr<bool> Run(HloModule* module) override;
Expand Down

0 comments on commit 76d3c55

Please sign in to comment.