Skip to content

Commit

Permalink
[SVE] Add codegen support for scalable buffer accesses (apache#16696)
Browse files Browse the repository at this point in the history
This commit adds support for generating code for scalable loads and
stores. It also adds support for the creation of scalable broadcast
operations.


Co-authored-by: Elen Kalda <elen.kalda@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>
  • Loading branch information
3 people authored and thaisacs committed Apr 3, 2024
1 parent 44828dd commit 22d32df
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 39 deletions.
16 changes: 12 additions & 4 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class DataType {
}
return -lanes_as_int;
}
/*! \return get vscale factor or lanes depending on scalability of the vector. */
int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); }
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
Expand Down Expand Up @@ -211,10 +213,13 @@ class DataType {
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \param lanes The number of lanes.
* \param is_scalable Whether the data type is scalable.
* \return The constructed data type.
*/
static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
return DataType(kDLUInt, bits, lanes, is_scalable);
}
/*!
* \brief Construct an float type.
* \param bits The number of bits in the type.
Expand Down Expand Up @@ -243,10 +248,13 @@ class DataType {
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \param lanes The number of lanes.
* \param is_scalable Whether the data type is scalable.
* \return The constructed data type.
*/
static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
static DataType Bool(int lanes = 1, bool is_scalable = false) {
return DataType::UInt(1, lanes, is_scalable);
}
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,13 @@ def _has_cpu_feat(features):
)


requires_aarch64_sve = Feature(
"arm_sve",
"AArch64 SVE",
run_time_check=lambda: _has_cpu_feat("sve"),
)


requires_x86_vnni = Feature(
"x86_vnni",
"x86 VNNI Extensions",
Expand Down
66 changes: 35 additions & 31 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
LOG(FATAL) << "do not support " << dtype;
}
}
if (dtype.lanes() != 1) {
if (!dtype.is_scalar()) {
#if TVM_LLVM_VERSION >= 110
return llvm::FixedVectorType::get(etype, dtype.lanes());
if (dtype.is_scalable_vector()) {
return llvm::VectorType::get(etype, dtype.vscale_factor(), true);
} else {
return llvm::FixedVectorType::get(etype, dtype.lanes());
}
#else
ICHECK(!dtype.is_scalable_vector())
<< "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
"version.";
return llvm::VectorType::get(etype, dtype.lanes());
#endif
} else {
Expand Down Expand Up @@ -749,26 +756,6 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Modul
return debug_info;
}

llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
#if TVM_LLVM_VERSION >= 110
llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
#else
llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
#endif
llvm::Constant* undef = llvm::UndefValue::get(type);
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(undef, value, zero);
#if TVM_LLVM_VERSION >= 120
llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero);
#elif TVM_LLVM_VERSION >= 110
llvm::Constant* mask =
llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero);
#else
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
#endif
return builder_->CreateShuffleVector(value, undef, mask);
}

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
Expand Down Expand Up @@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper(
}

PrimExpr last_index = indices[indices.size() - 1];
ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());
ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(),
last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes());

// Record index and elemtype in original form used for alias info
PrimExpr last_index_origin = last_index;
Expand Down Expand Up @@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper(
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = last_index.as<RampNode>()) {
// TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
ICHECK(!last_index.dtype().is_scalable_vector());
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
Expand All @@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper(
all_index_values.push_back(last_index_value);

TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
value_dtype.is_scalable_vector()
? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() /
last_index.dtype().lanes()))
: CreateBufferPtr(
MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin);
}
Expand Down Expand Up @@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
// TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
ICHECK(!op->dtype.is_scalable_vector());
int lanes = op->dtype.lanes();
return CreateBroadcast(MakeValue(op->value), lanes);
DataType dtype = op->dtype;
llvm::Value* value = MakeValue(op->value);
llvm::Type* type = DTypeToLLVMType(dtype);
llvm::Constant* undef = llvm::UndefValue::get(type);
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(undef, value, zero);
#if TVM_LLVM_VERSION >= 110
llvm::ElementCount ec =
llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), dtype.is_scalable_vector());
llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero);
#else
ICHECK(!dtype.is_scalable_vector())
<< "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
"version.";
llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero);
#endif
return builder_->CreateShuffleVector(value, undef, mask);
}

void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
Expand Down
1 change: 0 additions & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype);
// Vector concatenation.
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {

Buffer new_buffer = GetRemappedBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
value = cast(new_buffer->dtype, value);
}
auto indices = VisitIndices(op->indices);
Expand Down
7 changes: 5 additions & 2 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ namespace tir {
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
<< b.dtype() << "\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = DataType::Bool(a.dtype().lanes()); \
DataType a_dtype = a.dtype(); \
node->dtype = \
DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \
node->a = std::move(a); \
node->b = std::move(b); \
node->span = std::move(span); \
Expand Down Expand Up @@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) {
ICHECK(a.dtype().is_bool());

ObjectPtr<NotNode> node = make_object<NotNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
DataType a_dtype = a.dtype();
node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector());
node->a = std::move(a);
node->span = std::move(span);
data_ = std::move(node);
Expand Down
7 changes: 7 additions & 0 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";

if (value_dtype.is_scalable_vector()) {
// Scalable types are not currently supported in storage_rewrite. Scalable buffer
// accesses are not currently checked and therefore are not rewritten.
return;
}

BufferVarInfo& var_info = it->second;

if (value_dtype.element_of() == DataType::Bool()) {
Expand Down
16 changes: 16 additions & 0 deletions tests/cpp/tir_scalable_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) {
tvm::InternalError);
}

TEST(ScalableDataType, TestScalableBool) {
tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
ASSERT_EQ(scalable_type.code(), kDLUInt);
ASSERT_EQ(scalable_type.bits(), 1);
ASSERT_EQ(scalable_type.vscale_factor(), 4);
ASSERT_TRUE(scalable_type.is_scalable_vector());
}

TEST(ScalableDataType, TestScalableUInt) {
tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true);
ASSERT_EQ(scalable_type.code(), kDLUInt);
ASSERT_EQ(scalable_type.bits(), 1);
ASSERT_EQ(scalable_type.vscale_factor(), 4);
ASSERT_TRUE(scalable_type.is_scalable_vector());
}

// -----------
// Integration
// -----------
Expand Down
41 changes: 41 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,5 +492,46 @@ def main(A: T.Buffer((5,), "int32")):
assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
def test_scalable_buffer_load_store():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"

@T.prim_func
def my_func(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128,), "float32")
B = T.match_buffer(b, (128,), "float32")
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]

mod = tvm.build(my_func, target=target)
llvm = mod.get_source("ll")

assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load in generated LLVM."
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
def test_scalable_broadcast():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"

@T.prim_func
def my_func(a: T.handle):
A = T.match_buffer(a, (128,), "float32")
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())

mod = tvm.build(my_func, target=target)
llvm = mod.get_source("ll")

assert re.findall(
r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x float>", llvm
), "No scalable broadcast in generated LLVM."
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 22d32df

Please sign in to comment.