-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a 'rolling_buffer' scheduling primitive #7925
Conversation
@junrushao1994 @merrymercy could you take a look at this? |
Can anyone take a look at this? |
It would be great if you can fix CI errors before requesting reviewing :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only two minor comments, otherwise LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mbaret ,
Thanks for work! Did an initial pass with a lookout for documentation and coding style.
I ll do a technical pass later.
db157f9
to
beaa427
Compare
@Hzfengsy took 3 tries, but CI is now passing :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadly looks good.
Just few nits and clarification questions.
friendly ping! @junrushao1994 , It would be appreciated to get some feedback and to know what more needs to be done going forward to get this in.
roll_axis = -1 | ||
for loop in iter_vars: | ||
iter_var = loop.loop_var | ||
if iter_var in bound_iter_vars: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarification question : why cant we just look at bound_iter_vars directly ? Is there non-outermost iter_vars identified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because we don't necessarily iterate over a tensor in the same order as its bounds (e.g. we don't have to go axis 0, 1, 2...)
adding cc : @jcf94 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. Thanks! @mbaret This is really an interesting work. And also thanks to your remind @manupa-arm !
By the way, this PR didn't extent the support of relay integration, does that means currently we're not able to use this feature in an end to end model? Since in my understanding the fusion part of relay does not support to fuse multiple pool
to one subgraph.
return tvm.tir.transform.prim_func_pass( | ||
_ftransform, opt_level=0, name="tir.InjectRollingBuffer" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this is a pretty complex pass, would you consider to rewrite it as a C++ implementation? (not necessary in this PR, we can add a TODO here if the C++ migration is planned)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I'd definitely consider doing that rewrite at some point. I was actually waiting to see how the new 'scheduling passes' would look for TensorIR so that I could potentially follow any such pattern there. Unfortunately I don't have time to currently so I'd appreciate taking this in with the TODO. Perhaps we can revisit once TensorIR is complete?
src/te/operation/compute_op.cc
Outdated
bool skip_ivar_domain = !stage->rolling_buffer; | ||
ret.init_predicates = | ||
MakeBoundCheck(stage, dom_map, ret.init_vmap, skip_ivar_domain, skip_iter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool skip_ivar_domain = !stage->rolling_buffer; | |
ret.init_predicates = | |
MakeBoundCheck(stage, dom_map, ret.init_vmap, skip_ivar_domain, skip_iter); | |
ret.init_predicates = | |
MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter); |
Nit: Skipping the domain check of IterVar here looks more likely a hack to me, though I don't hava any better suggestion. Seems it's hard to process some more check to guard this operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this is a slight hack. It's because in a rolling buffer we need to expand the size and scope of the intermediate buffers and if we drop the bound checks they end up getting corrupted. To do this more accurately we'd need to replicate something closer to Halide's store_at
I think to formally change the realization point of the intermediate tensor. However this is my current workaround.
This doesn't currently improve the Relay build flow but rather adds a new scheduling primitive primarily for use with tvm::build at the moment. We do however intend to introduce some inter-operating scheduling as part of this design https://discuss.tvm.apache.org/t/rfc-cascade-scheduling/8119 and hope to make use of this primitive as part of that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mbaret Thanks for your explanation, this looks good to me.
Support declaring a tensor as a rolling buffer so that it will be realized as a circular buffer without the need to recompute elements. Change-Id: I32e0878bb1402ff0276adf3da3f9a4aaac46dd30
Change-Id: I160a2f95fb31beedb9e6ac8c8b45d51d6ec7ebce
Change-Id: I2973475413331cb9ef044407f7771a06491c390d
Change-Id: Ig6133dd822f33a8d32f3ddd8b8ce22b92490694e
ping @manupa-arm, could you please take another look at this patch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ping @junrushao1994, could you take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :-)
@mbaret Seems this PR is ready for merge once the conflict has been solved? Do we have a github issue to track the progress of your Cascade scheduling & RFC? If not, I suggest to have one since this schedule primitive can only work in a special condition. 😄 |
It would appear the conflict is quite non-trivial (the removal of the Python driver). I don't especially want to work around that by registering my pass into the global registry and then calling it from the C++ API so as not to pollute that with a Python dependency. Given this problem, I shall close this PR for now and when I find the time rewrite the pass in C++. |
Support declaring a tensor as a rolling buffer so that it will be realized as a circular buffer without the need to recompute elements. For further detail you can take a look at this RFC: https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836