Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cutlass] Sparse conv3d backward fusion #52361

Merged
merged 32 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d037457
commit for saving, not work now :(
umiswing Feb 21, 2023
63decbd
finally it pass compilation...
umiswing Feb 22, 2023
5a1f9a3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
umiswing Feb 22, 2023
865f12a
change GetKey() to GenKey()
umiswing Feb 22, 2023
5084bc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
umiswing Mar 1, 2023
a21618f
works for fp16 and fp32 on sm 80.
umiswing Mar 2, 2023
630319d
clean the code.
umiswing Mar 3, 2023
31984a0
remove scripts for sm 70
umiswing Mar 3, 2023
cd414d9
remove some comment
umiswing Mar 3, 2023
4f24f11
remove some unused header.
umiswing Mar 6, 2023
62c8120
restructure code.
umiswing Mar 9, 2023
384c34e
restructure more codes.
umiswing Mar 10, 2023
1b8072b
remove some unused codes.
umiswing Mar 10, 2023
4c7c25f
commit for saving.
umiswing Mar 10, 2023
5eb554c
modify interface for backward.
umiswing Mar 13, 2023
447d4db
run successfully, result need to be checked.
umiswing Mar 14, 2023
7aebcdb
add split k, but still slow
umiswing Mar 14, 2023
afc967d
Fix a bug in conv_grad_kernel.cu.
umiswing Mar 15, 2023
6d2bc70
fix compile
umiswing Mar 15, 2023
a6d7c95
Merge branch 'fix_make' into spconv_back_fuse
umiswing Mar 16, 2023
d587a76
fix shape to key mapping error in conv_grad shape.
umiswing Mar 16, 2023
5ac88b7
try to add a reduce kernel, not work yet...
umiswing Mar 21, 2023
a96b299
Can pass compilication adding reduce, not work yet.
umiswing Mar 22, 2023
197c59a
using device::reduce, still not work yet. :(
umiswing Mar 22, 2023
bc20db8
Reduction run without illegal memory access, but still not correct.
umiswing Mar 22, 2023
2b861fc
Compute correctlly, but super slow.
umiswing Mar 27, 2023
1431d1b
Works and fast now.
umiswing Mar 29, 2023
7f26ec7
Merge branch 'spconv_back_fuse' into back_fusion
umiswing Mar 30, 2023
f2c83fe
codestyle fix.
umiswing Mar 30, 2023
f8ed53b
revert some changes.
umiswing Mar 30, 2023
e602bd7
Add a status checks
umiswing Mar 30, 2023
f215798
Remove backward fusion in fp16 since it's slow.
umiswing Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions paddle/phi/kernels/autotune/auto_tune_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,34 @@ class MatmulAutoTuner
}
};

template <typename T, typename ReturnType, typename... Args>
template <bool TransposeA,
bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> {
public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(T, T, Args...)) {
static GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>*
Instance(ReturnType (*func)(T, T, Args...)) {
static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>>
static std::unique_ptr<GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>>
instance;
std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>);
instance.reset(new GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>);
instance->AddCallBack(func);
});
return instance.get();
Expand All @@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter<T>();
auto& cache = AutoTuneCache::Instance()
.GetGatherGemmScatter<T, TransposeA, TransposeB>();

if (cache.Find(key)) {
auto best_idx = cache.Get(key);
Expand Down Expand Up @@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
return best_idx;
}
};
template <typename T, typename ReturnType, typename... Args>
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>*
template <bool TransposeA,
bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
static GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>*
MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) {
return GatherGemmScatterAutoTuner<T, ReturnType, Args...>::Instance(func);
return GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>::Instance(func);
}

// Define the auto_tuner inital object.
Expand Down
53 changes: 34 additions & 19 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ enum class AlgorithmType {
kMatmul = 5,
kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7,
kGatherGemmScatterFP32TN = 8,
kGatherGemmScatterFP32NT = 9,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 8
kAlgorithmCount = 10
#else
kConvForwardV8 = 8,
kConvBackwardDataV8 = 9,
kConvBackwardFilterV8 = 10,
kAlgorithmCount = 11
kConvForwardV8 = 10,
kConvBackwardDataV8 = 11,
kConvBackwardFilterV8 = 12,
kAlgorithmCount = 13
#endif
};

