diff --git a/docs/guide/OpenXLA_Support_on_GPU.md b/docs/guide/OpenXLA_Support_on_GPU.md index 711c82923..4c0979c77 100644 --- a/docs/guide/OpenXLA_Support_on_GPU.md +++ b/docs/guide/OpenXLA_Support_on_GPU.md @@ -38,41 +38,64 @@ Then we can get the library with xla extension **./bazel-bin/itex/libitex_xla_ $ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so' $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib # Some functions defined in xla_extension.so are needed by libitex_xla_extension.so -$ export ONEDNN_VERBOSE=1 # Optional variable setting. Enable onednn verbose to check if it runs on GPU. +$ export ITEX_VERBOSE=1 # Optional variable setting. It shows detailed optimization/compilation/execution info. ``` * **Run the below jax python code.** ```python +import jax import jax.numpy as jnp -from jax import random -key = random.PRNGKey(0) -size = 3000 -x = random.normal(key, (size, size), dtype=jnp.float32) -y = jnp.dot(x, x.T).block_until_ready() -print(y) + +@jax.jit +def lax_conv(): + key = jax.random.PRNGKey(0) + lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32) + rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32) + side = jax.random.uniform(key, (1,1,1,1), jnp.float32) + out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1)) + out = jax.nn.relu(out) + out = jnp.multiply(out, side) + return out + +print(lax_conv()) ``` * **Reference result:** ``` -onednn_verbose,info,oneDNN v3.1.0 (commit xxxx) -onednn_verbose,info,cpu,runtime:DPC++,nthr:1 -onednn_verbose,info,cpu,isa:Intel 64 -onednn_verbose,info,gpu,runtime:DPC++ -onednn_verbose,info,cpu,engine,0,backend:OpenCL,name:Intel(R) Xeon(R) Gold 6346 CPU @ 3.10GHz,driver_version:2022.15.12,binary_kernels:disabled -onednn_verbose,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled -onednn_verbose,info,gpu,engine,1,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled -onednn_verbose,info,experimental features are enabled -onednn_verbose,info,use batch_normalization stats one pass is enabled -onednn_verbose,info,experimental functionality for sparse domain is enabled -onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time -onednn_verbose,exec,gpu:0,matmul,jit:gemm:any,undef,src_f32::blocked:abc:f0 wei_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,attr-scratchpad:user ,,1x3000x3000:1x3000x3000:1x3000x3000,xxxxxxxx -[[2938.1716 17.388428 36.508217 ... 32.315964 51.31904 -34.432026] - [17.388428 3031.179 41.194576 ... 47.248768 58.077858 -13.371612] - [36.508217 41.194576 3000.4697 ... 8.10901 -42.501842 26.495111] - ... - [32.315964 47.248768 8.10901 ... 2916.339 34.38107 39.404522] - [51.31904 58.077858 -42.501842 ... 34.38107 3032.2844 63.69183 ] - [-34.432026 -13.371612 26.495111 ... 39.404522 63.69183 3033.4866 ]] +I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero +I itex/core/compiler/xla/service/service.cc:176] XLA service 0x56060b5ae740 initialized for platform sycl (this does not guarantee that XLA will be used). Devices: +I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): , +I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (1): , +[[[[2.0449753 2.093208 2.1844783 1.9769732 1.5857391 1.6942389] + [1.9218378 2.2862523 2.1549542 1.8367321 1.3978379 1.3860377] + [1.9456574 2.062028 2.0365305 1.901286 1.5255247 1.1421617] + [2.0621 2.2933435 2.1257985 2.1095486 1.5584903 1.1229166] + [1.7746235 2.2446113 1.7870374 1.8216239 1.557919 0.9832508] + [2.0887792 2.5433128 1.9749291 2.2580051 1.6096935 1.264905 ]]] + + + [[[2.175818 2.0094342 2.005763 1.6559253 1.3896458 1.4036925] + [2.1342552 1.8239582 1.6091168 1.434404 1.671778 1.7397764] + [1.930626 1.659667 1.6508744 1.3305787 1.4061482 2.0829628] + [2.130649 1.6637266 1.594426 1.2636002 1.7168686 1.8598001] + [1.9009514 1.7938274 1.4870623 1.6193901 1.5297288 2.0247464] + [2.0905268 1.7598859 1.9362347 1.9513799 1.9403584 2.1483061]]]] +``` +If `ITEX_VERBOSE=1` is set, the log looks like this: +``` +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:301] Running HLO pass pipeline on module jit_lax_conv: optimization +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion_merger +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass multi_output_fusion +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass gpu-conv-rewriter +I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass onednn-fused-convolution-rewriter + +I itex/core/compiler/xla/service/gpu/gpu_compiler.cc:1221] Build kernel via LLVM kernel compilation. +I itex/core/compiler/xla/service/gpu/spir_compiler.cc:255] CompileTargetBinary - CompileToSpir time: 11 us (cumulative: 99.2 ms, max: 74.9 ms, #called: 8) + +I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2201] Executing computation jit_lax_conv; num_replicas=1 num_partitions=1 num_addressable_devices=1 +I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete. +I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete +I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1299] PjRtStreamExecutorBuffer::ToLiteral ``` -Check it runs on GPU but not CPU. For example, "onednn_verbose,exec,**gpu**:0,matmul, ..." means "matmul" runs on GPU. **4. More JAX examples.** Get examples from [https://github.com/google/jax](https://github.com/google/jax/tree/jaxlib-v0.4.4/examples) to run. diff --git a/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc b/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc index 661116339..9ee5ae2f3 100644 --- a/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc +++ b/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc @@ -179,7 +179,7 @@ StatusOr FuseConvertToFloat(HloComputation* comp) { if (!Match(instr, pattern)) { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseConvertToFloat: ", conv->ToString()); })) { continue; @@ -229,7 +229,7 @@ StatusOr FuseConvAlpha(HloComputation* comp) { if (config.conv_result_scale() != 1) { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseConvAlpha: ", conv->ToString()); })) { continue; @@ -327,7 +327,7 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString()); })) { continue; @@ -401,7 +401,7 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { }))))) { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseSideInputAlpha: ", conv->ToString()); })) { continue; @@ -481,7 +481,7 @@ StatusOr FuseRelu(HloComputation* comp) { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseRelu: ", conv->ToString()); })) { continue; @@ -524,7 +524,7 @@ StatusOr FuseConvertToF16(HloComputation* comp) { 0, m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall)))); continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseConvertToF16: ", conv->ToString()); })) { continue; @@ -609,7 +609,7 @@ StatusOr FuseConvertToS8(HloComputation* comp) { } else { continue; } - if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] { return absl::StrCat("FuseConvertToS8: ", conv->ToString()); })) { continue; diff --git a/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.h b/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.h index 04411bbb3..7e112b40c 100644 --- a/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.h +++ b/itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.h @@ -97,7 +97,7 @@ namespace gpu { class CudnnFusedConvRewriter : public HloModulePass { public: absl::string_view name() const override { - return "cudnn-fused-convolution-rewriter"; + return "onednn-fused-convolution-rewriter"; } StatusOr Run(HloModule* module) override;