-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM] Codegen subroutine call when CallNode::op is GlobalVar #14901
[LLVM] Codegen subroutine call when CallNode::op is GlobalVar #14901
Conversation
Previously, `CodeGenLLVM` required all TIR PrimFuncs to have the `kGlobalSymbol` attribute, using its value as the externally-visible symbol in the generated library. This commit relaxes that requirement, using the presence of `kGlobalSymbol` to indicate whether a function should be exposed externally. If `kGlobalSymbol` is not defined, then the symbol name is generated from the name of the `tvm::GlobalVar` with the prefix `"_internal_"`, and the symbol is not exposed externally. Since this does not change the codegen behavior for any function that was previously supported, this is not a breaking change.
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
…#14901) * [CodeGen][LLVM] Codegen to generate internal functions Previously, `CodeGenLLVM` required all TIR PrimFuncs to have the `kGlobalSymbol` attribute, using its value as the externally-visible symbol in the generated library. This commit relaxes that requirement, using the presence of `kGlobalSymbol` to indicate whether a function should be exposed externally. If `kGlobalSymbol` is not defined, then the symbol name is generated from the name of the `tvm::GlobalVar` with the prefix `"_internal_"`, and the symbol is not exposed externally. Since this does not change the codegen behavior for any function that was previously supported, this is not a breaking change. * [Codegen][LLVM] Handle callsite for internal functions * [UnitTest][LLVM] Added test for LLVM codegen for subroutine
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
looks like this pr introduced some rocm backend issues. code to reproduce: import tvm
from tvm import te
import numpy as np
import tvm.testing
from tvm.script import tir as T
from tvm.tir import TensorIntrin
M = 64
N = 64
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (M, N), dtype="float32")
B = T.match_buffer(b, (M, N), dtype="float32")
for i, j in T.grid(M, N):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module, debug_mask="all")
block_b = sch.get_block("B")
i, j = sch.get_loops(block_b)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
print(sch.mod["main"].script())
ctx = tvm.rocm(0)
with tvm.transform.PassContext():
rocm_mod = tvm.build(sch.mod, target="rocm")
'''
Traceback (most recent call last):
File "../../tvm_rocm/memory_copy.py", line 38, in <module>
rocm_mod = tvm.build(sch.mod, target="rocm -mcpu=gfx90a")
File "/home/aiscuser/v-leiwang3/tvm/python/tvm/driver/build_module.py", line 281, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/aiscuser/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
5: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
4: tvm::codegen::Build(tvm::IRModule, tvm::Target)
3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
2: tvm::codegen::BuildAMDGPU(tvm::IRModule, tvm::Target)
1: tvm::codegen::CodeGenLLVM::Finish()
0: tvm::codegen::CodeGenLLVM::Verify() const
File "/home/aiscuser/v-leiwang3/tvm/src/target/llvm/codegen_llvm.cc", line 361
TVMError: LLVM module verification failed with the following errors:
Calling convention requires void return type
ptr @main_kernel
Function return type does not match operand type of return inst!
ret void
i32
''' this can be resolved through: // src/target/llvm/codegen_amdgpu.cc
class CodeGenAMDGPU : public CodeGenLLVM {
public:
CodeGenAMDGPU() = default;
virtual ~CodeGenAMDGPU() = default;
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
// amd gpu kernel function should be void return
return this->DeclareFunctionInternal(gvar, f, true);
} because amd gpu kernels should be void return but by default it wll call a llvm codegen implementation which returns int by default, Previously there were no issues, until this code snippet was added. // src/target/llvm/codegen_llvm.cc
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
return it->second;
} and if we wanna leverage shared memory, there's another issue:
I still have no idea how to resolve it, please help me take a look @Lunderberg |
We should not support external global in the device codegen |
Following up on this. Device side codegen follows a different convention and likely we cannot simply have a call from the host to device via the external call. Instead, we need to do them through packed func calls (and we need to do them via PackedFunc convention) So I wonder if we should revert some of the changes in the device codegen(nvptx, rocm) and only support subroutine rewriting in host side. cc @Lunderberg |
Agreed, the device-side codegen should not handle host-to-device calls. If any such calls exist (e.g. resulting from SplitHostDevice), they are lowered to I'm investigating the issue at the moment, and can reproduce it locally. (Thank you @LeiWang1999 for the test case!) It looks like it is a discrepancy between the |
@LeiWang1999 It looks like the discrepancy is specifically between the return type of the PrimFunc ( Long-term, rather than providing the |
Prior to this commit, the `"ret_void"` argument needed to be explicitly provided to `CodeGenLLVM::AddFunction` and `CodeGenLLVM::DeclareFunction`. If this was inconsistent with the `builtin::ret()` usage within the `PrimFunc`, this could cause the incorrect return type in the generated LLVM-IR, resulting in LLVM IR verification failures. This commit removes the `"ret_void"` argument, instead using the type annotation in `PrimFunc::ret_type`, removing this opportunity for inconsistency. This PR is intended to fix a ROCm regression reported in apache#14901 (comment).
@tqchen Forgot to respond to this part. That is the intent, that the LLVM subroutines are supported only on the host side. The primary change to nvptx/rocm were (a) to make sure they could provide the |
@LeiWang1999 Can you try running your test case with PR #15127 ? This is a draft PR that removes the |
Thanks @Lunderberg , I will take a look later, and to reproduce the shared memory external global issue, the test case can be accessed through https://gist.github.com/LeiWang1999/d0db6e802f358a2a42fca32b5cd42441 |
@LeiWang1999 Thank you for the test case. I can reproduce the shmem error, though it looks like it predates the LLVM changes in this PR. Unfortunately, that means there isn't a clear before/after that could be compared and fixed. Do you know if the shared memory on rocm has worked in the past? FYI, I also condensed the test case down the the following test case, which reproduces the same error: import tvm
from tvm.script import tir as T
@T.prim_func
def func():
A = T.alloc_buffer(1, scope="shared")
for i in T.thread_binding(1, thread="threadIdx.x"):
A[0] = 0.0
tvm.build(func, target="rocm") |
@Lunderberg at least from this commit 1d98634 the rocm shared memory should work. |
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
Prior to this commit, the `"ret_void"` argument needed to be explicitly provided to `CodeGenLLVM::AddFunction` and `CodeGenLLVM::DeclareFunction`. If this was inconsistent with the `builtin::ret()` usage within the `PrimFunc`, this could cause the incorrect return type in the generated LLVM-IR, resulting in LLVM IR verification failures. This commit removes the `"ret_void"` argument, instead using the type annotation in `PrimFunc::ret_type`, removing this opportunity for inconsistency. This PR is intended to fix a ROCm regression reported in #14901 (comment).
I think ret void argument issue has been fixed through the related prs, thanks @Lunderberg ! but from my test the external global function still remains. |
I am facing the same issue on the ROCm platform. Is there a solution yet? |
@zhangxiao-stack checkout this branch https://github.com/LeiWang1999/tvm/blob/lei/feat-hip/tests/python/unittest/test_tir_schedule_tensorize_hip_mfma.py#L112 this is a alter hip source code generation, it might be one possible impermanent solution on rocm. |
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
Analogous to #14901, treat GlobalVar callees as internal function calls in CodeGenC. This specific PR doesn't provide new end-to-end functionality, as the target="c" backend isn't compiled. It does lead into allowing subroutines in any target whose codegen derives from CodeGenC, which will depend on the single-module lowering flow in #14985. * [CodeGenC] Added unit tests for desired behavior * [CodeGenC] Handle GlobalVar callee as internal function call * Update CodeGenC subclasses for updated interface - Call `DeclareFunction` for each `PrimFunc`, prior to any `AddFunction` calls - Provide both `GlobalVar` and `PrimFunc` to `AddFunction` calls. * Updated CRT test to expect forward declaration * Provide forward declarations for call_extern in cmsis * Avoid duplicate forward declaration C's automatic pointer cast (e.g. `void*` to `int*`) means that use of the arguments to infer the function signature may be incorrect. If a `call_extern` refers to a function within the same module, only output a single forward declaration based on the PrimFunc's parameters, not based on the CallNode's arguments. * Updated expected ptx cuda * Cast the AOT pools to the arg type * Improved tvm::GetType for tvm_access_ptr and address_of These `Call` instances can return a `PointerType(PrimType(pointee_dtype))` rather than a `PrimType(DataType::Handle())`. * [ARM][Topi] Update micro kernels to use same argument type as caller Previously, the micro kernels for gemm, avg_pool, max_pool, and tensordot relied on C's implicit type conversions for the arguments, when the caller's argument types differ from the signature's parameter types. This works, except when the codegen has auto-generated a forward declaration based on the caller's argument types, such as during AOT, which then causes a conflicting definition. Since the codegen cannot determine the functions names from the `"pragma_import_c"` in order to suppress these forward declarations, this conflict can be more easily resolved by updating the micro kernel signatures. The three types of mismatches are below. - Use of `int` or `long` parameters, whose width may vary by compiler, instead of fixed-width types. - TIR expecting the data array's integer type to also be used as an error code's return type, rather than the micro kernels' `int32_t` error code. - Pointer conversion done during argument conversion. Type conversions are done at the start of each micro kernel, to avoid changing types that are used within the computational sections of each micro kernel. * Updated unit tests with private=True Required for internal functions after PR #15214 * Docstring updates from review
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
Previously, the
CallNode::op
must be a known built-in operation. This commit allows LLVM codegen to produce a subroutine call to another function within the same IRModule, withCallNode::op
specifying theGlobalVar
that represents that function.