From 1620e6843f8d051edba9b168b0a54d53fee3b580 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 02:07:33 +0000 Subject: [PATCH 1/6] add support of int16 for gather op. --- paddle/fluid/operators/gather_op.cu | 7 +++++++ python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 19568835a6e960..7d5e52385a75e0 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { axis = static_cast(cpu_axis.data()[0]); } else if (axis_type == framework::proto::VarType::INT64) { axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT16) { + axis = static_cast(cpu_axis.data()[0]); } } const auto &place = ctx.GetPlace(); @@ -54,6 +56,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { GatherV2CUDAFunction(x, index, axis, output, place, ctx); } else if (index_type == framework::proto::VarType::INT64) { GatherV2CUDAFunction(x, index, axis, output, place, ctx); + } else if (index_type == framework::proto::VarType::INT16) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); } return; } @@ -64,6 +68,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { GPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { GPUGather(ctx.device_context(), *x, *index, output); + } else if (index_type == framework::proto::VarType::INT16) { + GPUGather(ctx.device_context(), *x, *index, output); } } }; @@ -130,6 +136,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 53bb9a88075628..2909db79f81125 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1402,7 +1402,7 @@ def gather(x, index, axis=None, name=None): return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'int16', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') From 8ce6e163a1c485c86e14729ce30227f724acb6d0 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 06:11:59 +0000 Subject: [PATCH 2/6] Recover formats. --- python/paddle/tensor/manipulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 2d82ec87d2dcb5..556e9239ed1859 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2419,7 +2419,8 @@ def _var_to_list(var): axes_y = [] if np.issubdtype(type(axes), np.integer): assert axes >= 0, ( - "The 'axes' in " + op_type + " should not be negative, but received axes={axes}.") + "The 'axes' in " + op_type + + f" should not be negative, but received axes={axes}.") axes_x = range(x.ndim - axes, x.ndim) axes_y = range(axes) else: From f3b20aceebe9862fa8221f099008beb560a16494 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 06:12:50 +0000 Subject: [PATCH 3/6] Recover formats. --- python/paddle/tensor/manipulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 556e9239ed1859..a4994f459d0c3b 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2419,8 +2419,8 @@ def _var_to_list(var): axes_y = [] if np.issubdtype(type(axes), np.integer): assert axes >= 0, ( - "The 'axes' in " + op_type + - f" should not be negative, but received axes={axes}.") + "The 'axes' in " + op_type + + f" should not be negative, but received axes={axes}.") axes_x = range(x.ndim - axes, x.ndim) axes_y = range(axes) else: From df2f454cdee7ed9652117545b6323848b8d16e51 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 06:14:32 +0000 Subject: [PATCH 4/6] fix. --- paddle/fluid/operators/gather_op.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 8a4c21bafe25b0..dfa221f9a98237 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -137,6 +137,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, From 8854e5522d39f59a1d7db861188ddef199788c15 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 09:43:05 +0000 Subject: [PATCH 5/6] Fix format. --- python/paddle/tensor/manipulation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a4994f459d0c3b..183d321cdaf6e8 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None): return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'int16', 'uint8'], + x, 'x', + ['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') @@ -2419,8 +2420,8 @@ def _var_to_list(var): axes_y = [] if np.issubdtype(type(axes), np.integer): assert axes >= 0, ( - "The 'axes' in " + op_type + - f" should not be negative, but received axes={axes}.") + "The 'axes' in " + op_type + + f" should not be negative, but received axes={axes}.") axes_x = range(x.ndim - axes, x.ndim) axes_y = range(axes) else: From 99e9777777fc4ac1f521316690dca576db46e92e Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 2 Mar 2022 09:52:37 +0000 Subject: [PATCH 6/6] Fix format. --- python/paddle/tensor/manipulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 183d321cdaf6e8..32ccecbc6d9f02 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2420,8 +2420,8 @@ def _var_to_list(var): axes_y = [] if np.issubdtype(type(axes), np.integer): assert axes >= 0, ( - "The 'axes' in " + op_type + - f" should not be negative, but received axes={axes}.") + "The 'axes' in " + op_type + + f" should not be negative, but received axes={axes}.") axes_x = range(x.ndim - axes, x.ndim) axes_y = range(axes) else: