Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tristan Konolige committed Jul 29, 2022
1 parent f8df56b commit 0f64454
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/tvm/utils/roofline/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...rpc.base import RPC_SESS_MASK
from ...rpc.client import RPCSession
from . import registry
from ...contrib import utils
from ...contrib import utils, nvcc


@registry.estimate_peak_flops.register("cuda")
Expand All @@ -50,7 +50,7 @@ def estimate_peak_flops_tensorcore(
----------
target : Target
Target to run on. This should be as specific to the actual hardware as
possible to make sure that LLVM generates the best vector code.
possible.
dev : Device
Device to run on.
remote : Optional[RPCSession]
Expand Down Expand Up @@ -166,7 +166,7 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.i
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
B = T.match_buffer(b, [blocks, warp_size, 4], "float32")
B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
for i in T.thread_binding(blocks, "blockIdx.x"):
for k in T.serial(N):
for l in T.unroll(4):
Expand All @@ -175,7 +175,7 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.i
# += is necessary to introduce a data dependency for all
# elements of A, preventing the backend from removing the
# `k` loop and setting `k` to the loop extent.
B[i, j, l] += A[i, k, l, j]
B[i, l, j] += A[i, k, l, j]


@registry.estimate_peak_bandwidth.register("cuda")
Expand Down Expand Up @@ -206,8 +206,12 @@ def estimate_peak_bandwidth(
float
Peak memory bandwidth in bytes/seconds.
"""
blocks = 1024
assert nvcc.have_tensorcore(
dev.compute_version
), "CUDA roofline only works with devices that have tensorcores"
warp_size = dev.warp_size
# These sizes seem large enough to give the card time to hit a fixpoint on memory bandwidth
blocks = 1024
size = 1024

specialized = peak_bandwidth_tir.specialize(
Expand All @@ -227,6 +231,6 @@ def estimate_peak_bandwidth(
f = remote.load_module("peak_bandwidth.tar")

a = nd.empty((blocks, size, 4, warp_size), dtype="float32", device=dev)
b = nd.empty((blocks, warp_size, 4), dtype="float32", device=dev)
b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
return a.numpy().size * 4 / times.min # 4 bytes per float32
1 change: 1 addition & 0 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
} else if (instance->IsInstance<PrimExprNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
} else {
CHECK(instance.defined()) << "Specialize instance is not defined for param " << param;
LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
<< instance->GetTypeKey();
}
Expand Down

0 comments on commit 0f64454

Please sign in to comment.