Skip to content

Commit

Permalink
Tkurth/sgbn fixes (NVIDIA#1685)
Browse files Browse the repository at this point in the history
* fixing order of class instantiation and device extraction in mixed precision lamb

* this commit fixes the SGBN graph capture problem by caching the cudnn plan and re-using it

* disentangling the mplamb MR and SGBN MR

* cleaner caching
  • Loading branch information
azrael417 committed Jul 2, 2023
1 parent 30a7ad3 commit 8ffc901
Show file tree
Hide file tree
Showing 3 changed files with 622 additions and 511 deletions.
108 changes: 70 additions & 38 deletions apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

#include "norm_sample.h"

// define this enum:
enum bn_type { BN_FWD, BN_BWD };

// this is a global variable
static std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;

at::Tensor gbn_forward(const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
Expand Down Expand Up @@ -38,28 +44,43 @@ at::Tensor gbn_forward(const at::Tensor& x,
void_peer_buffers.push_back((void*)addr);
}

// we need the peer size for the buffer reset
size_t peer_size = 1;
for (size_t i = 0; i < 4; ++i){
peer_size *= peerDims[i];
}

// sanity check
assert(bn_group == void_peer_buffers.size());
run_batch_norm_forward(
perChannelDims,
epsilonDims,
tensorDims,
peerDims,
x.data_ptr(),
y.data_ptr(),
scale.data_ptr(),
bias.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
epsilon,
momentum,
rank_id
);

// check if plan already exists
std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
gbn_plan_cache.emplace(fv, std::move(plan));
}

// get plan and handle
auto plan = gbn_plan_cache.find(fv)->second;

// execute
execute_batch_norm_forward(plan,
x.data_ptr(),
y.data_ptr(),
scale.data_ptr(),
bias.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
static_cast<double>(epsilon),
static_cast<double>(momentum),
peer_size,
rank_id);

return y;
}

Expand Down Expand Up @@ -98,26 +119,37 @@ std::vector<at::Tensor> gbn_backward(
void_peer_buffers.push_back((void*)addr);
}

// we need the peer size for the buffer reset
size_t peer_size = 1;
for (size_t i = 0; i < 4; ++i){
peer_size *= peerDims[i];
}

assert(bn_group == void_peer_buffers.size());

run_batch_norm_backward(
perChannelDims,
epsilonDims,
tensorDims,
peerDims,
x.data_ptr(),
dy.data_ptr(),
scale.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
x_grad.data_ptr(),
scale_grad.data_ptr(),
bias_grad.data_ptr(),
void_peer_buffers,
epsilon,
rank_id);


std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
gbn_plan_cache.emplace(fv, std::move(plan));
}

// get plan and handle
auto plan = gbn_plan_cache.find(fv)->second;

// execute
execute_batch_norm_backward(plan,
x.data_ptr(),
dy.data_ptr(),
scale.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
x_grad.data_ptr(),
scale_grad.data_ptr(),
bias_grad.data_ptr(),
static_cast<double>(epsilon),
peer_size,
rank_id);

return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
Expand Down
Loading

0 comments on commit 8ffc901

Please sign in to comment.