Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][Schedule] New schedule primitive unsafe_hide_buffer_access #15144

Merged
merged 5 commits into from
Jun 24, 2023

Conversation

yzh119
Copy link
Member

@yzh119 yzh119 commented Jun 22, 2023

Motivation

Currently, our tensorize schedule primitives rely on buffer read/write regions in the given block to perform pattern matching. However, for workloads such as block sparse operators, the read/write regions include some indices arraies that may fail tensorize primitive.
In SparseTIR we introduce a new schedule primitive called hide_buffer_access which allows us to hide certain buffer regions in a block so that the read/write buffer regions would be recognized by the tensorize primitive to further utilize tensor acceleration units.
This PR upstreams this schedule primitive to TensorIR mainline.

The schedule primitive interface

    def hide_buffer_access(self, block: BlockRV, buf_type: str, buf_index_array: List[int]) -> None:
        """Hide some buffer access in a given block.

        Parameters
        ----------
        block : BlockRV
            The block where we hide read access.
        buf_type : str
            The buffer type: "read"/"write".
        buf_index_array : List[int]
            The array of buffer indices we hide access.
        """
        pass

Example

@T.prim_func
def indirect_mem_access(a: T.handle, idx_a: T.handle, b: T.handle, idx_b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    IA = T.match_buffer(idx_a, [10], dtype="int32")
    B = T.match_buffer(b, [128], dtype="float32")
    IB = T.match_buffer(idx_b, [10], dtype="int32")

    for i in range(10):
        with T.block("B"):
            vi = T.axis.spatial(10, i)
            T.reads(A[IA[vi]], IA[vi])
            T.writes(B[IB[vi]], IB[vi])
            B[IB[vi]] = A[IA[vi]]

After we perform hiding buffer access to IA[vi] via:

sch = tir.Schedule(indirect_mem_access, debug_mask="all")
block_b = sch.get_block("B")
sch.hide_buffer_access(block_b, "write", [1]) 

the desired transformed IR would be:

@T.prim_func
def indirect_mem_access_hide_ia(a: T.handle, idx_a: T.handle, b: T.handle, idx_b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    IA = T.match_buffer(idx_a, [10], dtype="int32")
    B = T.match_buffer(b, [128], dtype="float32")
    IB = T.match_buffer(idx_b, [10], dtype="int32")

    for i in range(10):
        with T.block("B"):
            vi = T.axis.spatial(10, i)
            T.reads(A[IA[vi]])
            T.writes(B[IB[vi]], IB[vi])
            B[IB[vi]] = A[IA[vi]]

The existing passes/schedules would not be influenced by this PR.

cc @junrushao @MasterJH5574 @masahi

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 22, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, schedule See #10317 for details

Generated by tvm-bot

Copy link
Contributor

@quic-sanirudh quic-sanirudh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, a nice addition :). As mentioned in my comment below, could we rename the primitive to unsafe_hide_buffer_access and perhaps add some comments in the docstring to indicate the chances of incorrect output resulting from using this primitive.

@yzh119
Copy link
Member Author

yzh119 commented Jun 24, 2023

Hi @quic-sanirudh, thank you for your suggestions, I have marked the schedule as unsafe and added some docstrings explaining it.

@yzh119 yzh119 changed the title [TensorIR][Schedule] New schedule primitive hide_buffer_access [TensorIR][Schedule] New schedule primitive unsafe_hide_buffer_access Jun 24, 2023
@quic-sanirudh
Copy link
Contributor

Hi @quic-sanirudh, thank you for your suggestions, I have marked the schedule as unsafe and added some docstrings explaining it.

Thanks for taking my suggestion. Looks good to me now

@yzh119 yzh119 merged commit 0a5f5f0 into apache:main Jun 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants