From 6e7c76a368ea676f66ebe0380f8d128da63284fd Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 30 Jul 2024 11:08:10 -0700 Subject: [PATCH] Reland non-splat `tensor.from_elements` to `flow` https://github.com/iree-org/iree/issues/17086 We can support multipel `tensor.from_elements` through using `flow.tensor.store` for each elements into an empty array. --- .../onnx_cpu_llvm_sync.json | 3 -- .../external_test_suite/onnx_gpu_cuda.json | 7 ++--- .../onnx_gpu_rocm_rdna3.json | 10 +++---- .../external_test_suite/onnx_gpu_vulkan.json | 6 ++-- .../Flow/Conversion/TensorToFlow/Patterns.cpp | 30 ++++++++++++++----- .../TensorToFlow/test/from_elements.mlir | 30 +++++++++++-------- 6 files changed, 50 insertions(+), 36 deletions(-) diff --git a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json index 8d461e7f73814..3646f61216bf1 100644 --- a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json +++ b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json @@ -86,16 +86,13 @@ "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded", "onnx/node/generated/test_center_crop_pad_crop", "onnx/node/generated/test_center_crop_pad_crop_and_pad", - "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_chw", "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded", - "onnx/node/generated/test_center_crop_pad_crop_expanded", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded", "onnx/node/generated/test_center_crop_pad_pad", - "onnx/node/generated/test_center_crop_pad_pad_expanded", "onnx/node/generated/test_col2im", "onnx/node/generated/test_col2im_5d", "onnx/node/generated/test_col2im_dilations", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json index a41a6e0d4e223..94c9f544235d0 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json @@ -13,6 +13,8 @@ "skip_run_tests": [ "onnx/node/generated/test_gather_elements_negative_indices", "onnx/node/generated/test_gridsample_zeros_padding", + "onnx/node/generated/test_group_normalization_epsilon_expanded", + "onnx/node/generated/test_group_normalization_example_expanded", "onnx/node/generated/test_scatter_elements_with_negative_indices", "onnx/node/generated/test_resize_downsample_scales_linear_align_corners", "onnx/node/generated/test_resize_downsample_scales_linear_half_pixel_symmetric" @@ -92,16 +94,13 @@ "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded", "onnx/node/generated/test_center_crop_pad_crop", "onnx/node/generated/test_center_crop_pad_crop_and_pad", - "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_chw", "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded", - "onnx/node/generated/test_center_crop_pad_crop_expanded", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded", "onnx/node/generated/test_center_crop_pad_pad", - "onnx/node/generated/test_center_crop_pad_pad_expanded", "onnx/node/generated/test_col2im", "onnx/node/generated/test_col2im_5d", "onnx/node/generated/test_col2im_dilations", @@ -158,9 +157,7 @@ "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0", "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1", "onnx/node/generated/test_group_normalization_epsilon", - "onnx/node/generated/test_group_normalization_epsilon_expanded", "onnx/node/generated/test_group_normalization_example", - "onnx/node/generated/test_group_normalization_example_expanded", "onnx/node/generated/test_gru_batchwise", "onnx/node/generated/test_gru_defaults", "onnx/node/generated/test_gru_seq_length", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json index c6a9a35bbb95f..df5c8ff728715 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json @@ -11,7 +11,10 @@ "skip_compile_tests": [ "onnx/node/generated/test_dequantizelinear" ], - "skip_run_tests": [], + "skip_run_tests": [ + "onnx/node/generated/test_group_normalization_epsilon_expanded", + "onnx/node/generated/test_group_normalization_example_expanded" + ], "expected_compile_failures": [ "onnx/node/generated/test_adagrad", "onnx/node/generated/test_adagrad_multiple", @@ -87,16 +90,13 @@ "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded", "onnx/node/generated/test_center_crop_pad_crop", "onnx/node/generated/test_center_crop_pad_crop_and_pad", - "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_chw", "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded", - "onnx/node/generated/test_center_crop_pad_crop_expanded", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc", "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded", "onnx/node/generated/test_center_crop_pad_pad", - "onnx/node/generated/test_center_crop_pad_pad_expanded", "onnx/node/generated/test_col2im", "onnx/node/generated/test_col2im_5d", "onnx/node/generated/test_col2im_dilations", @@ -153,9 +153,7 @@ "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0", "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1", "onnx/node/generated/test_group_normalization_epsilon", - "onnx/node/generated/test_group_normalization_epsilon_expanded", "onnx/node/generated/test_group_normalization_example", - "onnx/node/generated/test_group_normalization_example_expanded", "onnx/node/generated/test_gru_batchwise", "onnx/node/generated/test_gru_defaults", "onnx/node/generated/test_gru_seq_length", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json index bf88b192f71c0..2acf8ac9547e9 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json @@ -11,7 +11,9 @@ "onnx/node/generated/test_dequantizelinear" ], "skip_run_tests": [ - "onnx/node/generated/test_det_nd" + "onnx/node/generated/test_det_nd", + "onnx/node/generated/test_group_normalization_epsilon_expanded", + "onnx/node/generated/test_group_normalization_example_expanded" ], "expected_compile_failures": [ "onnx/node/generated/test_adagrad", @@ -187,9 +189,7 @@ "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1", "onnx/node/generated/test_gridsample_zeros_padding", "onnx/node/generated/test_group_normalization_epsilon", - "onnx/node/generated/test_group_normalization_epsilon_expanded", "onnx/node/generated/test_group_normalization_example", - "onnx/node/generated/test_group_normalization_example_expanded", "onnx/node/generated/test_gru_batchwise", "onnx/node/generated/test_gru_defaults", "onnx/node/generated/test_gru_seq_length", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp index 0a033622ac5e9..d5794afd504e3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp @@ -187,17 +187,33 @@ struct ConvertTensorFromElementsPattern } auto tensorType = op.getType(); if (!tensorType.hasRank()) { - return failure(); + return rewriter.notifyMatchFailure(op, + "unranked result type not supported"); } - // Check that all the dimensions are 1. - if (!llvm::all_of(tensorType.getShape(), - [](int64_t dim) { return dim == 1; })) { - return failure(); + if (op.getNumOperands() == 1) { + rewriter.replaceOpWithNewOp( + op, tensorType, op.getOperand(0), ValueRange()); + return success(); } - rewriter.replaceOpWithNewOp( - op, tensorType, op.getOperand(0), ValueRange()); + const int64_t rank = tensorType.getRank(); + Value result = rewriter.create( + op.getLoc(), tensorType.getShape(), tensorType.getElementType()); + SmallVector ivs(rank); + for (int i = 0, s = op.getNumOperands(); i < s; ++i) { + int64_t index = i; + for (int j = rank - 1; j >= 0; --j) { + int64_t iv = index % tensorType.getDimSize(j); + index = index / tensorType.getDimSize(j); + ivs[j] = rewriter.create(op.getLoc(), iv); + } + + result = rewriter.create( + op.getLoc(), op.getOperand(i), result, ivs); + } + + rewriter.replaceOp(op, result); return success(); } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir index 13489f8038c9e..087d0e3159b58 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir @@ -9,18 +9,6 @@ util.func public @tensor.from_elements__to__flow.tensor.splat(%arg0: i8) -> (i8) util.return %1 : i8 } -// ----- -// CHECK: util.func public @tensor.from_elements__not_convertible(%[[arg0:.*]]: i8) -util.func public @tensor.from_elements__not_convertible(%arg0: i8) -> (i8) { - // CHECK: %[[c0:.*]] = arith.constant 0 - %c0 = arith.constant 0 : index - // CHECK: %[[res:.*]] = tensor.from_elements %[[arg0]], %[[arg0]] : tensor<2xi8> - %0 = tensor.from_elements %arg0, %arg0 : tensor<2xi8> - // CHECK: flow.tensor.load %[[res]][%[[c0]]] - %1 = flow.tensor.load %0[%c0] : tensor<2xi8> - util.return %1 : i8 -} - // ----- util.func public @tensor.from_elements__within_dispatch_workgroups_not_converted() -> tensor { %x = arith.constant 100 : index @@ -44,3 +32,21 @@ util.func public @tensor.from_elements_0D(%arg0 : f32) -> tensor { // CHECK-SAME: %[[ARG0:.+]]: f32 // CHECK: %[[SPLAT:.+]] = flow.tensor.splat %[[ARG0]] : tensor // CHECK: util.return %[[SPLAT]] + +// ----- + +// CHECK-LABEL: util.func public @tensor.from_elements_2D +util.func @tensor.from_elements_2D(%arg0 : f32, %arg1 : f32, %arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32) -> tensor<2x3xf32> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3xf32> + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x3xf32> + // CHECK: %[[STORE0:.+]] = flow.tensor.store %arg0, %[[EMPTY]][%[[C0]], %[[C0]]] : tensor<2x3xf32> + // CHECK: %[[STORE1:.+]] = flow.tensor.store %arg1, %[[STORE0]][%[[C0]], %[[C1]]] : tensor<2x3xf32> + // CHECK: %[[STORE2:.+]] = flow.tensor.store %arg2, %[[STORE1]][%[[C0]], %[[C2]]] : tensor<2x3xf32> + // CHECK: %[[STORE3:.+]] = flow.tensor.store %arg3, %[[STORE2]][%[[C1]], %[[C0]]] : tensor<2x3xf32> + // CHECK: %[[STORE4:.+]] = flow.tensor.store %arg4, %[[STORE3]][%[[C1]], %[[C1]]] : tensor<2x3xf32> + // CHECK: %[[STORE5:.+]] = flow.tensor.store %arg5, %[[STORE4]][%[[C1]], %[[C2]]] : tensor<2x3xf32> + util.return %0 : tensor<2x3xf32> +}