Skip to content

Commit

Permalink
Fix 堆栈溢出 (stack overflow) of case3: paddle.metric.accuracy (#49984)
Browse files Browse the repository at this point in the history
* add input check for accuracyOp

* add input check for gpu/accuracyOp

* add unittest

* use rank instead of dimensions in message

* update unittest

* update unittest
  • Loading branch information
RedContritio authored Feb 3, 2023
1 parent 85490f7 commit 9741121
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
23 changes: 23 additions & 0 deletions paddle/phi/kernels/cpu/accuracy_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,33 @@ void AccuracyRawKernel(const Context& dev_ctx,
const int64_t* indices_data = indices.data<int64_t>();
const int64_t* label_data = label.data<int64_t>();

PADDLE_ENFORCE_EQ(
inference.dims().size(),
2,
phi::errors::InvalidArgument(
"Rank(Input) of AccuracyOp must be 2, with shape "
"[sample_number, class_dim], But received rank(Input) is %d",
inference.dims().size()));

size_t num_samples = inference.dims()[0];
size_t class_dim = inference.dims()[1];
*accuracy_data = 0.0f;

PADDLE_ENFORCE_GT(label.dims().size(),
0,
phi::errors::InvalidArgument(
"Rank(Label) of AccuracyOp must greater than 0, "
"But received rank(Label) is %d",
label.dims().size()));

PADDLE_ENFORCE_GE(
label.dims()[0],
inference.dims()[0],
phi::errors::InvalidArgument("num_samples(%d) of Label should less than "
"or equal to num_samples(%d) of Input",
label.dims()[0],
num_samples));

if (num_samples == 0) {
return;
}
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/kernels/gpu/accuracy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ void AccuracyRawKernel(const Context& dev_ctx,
const int64_t* indices_data = indices.data<int64_t>();
const int64_t* label_data = label.data<int64_t>();

PADDLE_ENFORCE_EQ(
inference.dims().size(),
2,
phi::errors::InvalidArgument(
"Rank(Input) of AccuracyOp must be 2, with shape "
"[sample_number, class_dim], But received rank(Input) is %d",
inference.dims().size()));

int* correct_data = dev_ctx.template Alloc<int>(correct);
int* total_data = dev_ctx.template Alloc<int>(total);
T* accuracy_data = dev_ctx.template Alloc<T>(accuracy);
Expand All @@ -91,6 +99,21 @@ void AccuracyRawKernel(const Context& dev_ctx,
auto stream = dev_ctx.stream();
phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(T), stream);

PADDLE_ENFORCE_GT(label.dims().size(),
0,
phi::errors::InvalidArgument(
"Rank(Label) of AccuracyOp must greater than 0, "
"But received rank(Label) is %d",
label.dims().size()));

PADDLE_ENFORCE_GE(
label.dims()[0],
inference.dims()[0],
phi::errors::InvalidArgument("num_samples(%d) of Label should less than "
"or equal to num_samples(%d) of Input",
label.dims()[0],
num_samples));

if (num_samples == 0) {
return;
}
Expand Down
17 changes: 16 additions & 1 deletion python/paddle/fluid/tests/unittests/test_accuracy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_check_output(self):


class TestAccuracyOpError(unittest.TestCase):
def test_errors(self):
def test_type_errors(self):
with program_guard(Program(), Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
Expand All @@ -75,12 +75,27 @@ def test_errors(self):
x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="int32")
self.assertRaises(TypeError, paddle.static.accuracy, x2, label)
self.assertRaises(TypeError, paddle.metric.accuracy, x2, label)

x3 = paddle.static.data(
name='input', shape=[-1, 2], dtype="float16"
)
paddle.static.accuracy(input=x3, label=label)
paddle.metric.accuracy(input=x3, label=label)

def test_value_errors(self):
with program_guard(Program(), Program()):
paddle.disable_static()

# The input rank of accuracy_op must be 2.
with self.assertRaises(ValueError):
x3 = paddle.to_tensor([0.1], dtype='float32')
label3 = paddle.to_tensor(
np.reshape([0], [1, 1]), dtype='int32'
)
paddle.metric.accuracy(x3, label3)

paddle.enable_static()


class TestAccuracyAPI1(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 9741121

Please sign in to comment.