Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tkurth/sgbn fixes #1685

Merged
merged 4 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 70 additions & 38 deletions apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

#include "norm_sample.h"

// define this enum:
enum bn_type { BN_FWD, BN_BWD };

// this is a global variable
static std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;

at::Tensor gbn_forward(const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
Expand Down Expand Up @@ -38,28 +44,43 @@ at::Tensor gbn_forward(const at::Tensor& x,
void_peer_buffers.push_back((void*)addr);
}

// we need the peer size for the buffer reset
size_t peer_size = 1;
for (size_t i = 0; i < 4; ++i){
peer_size *= peerDims[i];
}

// sanity check
assert(bn_group == void_peer_buffers.size());
run_batch_norm_forward(
perChannelDims,
epsilonDims,
tensorDims,
peerDims,
x.data_ptr(),
y.data_ptr(),
scale.data_ptr(),
bias.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
epsilon,
momentum,
rank_id
);

// check if plan already exists
std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
gbn_plan_cache.insert(std::make_pair(fv, plan));
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like some of the code makes assumptions about the input tensor(s)' memory layout. If so, there should be checks like is_contiguous(at::MemoryFormat::ChannelsLast).

Copy link
Contributor Author

@azrael417 azrael417 Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done on the python frontend. That check is here

// get plan and handle
auto plan = gbn_plan_cache.find(fv)->second;

// execute
execute_batch_norm_forward(plan,
x.data_ptr(),
y.data_ptr(),
scale.data_ptr(),
bias.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
running_mean.data_ptr(),
running_var.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
static_cast<double>(epsilon),
static_cast<double>(momentum),
peer_size,
rank_id);

return y;
}

Expand Down Expand Up @@ -98,26 +119,37 @@ std::vector<at::Tensor> gbn_backward(
void_peer_buffers.push_back((void*)addr);
}

// we need the peer size for the buffer reset
size_t peer_size = 1;
for (size_t i = 0; i < 4; ++i){
peer_size *= peerDims[i];
}

assert(bn_group == void_peer_buffers.size());

run_batch_norm_backward(
perChannelDims,
epsilonDims,
tensorDims,
peerDims,
x.data_ptr(),
dy.data_ptr(),
scale.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
x_grad.data_ptr(),
scale_grad.data_ptr(),
bias_grad.data_ptr(),
void_peer_buffers,
epsilon,
rank_id);


std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
gbn_plan_cache.insert(std::make_pair(fv, plan));
}

// get plan and handle
auto plan = gbn_plan_cache.find(fv)->second;

// execute
execute_batch_norm_backward(plan,
x.data_ptr(),
dy.data_ptr(),
scale.data_ptr(),
minibatch_mean.data_ptr(),
minibatch_inv_var.data_ptr(),
void_peer_buffers,
x_grad.data_ptr(),
scale_grad.data_ptr(),
bias_grad.data_ptr(),
static_cast<double>(epsilon),
peer_size,
rank_id);

return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
Expand Down
Loading