Expand All @@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif

#define DEFINE_GET_GATHER_GEMM_SCATTER( \
dtype, transpose_a, transpose_b, algo_type) \
template <typename T, bool TransposeA, bool TransposeB> \
typename std::enable_if<std::is_same<T, dtype>::value && \
TransposeA == transpose_a && \
TransposeB == transpose_b, \
AlgorithmsCacheMap&>::type \
GetGatherGemmScatter() { \
return Get(algo_type); \
}

class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
Expand All @@ -89,20 +102,22 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}

template <typename T>
typename std::enable_if<std::is_same<T, float>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP32NN);
}

template <typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP16NN);
}
DEFINE_GET_GATHER_GEMM_SCATTER(phi::dtype::float16,
false,
false,
AlgorithmType::kGatherGemmScatterFP16NN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
false,
false,
AlgorithmType::kGatherGemmScatterFP32NN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
true,
false,
AlgorithmType::kGatherGemmScatterFP32TN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
false,
true,
AlgorithmType::kGatherGemmScatterFP32NT);

#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
Expand Down
193 changes: 135 additions & 58 deletions paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif

namespace phi {
namespace sparse {
extern size_t workspace_size;

// rulebook[3, rulebook_len]:
//[
Expand Down Expand Up @@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());

GroupIndexsV2<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
x.nnz(),
kernel_size,
offsets[kernel_size / 2],
rulebook_ptr,
out_index_ptr,
unique_value_ptr);
#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;

GatherV2<T, IntT>(dev_ctx,
x.values().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);
if (in_channels % 4 != 0 || out_channels % 4 != 0) cutlass = false;

Gather<T, IntT>(dev_ctx,
out_grad.values().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
out_channels,
out_grad_features_ptr);
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, double>::value)
cutlass = false;

if (!std::is_same<IntT, int32_t>::value) cutlass = false;

if (!cutlass) {
#endif

GroupIndexsV2<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
x.nnz(),
kernel_size,
offsets[kernel_size / 2],
rulebook_ptr,
out_index_ptr,
unique_value_ptr);

GatherV2<T, IntT>(dev_ctx,
x.values().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);

Gather<T, IntT>(dev_ctx,
out_grad.values().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
out_channels,
out_grad_features_ptr);

#ifdef PADDLE_WITH_CUTLASS
}
#endif
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
Expand All @@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;

// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);
#ifdef PADDLE_WITH_CUTLASS
if (cutlass) {
const IntT* gather_x_indices = rulebook_ptr + offsets[i];
const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K);
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size);
GatherGemmScatterDriver<T, IntT, true, false>(
dev_ctx,
key,
x.values().data<T>(),
out_grad.values().data<T>(),
tmp_d_kernel_ptr,
tmp_d_kernel_ptr,
in_channels,
out_channels,
counter_ptr[i],
gather_x_indices,
gather_out_indices,
static_cast<const IntT*>(nullptr),
static_cast<const T>(1.0),
static_cast<const T>(0.0),
&workspace);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<T, IntT, false, true>(
dev_ctx,
key,
out_grad.values().data<T>(),
tmp_kernel_ptr,
x_grad_values_ptr,
x_grad_values_ptr,
counter_ptr[i],
in_channels,
out_channels,
gather_out_indices,
static_cast<const IntT*>(nullptr),
scatter_x_indices,
static_cast<const T>(1.0),
static_cast<const T>(1.0),
nullptr);
} else {
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);

// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas.GEMM(CblasNoTrans,
CblasTrans,
M,
K,
N,
static_cast<T>(1),
tmp_out_grad_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_d_x_ptr);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas.GEMM(CblasNoTrans,
CblasTrans,
M,
K,
N,
static_cast<T>(1),
tmp_out_grad_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_d_x_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}

// 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
d_x_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
x_grad->nnz(),
kernel_size,
in_channels,
2,
x_grad_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
if (!cutlass) {
#endif
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
d_x_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
x_grad->nnz(),
kernel_size,
in_channels,
2,
x_grad_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}

template <typename T, typename Context>
Expand Down
Loading