-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim][PIR] PIR Prim support intarray, scalar, combineop #58581
Changes from all commits
783343f
8d0a0d9
b2cff37
93893ea
59ca1c6
fadbbb4
d43bd98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,12 +32,20 @@ def _build_tensor_tuple(xs): | |
return TypeError(f"Type {type(xs)} is not supported.") | ||
|
||
|
||
def _analyse_decomp_results(orig_outs, decomp_outs): | ||
assert len(orig_outs) == len(decomp_outs) | ||
def _analyse_decomp_results(orig_outs, decomp_outs, op): | ||
intermediate_values = op.get_output_intermediate_value() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
assert len(orig_outs) == len(decomp_outs) == len(intermediate_values) | ||
res = [] | ||
for org_item, new_item in zip(orig_outs, decomp_outs): | ||
for org_item, new_item, value in zip( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿value 改成 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, thks |
||
orig_outs, decomp_outs, intermediate_values | ||
): | ||
if isinstance(org_item, pir.OpResult): | ||
assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult) | ||
if value: | ||
assert new_item[0] is None | ||
else: | ||
assert len(new_item) == 1 and isinstance( | ||
new_item[0], pir.OpResult | ||
) | ||
res.append(new_item[0]) | ||
else: | ||
res.append(new_item) | ||
|
@@ -256,7 +264,9 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): | |
orig_outs = op.results() | ||
if has_sink_decomp_rule: | ||
decomp_outs = call_decomp(op) | ||
new_outs = _analyse_decomp_results(orig_outs, decomp_outs) | ||
new_outs = _analyse_decomp_results( | ||
orig_outs, decomp_outs, op | ||
) | ||
else: | ||
new_outs = _build_tensor_tuple(decom_rule(*input_args)) | ||
|
||
|
@@ -389,7 +399,9 @@ def decompose_fwd_op( | |
pir.set_insertion_point(fwd_op) | ||
if has_sink_decomp_rule: | ||
decomp_outs = call_decomp(fwd_op) | ||
new_outs = _analyse_decomp_results(orig_outs, decomp_outs) | ||
new_outs = _analyse_decomp_results( | ||
orig_outs, decomp_outs, fwd_op | ||
) | ||
else: | ||
new_outs = _build_tensor_tuple(decom_rule(*input_args)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
换成这个名字更容易理解一点?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或许叫get_output_intermediate_status?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice