Skip to content

Commit

Permalink
fix: cuda codegen vectorize cast
Browse files Browse the repository at this point in the history
  • Loading branch information
kongroo committed Mar 2, 2021
1 parent 485dfd6 commit 4900849
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 27 deletions.
134 changes: 113 additions & 21 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ std::string CodeGenCUDA::Finish() {
decl_stream << " #define uint unsigned int\n";
decl_stream << " #define uchar unsigned char\n";
decl_stream << " #define ushort unsigned short\n";
decl_stream << " #define int64_t long\n";
decl_stream << " #define uint64_t ulong\n";
decl_stream << " #define int64_t long long\n";
decl_stream << " #define uint64_t unsigned long long\n";
decl_stream << "#endif\n";

return CodeGenC::Finish();
Expand Down Expand Up @@ -141,7 +141,21 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
break;
case 32:
os << "float";
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
Expand All @@ -151,6 +165,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
Expand Down Expand Up @@ -238,12 +253,53 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
break;
}
}
case 16:
os << "short";
break;
case 32:
os << "int";
case 16: {
if (t.is_scalar()) {
os << "short";
} else if (t.lanes() <= 4) {
os << "short" << lanes;
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
Expand Down Expand Up @@ -314,21 +370,36 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
}

static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) {
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((char)(" << vec << " >> " << i * 8 << "))";
}
} else if ((t.is_uint()) && t.bits() == 8) {
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
Expand All @@ -338,22 +409,43 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
stream << vec << "=";
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
}
stream << "(" << value << " << " << i * 8 << ");\n";
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
Expand Down
17 changes: 11 additions & 6 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_cuda_floormod_with_vectorization():
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_casts():
def check(t0, t1):
def check(t0, t1, factor):
if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
Expand All @@ -511,8 +511,8 @@ def check(t0, t1):

# schedule
s = tvm.te.create_schedule(C.op)
ob, ib = s[C].split(s[C].op.axis[0], nparts=32)
_, iib = s[C].split(ib, factor=4)
ob, ib = s[C].split(s[C].op.axis[0], nparts=128//factor)
_, iib = s[C].split(ib, factor=factor)
s[C].vectorize(iib)
s[C].bind(ob, tx)
func = tvm.build(s, [A, B, C], "cuda")
Expand All @@ -538,9 +538,14 @@ def skip(t0, t1):
return True
return False

types = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
for t0, t1 in [(x, y) for x in types for y in types if not skip(x, y)]:
check(t0, t1)
types_4 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32", "float64", "int64", "uint64"]
types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]:
check(t0, t1, 4)
for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]:
check(t0, t1, 8)
check('int8', 'uint8', 16)
check('uint8', 'int8', 16)


def sched(B):
Expand Down

0 comments on commit 4900849

Please sign in to comment.