Skip to content

Commit

Permalink
[PIR]Migrate argmin/max into pir (PaddlePaddle#57909)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored and jiahy0825 committed Oct 26, 2023
1 parent e9e6250 commit baf9203
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 43 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
backward : angle_grad

- op : argmax
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, int dtype = 3)
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, DataType dtype = DataType::INT64)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
Expand All @@ -143,7 +143,7 @@
data_type : x

- op : argmin
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, int dtype = 3)
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, DataType dtype = DataType::INT64)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
Expand Down
21 changes: 9 additions & 12 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,19 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
(dtype == DataType::UNDEFINED || dtype == DataType::INT32 ||
dtype == DataType::INT64),
true,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64),
DataTypeToString(phi::TransToPhiDataType(dtype))));
DataTypeToString(dtype)));

if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
Expand All @@ -177,10 +178,8 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
if (dtype == DataType::INT32 || dtype == DataType::INT64) {
out->set_dtype(dtype);
}
return;
}
Expand Down Expand Up @@ -216,7 +215,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (int_axis < 0) int_axis += x_rank;

if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
if (dtype == DataType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
Expand Down Expand Up @@ -253,10 +252,8 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}

out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
if (dtype == DataType::INT32 || dtype == DataType::INT64) {
out->set_dtype(dtype);
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/arg_min_max_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void ArgMinKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out);

template <typename T, typename Context>
Expand All @@ -34,7 +34,7 @@ void ArgMaxKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out);

} // namespace phi
10 changes: 5 additions & 5 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ void ArgMinMaxKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
if (dtype < 0) {
if (dtype == DataType::UNDEFINED) {
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
phi::VisitDataTypeTiny(
phi::TransToPhiDataType(dtype),
dtype,
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
Expand All @@ -172,7 +172,7 @@ void ArgMinKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
Expand All @@ -184,7 +184,7 @@ void ArgMaxKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,17 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
if (dtype < 0) {
if (dtype == DataType::UNDEFINED) {
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
phi::VisitDataTypeTiny(
phi::TransToPhiDataType(dtype),
dtype,
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
Expand All @@ -230,7 +230,7 @@ void ArgMinKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
Expand All @@ -242,7 +242,7 @@ void ArgMaxKernel(const Context& dev_ctx,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/multinomial_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx,

if (int_num_samples == 1) {
ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
dev_ctx, rand, -1, true, false, DataType::INT64, out);
} else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
Expand Down
13 changes: 4 additions & 9 deletions paddle/phi/kernels/xpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@

namespace phi {

namespace {
const int ARG_MAX_OUTPUT_DATATYPE_INT32 = 2;
const int ARG_MAX_OUTPUT_DATATYPE_INT64 = 3;
} // Anonymous namespace

template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DataType dtype,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 ||
dtype == ARG_MAX_OUTPUT_DATATYPE_INT64),
(dtype == DataType::UNDEFINED || dtype == DataType::INT32 ||
dtype == DataType::INT64),
true,
errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
Expand All @@ -60,7 +55,7 @@ void ArgMaxKernel(const Context& dev_ctx,
}
auto xdims_vec = phi::vectorize<int>(x_dims);
int r = 0;
if (dtype != ARG_MAX_OUTPUT_DATATYPE_INT32) {
if (dtype != DataType::INT32) {
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
Expand Down
12 changes: 8 additions & 4 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[2, 2, 0, 1]]
"""
if axis is not None and not isinstance(axis, (int, Variable)):
if axis is not None and not isinstance(
axis, (int, Variable, paddle.pir.OpResult)
):
raise TypeError(
"The type of 'axis' must be int or Tensor or None in argmax, but received %s."
% (type(axis))
Expand All @@ -188,7 +190,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True
axis = 0

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.argmax(x, axis, keepdim, flatten, var_dtype)
else:
helper = LayerHelper("argmax", **locals())
Expand Down Expand Up @@ -261,7 +263,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[1, 1, 1, 2]]
"""
if axis is not None and not isinstance(axis, (int, Variable)):
if axis is not None and not isinstance(
axis, (int, Variable, paddle.pir.OpResult)
):
raise TypeError(
"The type of 'axis' must be int or Tensor or None in argmin, but received %s."
% (type(axis))
Expand All @@ -278,7 +282,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True
axis = 0

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.argmin(x, axis, keepdim, flatten, var_dtype)
else:
helper = LayerHelper("argmin", **locals())
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_arg_min_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_pir=True)


class TestCase0(BaseTestCase):
Expand Down Expand Up @@ -122,7 +122,7 @@ def setUp(self):
self.outputs = {'Out': np.argmax(x, axis=self.axis)}

def test_check_output(self):
self.check_output_with_place(paddle.CUDAPlace(0))
self.check_output_with_place(paddle.CUDAPlace(0), check_pir=True)


class TestArgMaxBF16OP(TestArgMinBF16OP):
Expand Down

0 comments on commit baf9203

Please sign in to comment.