Skip to content

Commit

Permalink
[topi][CuDNN] Removed requirement for GPU from topi conv2d_cudnn.cuda…
Browse files Browse the repository at this point in the history
… and conv3d_cudnn.cuda (#8276)

Previously, `conv2d_cudnn.cuda` would use cudnn's benchmarking
function to select a forward convolution when `cfg.is_fallback`, and
`conv3d_cudnn.cuda` would use cudnn's benchmarking at all times.
After this commit, both expose the cudnn algorithm choice as an
option.  If `cfg.is_fallback`, the local device will be benchmarked if
present, otherwise will select a default cudnn implementation.

In the future, to better support RPC use-cases, the fallback config
should be based on cudnn-specific parameters saved in the Target
object.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg authored Jun 18, 2021
1 parent 5f94c1e commit bf3f000
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 8 deletions.
18 changes: 18 additions & 0 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@
_ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"]


def exists():
"""
Checks whether the local machine can use CuDNN.
Returns
-------
exists: bool
True if CuDNN support is enabled and a CuDNN-capable GPU
exists. Otherwise, False.
"""
func = tvm.get_global_func("tvm.contrib.cudnn.exists", allow_missing=True)
if func is None:
return False

return bool(func())


def algo_to_index(algo_type, algo_name):
"""Return a index represents the algorithm, which can be used in
calling CuDNN function
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ def conv2d_cudnn(
else:
dtype = data.dtype

cfg.define_knob("algo", range(8))
if cfg.is_fallback: # Let CUDNN choose the best algo
cfg["algo"] = OtherOptionEntity(-1)
cfg.define_knob("algo", range(cudnn.algo_to_index("fwd", "CUDNN_CONVOLUTION_FWD_ALGO_COUNT")))
if cfg.is_fallback:
if cudnn.exists():
# Let CUDNN choose the best algo, based on benchmarks run
# on the local machine. In the future, this should be
# based on parameters stored in the Target.
cfg["algo"] = OtherOptionEntity(-1)
else:
cfg["algo"] = OtherOptionEntity(0)

return cudnn.conv_forward(
data,
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/topi/cuda/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def conv3d_cudnn(
* ((KW - 1) * dilation_w + 1)
)

cfg.define_knob("algo", range(cudnn.algo_to_index("fwd", "CUDNN_CONVOLUTION_FWD_ALGO_COUNT")))
if cfg.is_fallback:
if cudnn.exists():
# Let CUDNN choose the best algo, based on benchmarks run
# on the local machine. In the future, this should be
# based on parameters stored in the Target.
cfg["algo"] = OtherOptionEntity(-1)
else:
cfg["algo"] = OtherOptionEntity(0)

return cudnn.conv_forward(
data,
kernel,
Expand All @@ -229,7 +239,7 @@ def conv3d_cudnn(
[dilation_d, dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo
algo=cfg["algo"].val,
conv_dtype=dtype,
)

Expand Down
32 changes: 29 additions & 3 deletions src/runtime/contrib/cudnn/cudnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,38 @@ CuDNNThreadEntry::CuDNNThreadEntry() {
auto func = runtime::Registry::Get("device_api.cuda");
void* ret = (*func)();
cuda_api = static_cast<runtime::DeviceAPI*>(ret);
CUDNN_CALL(cudnnCreate(&handle));

// If no CuDNN-capable device is present, allow the CuDNNThreadEntry
// object to be created. This is needed for
// CuDNNThreadEntry::exists.
{
cudnnStatus_t create_res = cudnnCreate(&handle);
if (create_res == CUDNN_STATUS_NOT_INITIALIZED) {
return;
}
CUDNN_CALL(create_res);
}

CUDNN_CALL(cudnnSetStream(handle, stream));
conv_entry.cuda_api = cuda_api;
}

CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); }
CuDNNThreadEntry::~CuDNNThreadEntry() {
if (handle) {
CUDNN_CALL(cudnnDestroy(handle));
}
}

typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;

CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); }
CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) {
auto* res = CuDNNThreadStore::Get();
if (check_exists) {
ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED";
}

return res;
}

// ConvEntry

Expand Down Expand Up @@ -148,5 +170,9 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des

SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); }

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool {
return CuDNNThreadEntry::ThreadLocal(false)->exists();
});

} // namespace contrib
} // namespace tvm
5 changes: 4 additions & 1 deletion src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ struct SoftmaxEntry {
struct CuDNNThreadEntry {
CuDNNThreadEntry();
~CuDNNThreadEntry();

bool exists() const { return handle; }

cudnnHandle_t handle{nullptr};
ConvEntry conv_entry;
SoftmaxEntry softmax_entry;
runtime::DeviceAPI* cuda_api{nullptr};
static CuDNNThreadEntry* ThreadLocal();
static CuDNNThreadEntry* ThreadLocal(bool check_exists = true);
}; // CuDNNThreadEntry

} // namespace contrib
Expand Down

0 comments on commit bf3f000

Please sign in to comment.