diff --git a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc index 585c27bdcec97..a5c9dc4c55e49 100644 --- a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ALL_LAYOUT, phi::SegmentPoolGradKernel, float, - double) {} + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/segment_pool_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_kernel.cc index d0413457f8177..ad76a7a86bcb2 100644 --- a/paddle/phi/kernels/cpu/segment_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} +PD_REGISTER_KERNEL(segment_pool, + CPU, + ALL_LAYOUT, + phi::SegmentPoolKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/segment_pooling.cc b/paddle/phi/kernels/funcs/segment_pooling.cc index bf4a21f37223d..fbd744430aa11 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cc +++ b/paddle/phi/kernels/funcs/segment_pooling.cc @@ -149,10 +149,19 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; + template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu index 305cd39f077bc..95606b1526729 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -453,10 +453,19 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; + template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu index d9618dc159a6d..9d1769e18b4b8 100644 --- a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ALL_LAYOUT, phi::SegmentPoolGradKernel, float, - double) {} + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/segment_pool_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_kernel.cu index c38e935adf837..3128e534166ac 100644 --- a/paddle/phi/kernels/gpu/segment_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_kernel.cu @@ -19,5 +19,11 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} +PD_REGISTER_KERNEL(segment_pool, + GPU, + ALL_LAYOUT, + phi::SegmentPoolKernel, + float, + double, + int, + int64_t) {} diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py index 9f577d5ff3802..2d0b079ee9280 100644 --- a/python/paddle/incubate/tensor/math.py +++ b/python/paddle/incubate/tensor/math.py @@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None): where sum is over j such that `segment_ids[j] == i`. Args: - data (Tensor): A tensor, available data type float32, float64. + data (Tensor): A tensor, available data type float32, float64, int32, int64. segment_ids (Tensor): A 1-D tensor, which have the same size with the first dimension of input data. Available data type is int32, int64. @@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None): of all index 'segment_ids[j] == i'. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None): where min is over j such that `segment_ids[j] == i`. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None): where max is over j such that `segment_ids[j] == i`. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool")