Skip to content

Commit

Permalink
more refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 17, 2022
1 parent 146464e commit dcbd9c9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 99 deletions.
87 changes: 22 additions & 65 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
DLTensor* y, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
SetConvDescriptors(entry_ptr, mode, format, algo, dims, groups, pad, stride, dilation, x, w, y,
conv_dtype);

// Set workspace
size_t workspace_size = 0;
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
y->shape, x->dtype, conv_dtype);
// Set Device
entry_ptr->conv_entry.device = x->device;
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);

// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
Expand Down Expand Up @@ -112,67 +118,18 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));

if (format == 1) {
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
int ni = 0;
int ci = 3;
int hi = 1;
int wi = 2;

// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
static_cast<int>(x_dim[wi])));

CUDNN_CALL(cudnnSetFilter4dDescriptor(
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
static_cast<int>(w_dim[wi])));
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
static_cast<int>(y_dim[wi])));

CUDNN_CALL(cudnnSetConvolution2dDescriptor(
entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
} else {
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));

std::vector<int> tensor_stride(full_dims);
// input desc
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
x_dim, tensor_stride.data()));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, full_dims, w_dim));

// output desc
GetCudnnStride(full_dims, y_dim, tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
y_dim, tensor_stride.data()));
}

if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
const int full_dims = dims + 2;
std::vector<int64_t> x_dim_int64(full_dims);
std::vector<int64_t> w_dim_int64(full_dims);
std::vector<int64_t> y_dim_int64(full_dims);
for (int i = 0; i < full_dims; ++i) {
x_dim_int64[i] = x_dim[i];
w_dim_int64[i] = w_dim[i];
y_dim_int64[i] = y_dim[i];
}
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
conv_dtype);

int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
Expand Down
56 changes: 27 additions & 29 deletions src/runtime/contrib/cudnn/cudnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,18 @@ void ConvEntry::CleanWorkspace() {
workspace_size = 0;
}

void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims,
int groups, const int pad[], const int stride[], const int dilation[],
DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype) {
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups,
const int pad[], const int stride[], const int dilation[], int64_t x_dim[],
int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype,
const std::string& conv_dtype) {
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
// Set Device
entry_ptr->conv_entry.device = x->device;
// Set Data Type
entry_ptr->conv_entry.data_type =
CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype));

cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype);

// Dims includes N and C
int full_dims = dims + 2;

Expand Down Expand Up @@ -205,47 +201,49 @@ void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int a
wi = 3;
}

// Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w->shape[ni]), static_cast<int>(w->shape[ci]),
static_cast<int>(w->shape[hi]), static_cast<int>(w->shape[wi])));
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(x->shape[ni]), static_cast<int>(x->shape[ci]),
static_cast<int>(x->shape[hi]), static_cast<int>(x->shape[wi])));
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type,
static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
static_cast<int>(x_dim[wi])));
// Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(
entry_ptr->conv_entry.filter_desc, cudnn_data_type, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
static_cast<int>(w_dim[wi])));
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(y->shape[ni]), static_cast<int>(y->shape[ci]),
static_cast<int>(y->shape[hi]), static_cast<int>(y->shape[wi])));
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type,
static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
static_cast<int>(y_dim[wi])));
} else {
ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors.";

CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, entry_ptr->conv_entry.mode,
entry_ptr->conv_entry.data_type));

// Set Filter
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(w->shape[i]);
dim[i] = static_cast<int>(w_dim[i]);
}
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, cudnn_data_type,
entry_ptr->conv_entry.tensor_format, full_dims,
dim.data()));
// Set Input
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(x->shape[i]);
dim[i] = static_cast<int>(x_dim[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
dim.data(), tensor_stride.data()));
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, cudnn_data_type,
full_dims, dim.data(), tensor_stride.data()));
// Set Output
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(y->shape[i]);
dim[i] = static_cast<int>(y_dim[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
dim.data(), tensor_stride.data()));
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, cudnn_data_type,
full_dims, dim.data(), tensor_stride.data()));
}

if (cudnnGetVersion() > 7000) {
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) {

struct ConvEntry {
cudnnConvolutionDescriptor_t conv_desc;
cudnnConvolutionMode_t mode;
cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION};
cudnnFilterDescriptor_t filter_desc;
cudnnDataType_t data_type;
cudnnTensorFormat_t tensor_format;
Expand Down Expand Up @@ -103,9 +103,10 @@ struct CuDNNThreadEntry {
static CuDNNThreadEntry* ThreadLocal(bool check_exists = true);
}; // CuDNNThreadEntry

void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims,
int groups, const int pad[], const int stride[], const int dilation[],
DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype);
void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups,
const int pad[], const int stride[], const int dilation[], int64_t x_dim[],
int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype,
const std::string& conv_dtype);

} // namespace contrib
} // namespace tvm
Expand Down
4 changes: 3 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,6 @@ def test_conv_output_shape(conv_output_shape_kwargs):


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
# sys.exit(pytest.main(sys.argv))
test_conv2d()
test_conv3d()

0 comments on commit dcbd9c9

Please sign in to comment.