diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index ce4c0a1cd094fd..5a1e8f361bd626 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -24,11 +24,14 @@ "dropout", "full_like", "gelu", + "instance_norm", "layer_norm", + "leaky_relu", "mean", "pow", "relu", "rsqrt", + "sigmoid", "silu", "softmax", "sqrt", @@ -44,11 +47,14 @@ "dropout", "full_like", "gelu", + "instance_norm", "layer_norm", + "leaky_relu", "mean", "pow", "relu", "rsqrt", + "sigmoid", "silu", "softmax", "sqrt", diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 90221982ebbddf..71124cf5593960 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -344,6 +344,7 @@ kernel : func : hardswish_grad inplace : (out_grad -> x_grad) + composite : hardswish_grad(x, out_grad, x_grad) - backward_op : hsigmoid_loss_grad forward : hsigmoid_loss (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 1ab1f33f4f5f68..cb3366229a92f4 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -528,6 +528,115 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { } } +template +Tensor sigmoid_decomp(const Tensor& x) { + auto org_dtype = x.dtype(); + Tensor x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + if (need_cast) { + x_cast = cast(x, phi::DataType::FLOAT32); + } + + // res = 1 / (1 + exp(-x)) + auto one = full(common::vectorize(x_cast.dims()), 1, x_cast.dtype()); + auto exp_tmp = exp( + full(common::vectorize(x_cast.dims()), -1, x_cast.dtype()) * x_cast); + auto res = one / (one + exp_tmp); + if (need_cast) { + return cast(res, org_dtype); + } else { + return res; + } +} + +template +Tensor leaky_relu_decomp(const Tensor& x, float negative_slope) { + auto multiply_tmp = + full(phi::vectorize(x.dims()), negative_slope, x.dtype()) * x; + if (negative_slope < 1.0) { + return maximum(x, multiply_tmp); + } else { + return minimum(x, multiply_tmp); + } +} + +template +std::tuple instance_norm_decomp( + const Tensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon) { + auto org_dtype = x.dtype(); + Tensor x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + if (need_cast) { + x_cast = cast(x, phi::DataType::FLOAT32); + } + + std::vector axis; + auto x_dim = common::vectorize(x.dims()); + for (size_t i = 2; i < x_dim.size(); i++) { + axis.push_back(static_cast(i)); + } + + // out = (x - mean(x)) / sqrt(var + epsilon)) + // var = mean((x-mean(x))^2) + auto mean_ = mean_decomp(x_cast, IntArray(axis), true); + auto difference = x_cast - mean_; + auto var_tmp1 = difference * difference; + auto variance = mean_decomp(var_tmp1, IntArray(axis), true); + auto var_tmp3 = variance + epsilon; + auto rsqrt_var = elementwise_pow( + var_tmp3, + full(common::vectorize(var_tmp3.dims()), 0.5, var_tmp3.dtype())); + auto out = difference / rsqrt_var; + + auto scale_ptr = scale.get_ptr(); + auto bias_ptr = bias.get_ptr(); + std::vector slice_shape(x_dim.size(), 1); + slice_shape[1] = x_dim[1]; + + Tensor scale_cast; + if (scale_ptr) { + if (slice_shape != scale_ptr->shape()) { + scale_cast = reshape(*scale_ptr, slice_shape); + } else { + scale_cast = *scale_ptr; + } + if (need_cast) { + scale_cast = cast(scale_cast, phi::DataType::FLOAT32); + } + out = out * scale_cast; + } + Tensor bias_cast; + if (bias_ptr) { + if (slice_shape != bias_ptr->shape()) { + bias_cast = reshape(*bias_ptr, slice_shape); + } else { + bias_cast = *bias_ptr; + } + if (need_cast) { + bias_cast = cast(bias_cast, phi::DataType::FLOAT32); + } + out = out + bias_cast; + } + + std::vector res_shape(1, -1); + auto mean_out = reshape(mean_, res_shape); + auto variance_out = reshape(1 / rsqrt_var, res_shape); + + Tensor res; + if (need_cast) { + res = cast(out, org_dtype); + } else { + res = out; + } + + return std::make_tuple(res, mean_out, variance_out); +} + } // namespace details } // namespace primitive diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8a9379f528c1eb..7b2b70d7b515f6 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -390,7 +390,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -411,7 +411,7 @@ def init_dtype(self): def test_check_output(self): with paddle.static.scope_guard(paddle.static.Scope()): - self.check_output(check_prim=False) + self.check_output(check_prim=False, check_prim_pir=False) def test_check_grad(self): self.check_grad( @@ -420,6 +420,7 @@ def test_check_grad(self): max_relative_error=0.006, check_prim=False, check_pir=True, + check_prim_pir=False, ) @@ -428,7 +429,9 @@ def init_dtype(self): self.dtype = np.complex128 def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=False, check_pir=True, check_prim_pir=False + ) class TestSigmoid_ZeroDim(TestSigmoid): @@ -469,7 +472,9 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_prim=True, check_pir=True) + self.check_output_with_place( + place, check_prim=True, check_pir=True, check_prim_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) @@ -2555,7 +2560,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -3038,7 +3043,9 @@ def test_check_grad(self): else False, only_check_prim=self.if_only_check_prim(), check_pir=True, - check_prim_pir=True, + check_prim_pir=True + if self.dtype not in [np.complex64, np.complex128] + else False, ) def test_check_output(self): @@ -4832,7 +4839,11 @@ def test_check_grad(self): ) create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class( - TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True + TestSigmoid, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True @@ -4929,18 +4940,24 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_fp16_class(TestMish, check_pir=True) create_test_act_fp16_class( - TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True + TestLeakyRelu, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, +) +create_test_act_fp16_class( + TestLeakyReluAlpha1, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha1, check_prim=True, enable_cinn=True + TestLeakyReluAlpha2, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha2, check_prim=True, enable_cinn=True + TestLeakyReluAlpha3, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha3, check_prim=True, enable_cinn=True + TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True ) -create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True) create_test_act_fp16_class( TestRsqrt, check_prim=True, @@ -5017,7 +5034,9 @@ def test_check_grad(self): TestExpFp32_Prim, check_prim=True, check_prim_pir=True ) create_test_act_bf16_class(TestExpm1) -create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True) +create_test_act_bf16_class( + TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True +) create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestLogSigmoid) create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True) @@ -5089,11 +5108,21 @@ def test_check_grad(self): create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_bf16_class(TestMish, check_pir=True) -create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True) -create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True) -create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) -create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) -create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True) +create_test_act_bf16_class( + TestLeakyRelu, check_prim=True, check_pir=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha1, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha2, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha3, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True +) create_test_act_bf16_class( TestRsqrt, check_prim=True, check_pir=True, check_prim_pir=True ) diff --git a/test/legacy_test/test_instance_norm_op.py b/test/legacy_test/test_instance_norm_op.py index c5fd7af6b48799..3ac10a9547d5c8 100644 --- a/test/legacy_test/test_instance_norm_op.py +++ b/test/legacy_test/test_instance_norm_op.py @@ -130,7 +130,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): self.check_grad( diff --git a/test/legacy_test/test_instance_norm_op_v2.py b/test/legacy_test/test_instance_norm_op_v2.py index fe8e26aaec7839..90641cc20ef8df 100644 --- a/test/legacy_test/test_instance_norm_op_v2.py +++ b/test/legacy_test/test_instance_norm_op_v2.py @@ -220,7 +220,12 @@ def setUp(self): def test_check_output(self): self.check_output( - atol=self.atol, check_prim=self.check_prim, check_pir=True + atol=self.atol, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self): @@ -275,7 +280,13 @@ def set_err_thre(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, atol=self.atol, check_prim=self.check_prim, check_pir=True + place, + atol=self.atol, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self): @@ -350,7 +361,12 @@ def init_shape(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_prim=self.check_prim, check_pir=True + place, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self):