Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expand composite rule #50810

Merged
merged 14 commits into from
Mar 13, 2023
Merged
104 changes: 103 additions & 1 deletion python/paddle/fluid/tests/unittests/test_expand_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_check_output(self):
self.check_output()


# Situation 56: input x is Integer
# Situation 6: input x is Integer
class TestExpandV2OpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
Expand Down Expand Up @@ -332,6 +332,108 @@ def test_grad(self):
self.func(p)


# Situation 7: comp case, shape is a list(without tensor)
class TestExpandV2CompOpRank1(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.init_data()
self.python_api = paddle.expand
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
self.attrs = {'shape': self.shape}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
self.enable_cinn = True

def init_data(self):
self.ori_shape = [100]
self.shape = [100]
self.expand_times = [1]

def test_check_output(self):
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)


class TestExpandV2OpCompRank2_DimExpanding(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = [120]
self.shape = [2, 120]
self.expand_times = [2, 1]


class TestExpandV2CompOpRank2(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = [1, 140]
self.shape = [12, 140]
self.expand_times = [12, 1]


class TestExpandV2CompOpRank3_Corner(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = (2, 10, 5)
self.shape = (2, 10, 5)
self.expand_times = (1, 1, 1)


class TestExpandV2CompOpRank4(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.shape = (-1, -1, -1, -1)
self.expand_times = (1, 1, 1, 1)


# Situation 8: comp case, input x is Integer
class TestExpandV2CompOpInteger(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int32")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}

def test_check_output(self):
self.check_output(check_prim=True)


# Situation 9: comp case, input x is Bool
class TestExpandV2CompOpBoolean(OpTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前没有组合测试的单测可以去掉,例如Shape Tensor类的

def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}

def test_check_output(self):
self.check_output(check_prim=True)


# Situation 10: comp case, input x is Integer
class TestExpandV2CompOpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int64")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}

def test_check_output(self):
self.check_output(check_prim=True)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
34 changes: 34 additions & 0 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,40 @@ def mean_composite(x, axis, keepdim):
return divide(sum_x, norm)


@REGISTER_COMPOSITE('expand_v2')
def expand_v2_composite(x, shape):
"""
define composite rule of op expnad_v2, expand_v2->expand
repeat_times = shape / x.shape
out = tile(x, repeat_times = repeat_times)
"""
shape_in = x.shape
dim_out = len(shape)
dim_in = len(shape_in)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish 3 lines

assert dim_in <= dim_out and dim_out >= 0
repeat_times = []
for i in range(dim_out):
offset = dim_out - i
dim = dim_in - offset
size_in = shape_in[dim] if dim >= 0 else 1
size_out = shape[i]
if size_out == -1:
assert dim >= 0
repeat = 1
else:
assert size_out % size_in == 0
repeat = int(size_out / size_in)
repeat_times.append(repeat)
if dim_in < dim_out:
shape_in_expand = []
for i in range(dim_out - dim_in):
shape_in_expand.append(1)
shape_in_expand.extend(shape_in)
x_reshape = reshape(x, shape_in_expand)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to show why we need reshape first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the static graph, the tile op will set the expanded dimension to -1 when expanding the tensor, which will result in failure to pass the shape check. So for tensors that need to expand the dimension, reshape will be used in advance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return tile(x_reshape, repeat_times=repeat_times)
return tile(x, repeat_times=repeat_times)


@REGISTER_COMPOSITE('stack')
def stack_composite(x, axis):
"""
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/incubate/autograd/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from paddle.tensor import sum # noqa: F401
from paddle.tensor import tan # noqa: F401
from paddle.tensor import tanh # noqa: F401
from paddle.tensor import tile # noqa: F401
from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
Expand Down Expand Up @@ -124,6 +125,7 @@
'fill_constant',
'reshape',
'full',
'tile',
'concat',
'uniform',
'greater_equal',
Expand Down