Skip to content

Commit

Permalink
[BugFix] Add error hint for one_hot gpu version (PaddlePaddle#41335)
Browse files Browse the repository at this point in the history
* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest
  • Loading branch information
DesmonDay committed Apr 7, 2022
1 parent 5b85f3d commit bd096a0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 55 deletions.
57 changes: 21 additions & 36 deletions paddle/phi/kernels/cpu/one_hot_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,12 @@ struct OneHotV2OpFunctor {
DenseTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;

OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}

template <typename OutT>
void apply() const {
Expand All @@ -45,32 +39,24 @@ struct OneHotV2OpFunctor {
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0);

if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) {
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
};
Expand All @@ -89,8 +75,7 @@ void OneHotRawKernel(const Context& dev_ctx,
}

phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(
&x, out, depth, dev_ctx, allow_out_of_range));
OneHotV2OpFunctor<Context, T>(&x, out, depth, dev_ctx));
}

} // namespace phi
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/kernels/gpu/one_hot_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ __global__ void FillOutputKernel(const InT* p_in_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
if (idx < numel) {
PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth,
"Illegal index value, Input(input) value should be "
"greater than or equal to 0, and less than depth [%d], "
"but received [%lld].",
depth,
p_in_data[idx]);

*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}
Expand Down
18 changes: 0 additions & 18 deletions python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,6 @@ def test_check_output(self):
self.check_output()


class TestOneHotOp_out_of_range(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])

out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')

self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}

def test_check_output(self):
self.check_output()


class TestOneHotOp_exception(unittest.TestCase):
def setUp(self):
self.op_type = 'one_hot_v2'
Expand Down

0 comments on commit bd096a0

Please sign in to comment.