-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add expand composite rule #50810
Add expand composite rule #50810
Changes from 10 commits
08b28fd
60d6ca5
016e01a
f6f2d95
0b0103e
3fa65fe
0b77b5b
18d72ed
1e40d36
9c1e3f3
5d0af9e
107bba4
b7f209c
e546b09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,6 +192,42 @@ def mean_composite(x, axis, keepdim): | |
return divide(sum_x, norm) | ||
|
||
|
||
@REGISTER_COMPOSITE('expand_v2') | ||
def expand_v2_composite(x, shape): | ||
""" | ||
define composite rule of op expnad_v2, expand_v2->expand | ||
repeat_times = shape / x.shape | ||
out = tile(x, repeat_times = repeat_times) | ||
""" | ||
shape_in = x.shape | ||
assert len(shape) >= len(shape_in) | ||
dim_out = len(shape) | ||
dim_in = len(shape_in) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. polish 3 lines |
||
if dim_out == 0: | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dim_out can't be 0, remove the code. |
||
repeat_times = [] | ||
for i in range(dim_out): | ||
offset = dim_out - i | ||
dim = dim_in - offset | ||
size_in = shape_in[dim] if dim >= 0 else 1 | ||
size_out = shape[i] | ||
if size_out == -1: | ||
assert dim >= 0 | ||
repeat = 1 | ||
else: | ||
assert size_out % size_in == 0 | ||
repeat = int(size_out / size_in) | ||
repeat_times.append(repeat) | ||
if dim_in < dim_out: | ||
shape_in_expand = [] | ||
for i in range(dim_out - dim_in): | ||
shape_in_expand.append(1) | ||
shape_in_expand.extend(shape_in) | ||
x_reshape = reshape(x, shape_in_expand) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment to show why we need reshape first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Under the static graph, the tile op will set the expanded dimension to -1 when expanding the tensor, which will result in failure to pass the shape check. So for tensors that need to expand the dimension, reshape will be used in advance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return tile(x_reshape, repeat_times=repeat_times) | ||
return tile(x, repeat_times=repeat_times) | ||
|
||
|
||
@REGISTER_COMPOSITE('stack') | ||
def stack_composite(x, axis): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前没有组合测试的单测可以去掉,例如Shape Tensor类的