Skip to content

Commit

Permalink
[SVE] Add vscale builtin
Browse files Browse the repository at this point in the history
Add a vscale builtin and lowering to `llvm.vscale`. This will be used in
subsequent patches for expressing scalable vectors in TIR.

Co-authored-by: Luke Hutton <luke.hutton@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>
  • Loading branch information
3 people committed Jan 29, 2024
1 parent 90320b2 commit 98a3eb3
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 3 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,11 @@ TVM_DLL const Op& anylist_setitem_call_packed();
*/
TVM_DLL const Op& anylist_setitem_call_cpacked();

/*!
* \brief Get the target's vscale value
*/
TVM_DLL const Op& vscale();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,7 @@ def wrapped(*args, **kwargs):
anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem)
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
vscale = _op_wrapper(_tir_op.vscale)


def _dtype_forward(func):
Expand Down Expand Up @@ -2199,4 +2200,5 @@ def wrapped(*args, **kwargs):
"IterVar",
"CommReducer",
"Range",
"vscale",
]
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3338,6 +3338,16 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args):
)


def vscale():
"""Get the target's vscale value
Returns
-------
call : PrimExpr
Call to the vscale intrinsic
"""
return call_intrin("int32", "tir.vscale")


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
return builder_->CreateAssumption(cond);
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
#if TVM_LLVM_VERSION >= 110
} else if (op->op.same_as(builtin::vscale())) {
llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
return builder_->CreateCall(f);
#endif
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed)

TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace builtin
} // namespace tir
} // namespace tvm
22 changes: 19 additions & 3 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as TIR
from tvm.script import tir as T
import re
import os
import ctypes
import pytest

from tvm.target.codegen import llvm_version_major
Expand Down Expand Up @@ -476,5 +474,23 @@ def check_correct_assembly(type):
check_correct_assembly(type=dtype)


@pytest.mark.skipif(
llvm_version_major() < 10, reason="Vscale is not supported in earlier versions of LLVM"
)
def test_codegen_vscale():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
vscale = tvm.tir.vscale()

@T.prim_func
def main(A: T.Buffer((5,), "int32")):
for i in range(5):
A[i] = 2 * vscale

build_mod = tvm.build(main, target=target)
llvm = build_mod.get_source()

assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 98a3eb3

Please sign in to comment.