Skip to content

Commit

Permalink
move allow_out_of_range judgement
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Apr 6, 2022
1 parent 1e0a163 commit bda18e2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 70 deletions.
57 changes: 21 additions & 36 deletions paddle/phi/kernels/cpu/one_hot_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,12 @@ struct OneHotV2OpFunctor {
DenseTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;

OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}

template <typename OutT>
void apply() const {
Expand All @@ -45,32 +39,24 @@ struct OneHotV2OpFunctor {
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0);

if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) {
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
};
Expand All @@ -89,8 +75,7 @@ void OneHotRawKernel(const Context& dev_ctx,
}

phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(
&x, out, depth, dev_ctx, allow_out_of_range));
OneHotV2OpFunctor<Context, T>(&x, out, depth, dev_ctx));
}

} // namespace phi
Expand Down
42 changes: 8 additions & 34 deletions paddle/phi/kernels/gpu/one_hot_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,6 @@ __global__ void FillOutputKernel(const InT* p_in_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}

template <typename InT, typename OutT>
__global__ void FillOutputKernelV2(const InT* p_in_data,
OutT* p_out_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth,
"Illegal index value, Input(input) value should be "
Expand All @@ -58,18 +47,12 @@ struct OneHotV2OpCUDAFunctor {
DenseTensor* out_;
const DeviceContext& ctx_;
int depth_;
bool allow_out_of_range_;

OneHotV2OpCUDAFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}

template <typename OutT>
void apply() const {
Expand All @@ -79,19 +62,11 @@ struct OneHotV2OpCUDAFunctor {
auto stream = ctx_.stream();
funcs::set_constant(ctx_, out_, 0.0);

if (allow_out_of_range_) {
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
} else {
FillOutputKernelV2<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
}
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
}
};

Expand All @@ -108,9 +83,8 @@ void OneHotRawKernel(const Context& dev_ctx,
out->Resize(out_dims);
}

phi::VisitDataType(dtype,
OneHotV2OpCUDAFunctor<Context, T>(
&x, out, depth, dev_ctx, allow_out_of_range));
phi::VisitDataType(
dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth, dev_ctx));
}

} // namespace phi
Expand Down

0 comments on commit bda18e2

Please sign in to comment.