Skip to content

Commit

Permalink
Adding more validation checks to _ParallelConcatUpdate to avoid NPE.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 402569467
Change-Id: I2db122dab68be2a5e4e8dd3375f5a70c4d2307ec
  • Loading branch information
rohan100jain authored and tensorflower-gardener committed Oct 12, 2021
1 parent f6da17f commit f2c3931
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/inplace_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class ParallelConcatUpdate : public OpKernel {

void Compute(OpKernelContext* ctx) override {
auto value = ctx->input(0);
// Value should be at least rank 1. Also the 0th dimension should be
// at least loc_.
OP_REQUIRES(ctx, value.dims() >= 1,
errors::InvalidArgument("value should be at least rank 1."));
OP_REQUIRES(
ctx, value.dim_size(0) > loc_,
errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
" is less than loc_=", loc_));

auto update = ctx->input(1);

OP_REQUIRES(
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/python/kernel_tests/array_ops/stack_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

import numpy as np

from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.platform import test

Expand Down Expand Up @@ -69,6 +73,19 @@ def testSimpleParallelCPU(self):
c = array_ops.parallel_stack(xs)
self.assertAllEqual(c, data)

def testParallelConcatShapeZero(self):
if not tf2.enabled():
self.skipTest("only fails in TF2")

@def_function.function
def f():
y = gen_array_ops.parallel_concat(values=[["tf"]], shape=0)
return y

with self.assertRaisesRegex(errors.InvalidArgumentError,
r"0th dimension of value .* is less than"):
f()

def testSimpleParallelGPU(self):
# tf.parallel_stack is only supported in graph mode.
with ops.Graph().as_default():
Expand Down

0 comments on commit f2c3931

Please sign in to comment.