From db4feb2f4f0be84867a637b61601364fa89b9cf2 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 10 Aug 2023 07:53:11 +0800 Subject: [PATCH 1/2] fix fused linear grad add ut --- .../test_fused_linear_param_grad_add.py | 62 ++++++++++++------- tools/gpups_test.sh | 2 +- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/test/legacy_test/test_fused_linear_param_grad_add.py b/test/legacy_test/test_fused_linear_param_grad_add.py index e707bbc41fa2a..6f98745e68827 100644 --- a/test/legacy_test/test_fused_linear_param_grad_add.py +++ b/test/legacy_test/test_fused_linear_param_grad_add.py @@ -54,7 +54,7 @@ def recreate(x, multi_precision): return paddle.to_tensor(x.numpy()) -def run_ground_truth(x, dy, dweight, dbias, multi_precision): +def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias): x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision) dweight_tmp = paddle.matmul( @@ -69,24 +69,35 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision): assert dweight.dtype == dweight.dtype dweight += dweight_tmp - dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) - if dbias is None: - dbias = dbias_tmp - else: - assert dbias.shape == dbias_tmp.shape - assert dbias.dtype == dbias_tmp.dtype - dbias += dbias_tmp + if has_bias: + dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) + if dbias is None: + dbias = dbias_tmp + else: + assert dbias.shape == dbias_tmp.shape + assert dbias.dtype == dbias_tmp.dtype + dbias += dbias_tmp - return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + else: + return promote_dtype(dweight).numpy() -def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision): +def run_fused_linear_param_grad_add( + x, dy, dweight, dbias, multi_precision, has_bias +): dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) if dweight is not None: assert dweight_new.data_ptr() == dweight.data_ptr() - return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy() + if has_bias: + return ( + promote_dtype(dweight_new).numpy(), + promote_dtype(dbias_new).numpy(), + ) + else: + return promote_dtype(dweight_new).numpy() class TestMainClassBase(unittest.TestCase): @@ -103,7 +114,9 @@ def rand(self, shape, dtype=None): x = paddle.to_tensor(x) return x.astype(dtype or self.dtype) - def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): + def generate_rand_inputs( + self, has_dweight, has_dbias, multi_precision, has_bias + ): x_shape = self.shape dy_shape = self.shape[:-1] + [self.output_size] dweight_shape = [self.shape[-1], self.output_size] @@ -118,7 +131,7 @@ def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): else: dweight = None - if has_dbias: + if has_bias and has_dbias: dbias = self.rand(dbias_shape) if multi_precision: dbias = promote_dtype(dbias) @@ -126,14 +139,16 @@ def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): dbias = None return x, dy, dweight, dbias - def check_main(self, has_dweight, has_dbias, multi_precision): - print(has_dweight, has_dbias, multi_precision) + def check_main(self, has_dweight, has_dbias, multi_precision, has_bias): + print(has_dweight, has_dbias, multi_precision, has_bias) x, dy, dweight, dbias = self.generate_rand_inputs( - has_dweight, has_dbias, multi_precision + has_dweight, has_dbias, multi_precision, has_bias + ) + res1 = run_ground_truth( + x, dy, dweight, dbias, multi_precision, has_bias ) - res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision) res2 = run_fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) self.assertEqual(len(res1), len(res2)) for r1, r2 in zip(res1, res2): @@ -153,9 +168,12 @@ def test_main(self): return for has_dweight in [False, True]: - for has_dbias in [False, True]: - for multi_precision in [False, True]: - self.check_main(has_dweight, has_dbias, multi_precision) + for has_bias in [False, True]: + for has_dbias in [False, True]: + for multi_precision in [False, True]: + self.check_main( + has_dweight, has_dbias, multi_precision, has_bias + ) class TestMainClassBF16(TestMainClassBase): diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index a833e48efc3bd..d0f9dd19341bd 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Disable Test list: test_fused_linear_param_grad_add serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ @@ -69,6 +68,7 @@ parallel_list="^init_phi_test$|\ ^test_fused_gemm_epilogue_op$|\ ^test_fused_gemm_epilogue_op_with_es$|\ ^test_fused_layernorm_residual_dropout_bias$|\ +^test_fused_linear_param_grad_add$|\ ^test_fused_linear_pass$|\ ^test_fused_matmul_bias$|\ ^test_fused_multi_transformer_decoder_pass$|\ From 5b75adc031080d6332225e1f011ab6703dd1ced4 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 10 Aug 2023 08:28:10 +0800 Subject: [PATCH 2/2] fix ci --- test/legacy_test/test_fused_linear_param_grad_add.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/legacy_test/test_fused_linear_param_grad_add.py b/test/legacy_test/test_fused_linear_param_grad_add.py index 6f98745e68827..762b2a99b52e9 100644 --- a/test/legacy_test/test_fused_linear_param_grad_add.py +++ b/test/legacy_test/test_fused_linear_param_grad_add.py @@ -140,7 +140,6 @@ def generate_rand_inputs( return x, dy, dweight, dbias def check_main(self, has_dweight, has_dbias, multi_precision, has_bias): - print(has_dweight, has_dbias, multi_precision, has_bias) x, dy, dweight, dbias = self.generate_rand_inputs( has_dweight, has_dbias, multi_precision, has_bias )