From a4a66c8026e4b4b717714588d681af65e9b90630 Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Wed, 29 Nov 2023 15:18:41 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=E3=80=91add=20batch=5Fnorm=5Fgrad?= =?UTF-8?q?=5Fgrad=20in=20pir=20(#59373)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add batch_norm_grad_grad in pir * fix * fix 2023-11-27 --- paddle/fluid/ir_adaptor/translator/op_compat_gen.py | 11 +++++++++++ paddle/phi/api/yaml/op_compat.yaml | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index e0effa77bb05b0..e30a6770554f69 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -136,6 +136,17 @@ def insert_new_mutable_attributes( "grad_y_grad": "DDY", "grad_out_grad": "DDOut", } + op_arg_name_mappings["batch_norm_grad_grad"] = { + "scale_grad": "DScale", + "x_grad": "DX", + "grad_out_grad": "DDY", + "out_mean": "OutMean", + "out_variance": "OutVariance", + "grad_x_grad": "DDX", + "grad_scale_grad": "DDScale", + "grad_bias_grad": "DDBias", + "grad_out": "DY", + } op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 67551f9dd608f5..553df312fdec77 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -318,7 +318,7 @@ {auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut} - op : batch_norm - backward : batch_norm_grad + backward : batch_norm_grad, batch_norm_double_grad(batch_norm_grad_grad) inputs: x : X mean : Mean