Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

fix paddle test: arg_min/arg_max/flip/strided_slice #1485

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cinn/frontend/op_mappers/paddle/arg_min_max.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext
auto out_name = op_desc.Output("Out").front();

auto x = ctx.GetVar(x_name);
auto axis = utils::GetAttrOrDefault<int32_t>(op_desc, "axis", -1);
auto axis = utils::GetAttrOrDefault<int64_t>(op_desc, "axis", -1);
CHECK(op_desc.HasAttr("axis")) << "Argmax/Argmin op should has attribute \"axis\"! Please check.";

auto keepdims = utils::GetAttrOrDefault<bool>(op_desc, "keepdims", false);
Expand Down
8 changes: 7 additions & 1 deletion cinn/frontend/op_mappers/paddle/flip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ void FlipOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx
auto axes = utils::GetAttrOrDefault<std::vector<int>>(op_desc, "axis", std::vector<int>{});
VLOG(4) << "out_name = flip(" << x_name << ", axis=[" << cinn::utils::Join(axes, ", ") << "])";

auto x = ctx.GetVar(x_name);
auto x = ctx.GetVar(x_name);
const auto& ndim = x->shape.size();
for (auto& axis : axes) {
if (axis < 0) {
axis += ndim;
}
}
auto out = ctx.Builder()->Flip(x, axes);

ctx.AddVar(out_name, out);
Expand Down
12 changes: 6 additions & 6 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1446,14 +1446,14 @@ std::vector<std::vector<int>> InferShapeForSlice(const std::vector<std::vector<i
for (int i = 0; i < axes.size(); i++) {
if (ends[i] < 0) {
ends[i] = output_shape[axes[i]] + ends[i];
}
if (starts[i] < 0) {
starts[i] = output_shape[axes[i]] + starts[i];
}
if (ends[i] > output_shape[axes[i]]) {
} else if (ends[i] > output_shape[axes[i]]) {
ends[i] = output_shape[axes[i]];
}
if (starts[i] > output_shape[axes[i]]) {
if (starts[i] < -output_shape[axes[i]]) {
starts[i] = 0;
} else if (starts[i] < 0) {
starts[i] = output_shape[axes[i]] + starts[i];
} else if (starts[i] > output_shape[axes[i]]) {
starts[i] = output_shape[axes[i]] - 1;
}

Expand Down
7 changes: 4 additions & 3 deletions cinn/hlir/pe/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,11 @@ ir::Tensor Slice(const ir::Tensor& A,
}
std::vector<int> new_starts(starts);
for (int i = 0; i < axes.size(); i++) {
if (new_starts[i] < 0) {
if (new_starts[i] < -input_shape[axes[i]]) {
new_starts[i] = 0;
} else if (new_starts[i] < 0) {
new_starts[i] = input_shape[axes[i]] + new_starts[i];
}
if (new_starts[i] > input_shape[axes[i]]) {
} else if (new_starts[i] > input_shape[axes[i]]) {
new_starts[i] = input_shape[axes[i]] - 1;
}
}
Expand Down
14 changes: 10 additions & 4 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,14 @@ CINN_NVGPU_LT_NUM(fp32, float)
CINN_NVGPU_LT_NUM(fp64, double)
CINN_NVGPU_LT_NUM(int32, int)
CINN_NVGPU_LT_NUM(int64, long long int)
#ifdef CINN_CUDA_FP16
CINN_NVGPU_LT_NUM(fp16, float16)
#endif

#undef CINN_NVGPU_LT_NUM

#define CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE) \
__device__ inline int cinn_nvgpu_gt_num_##TYPE_SUFFIX( \
#define CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE) \
__device__ inline int cinn_nvgpu_gt_num_##TYPE_SUFFIX( \
const TYPE *buf, const int size, const TYPE num, const int offset, const int stride) { \
int out = 0; \
for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \
Expand All @@ -702,11 +705,14 @@ CINN_NVGPU_GT_NUM(fp32, float)
CINN_NVGPU_GT_NUM(fp64, double)
CINN_NVGPU_GT_NUM(int32, int)
CINN_NVGPU_GT_NUM(int64, long long int)
#ifdef CINN_CUDA_FP16
CINN_NVGPU_GT_NUM(fp16, float16)
#endif

#undef CINN_NVGPU_GT_NUM

#define CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \
__device__ inline TYPE cinn_nvgpu_index_add_##TYPE_SUFFIX(const TYPE x, \
#define CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \
__device__ inline TYPE cinn_nvgpu_index_add_##TYPE_SUFFIX(const TYPE x, \
const int axis_indice, \
const TYPE *__restrict__ y, \
const int offset, \
Expand Down
34 changes: 31 additions & 3 deletions cinn/runtime/cuda/cuda_instrinsics_float16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,35 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) {

#undef REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL

#define _REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \
#define REGISTER_CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_gt_num_##TYPE_SUFFIX, target) \
.SetRetType<int>() \
.AddInputType<cinn_buffer_t *>() \
.AddInputType<int>() \
.AddInputType<TYPE>() \
.AddInputType<int>() \
.AddInputType<int>() \
.End();

REGISTER_CINN_NVGPU_GT_NUM(fp16, float16);

#undef REGISTER_CINN_NVGPU_GT_NUM

#define REGISTER_CINN_NVGPU_LT_NUM(TYPE_SUFFIX, TYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_lt_num_##TYPE_SUFFIX, target) \
.SetRetType<int>() \
.AddInputType<cinn_buffer_t *>() \
.AddInputType<int>() \
.AddInputType<TYPE>() \
.AddInputType<int>() \
.AddInputType<int>() \
.End();

REGISTER_CINN_NVGPU_LT_NUM(fp16, float16);

#undef REGISTER_CINN_NVGPU_LT_NUM

#define REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_index_add_##TYPE_SUFFIX, target) \
.SetRetType<TYPE>() \
.AddInputType<TYPE>() \
Expand All @@ -88,9 +116,9 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) {
.AddInputType<int>() \
.End();

_REGISTER_CINN_NVGPU_INDEX_ADD(fp16, float16);
REGISTER_CINN_NVGPU_INDEX_ADD(fp16, float16);

#undef _REGISTER_CINN_NVGPU_INDEX_ADD
#undef REGISTER_CINN_NVGPU_INDEX_ADD

return true;
}