Skip to content

Commit

Permalink
[PTX] Intrinsics for async copy from global to shared (SM80) (apache#…
Browse files Browse the repository at this point in the history
…11368)

* registor ptx builtin for async copy

* add basic codegen

* add test

* update codegen

* wip

* codegen bug fixed, test working

* add commit group

* add doc
  • Loading branch information
masahi committed May 25, 2022
1 parent fdcc22b commit 8208e73
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
os << dst << "[" << dst_offset << " + i] = 0.0;";
os << "}\n";
} else if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
std::string N = this->PrintExpr(op->args[0]);
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
26 changes: 26 additions & 0 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,5 +638,31 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
return asm_code;
}

std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes) {
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.cg.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
)";
Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

} // namespace codegen
} // namespace tvm
13 changes: 13 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
const std::string& smem_ptr,
const std::string& smem_elem_offset);

/*!
* \brief Print ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
*/
std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes);

} // namespace codegen
} // namespace tvm

Expand Down
70 changes: 70 additions & 0 deletions tests/python/unittest/test_tir_ptx_cp_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.script import tir as T
import numpy as np
import tvm.testing


@T.prim_func
def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
T.reads(A[0:32, 0:128])
T.writes(B[0:32, 0:128])

for i in range(16):
T.evaluate(
T.ptx_cp_async(
A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
)
)

# TODO(masahi): Remove dtype requirement from TVMScript parser
T.evaluate(T.ptx_commit_group(dtype="float16"))
T.evaluate(T.ptx_wait_group(0, dtype="float16"))

for i in range(128):
B[tx, i] = A_shared[tx, i]


@tvm.testing.requires_cuda
def test_ptx_cp_async():
f = ptx_cp_async
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return

mod = tvm.build(f, target="cuda")
A_np = np.random.rand(32, 128).astype("float16")
B_np = np.zeros((32, 128)).astype("float16")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.numpy(), A_np)


if __name__ == "__main__":
test_ptx_cp_async()

0 comments on commit 8208e73

Please sign in to comment.