Skip to content

Commit

Permalink
Support int16_t in fill_constant_op (PaddlePaddle#35619)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean authored and AnnaTrainingG committed Sep 29, 2021
1 parent 1dd8717 commit 08c4a07
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 20 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fill_constant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fill_constant_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int16_t>;
template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
Expand All @@ -56,6 +57,7 @@ template struct SetConstant<platform::XPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int16_t>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int16_t>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext,
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, uint8, int32, int64.
be float16, float32, float64, uint8, int16, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
Expand Down Expand Up @@ -712,7 +712,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']:
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
Expand All @@ -725,7 +725,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype)

if isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']:
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
Expand All @@ -745,10 +745,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value

check_shape(shape)
check_dtype(
dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'],
'fill_constant')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32',
'int64'
], 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')

if out is not None:
Expand Down
9 changes: 1 addition & 8 deletions python/paddle/fluid/tests/unittests/test_fill_constant_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,6 @@ def test_errors(self):
shape=[1],
value=5,
dtype='uint4')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='int16',
out=x1)

self.assertRaises(
TypeError,
Expand All @@ -375,7 +368,7 @@ def test_errors(self):
out=x1)

# The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, uint8, int32 or int64
#float32, float64, uint8, int16, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")

self.assertRaises(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_full_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_errors(self):
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4')

# The argument dtype of full must be one of bool, float16,
#float32, float64, uint8, int32 or int64
#float32, float64, uint8, int16, int32 or int64

# The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type():
Expand Down

0 comments on commit 08c4a07

Please sign in to comment.