From 73161305f009ef0a210f6da76234fe94f5316b5b Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 21 Jun 2023 23:40:44 -0700 Subject: [PATCH 1/4] fixing order of class instantiation and device extraction in mixed precision lamb --- apex/optimizers/fused_mixed_precision_lamb.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index f1b2902ca..eaf70a1ce 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -12,22 +12,27 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True, amsgrad=False, adam_w_mode=True, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False, reduced_precision_dtype=None): + if amsgrad: raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - - # The learning rate (lr) and optimizer step (step) should be located on device - # in order to faciliated device sync free execution + + # init defaults defaults = dict(lr=torch.tensor(lr, dtype=torch.float32), step=torch.tensor([step], dtype=torch.int), bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm) - tensor_state = ['lr', 'step'] - super(FusedMixedPrecisionLamb, self).__init__(params, defaults) + # init base module + super(FusedMixedPrecisionLamb, self).__init__(params, defaults) + + # The learning rate (lr) and optimizer step (step) should be located on device device = self.param_groups[0]['params'][0].device + # The learning rate (lr) and optimizer step (step) should be located on device + # in order to faciliated device sync free execution + tensor_state = ['lr', 'step'] for idx,group in enumerate(self.param_groups): for item in tensor_state: self.param_groups[idx][item] = group[item].to(device=device) From dafe66aa02d6d26030d8f5a251ddd1bf8c24ba00 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 21 Jun 2023 23:46:34 -0700 Subject: [PATCH 2/4] this commit fixes the SGBN graph capture problem by caching the cudnn plan and re-using it --- apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp | 108 ++- apex/contrib/csrc/cudnn_gbn/norm_sample.cpp | 867 ++++++++++---------- apex/contrib/csrc/cudnn_gbn/norm_sample.h | 158 +++- 3 files changed, 622 insertions(+), 511 deletions(-) diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index ded444a6a..c1d32e1d6 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -7,6 +7,12 @@ #include "norm_sample.h" +// define this enum: +enum bn_type { BN_FWD, BN_BWD }; + +// this is a global variable +static std::map, cudnn_frontend::ExecutionPlan> gbn_plan_cache; + at::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, @@ -38,28 +44,43 @@ at::Tensor gbn_forward(const at::Tensor& x, void_peer_buffers.push_back((void*)addr); } + // we need the peer size for the buffer reset + size_t peer_size = 1; + for (size_t i = 0; i < 4; ++i){ + peer_size *= peerDims[i]; + } + + // sanity check assert(bn_group == void_peer_buffers.size()); - run_batch_norm_forward( - perChannelDims, - epsilonDims, - tensorDims, - peerDims, - x.data_ptr(), - y.data_ptr(), - scale.data_ptr(), - bias.data_ptr(), - running_mean.data_ptr(), - running_var.data_ptr(), - running_mean.data_ptr(), - running_var.data_ptr(), - minibatch_mean.data_ptr(), - minibatch_inv_var.data_ptr(), - void_peer_buffers, - epsilon, - momentum, - rank_id - ); + // check if plan already exists + std::vector fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; + if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { + auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); + gbn_plan_cache.insert(std::make_pair(fv, plan)); + } + + // get plan and handle + auto plan = gbn_plan_cache.find(fv)->second; + + // execute + execute_batch_norm_forward(plan, + x.data_ptr(), + y.data_ptr(), + scale.data_ptr(), + bias.data_ptr(), + running_mean.data_ptr(), + running_var.data_ptr(), + running_mean.data_ptr(), + running_var.data_ptr(), + minibatch_mean.data_ptr(), + minibatch_inv_var.data_ptr(), + void_peer_buffers, + static_cast(epsilon), + static_cast(momentum), + peer_size, + rank_id); + return y; } @@ -98,26 +119,37 @@ std::vector gbn_backward( void_peer_buffers.push_back((void*)addr); } + // we need the peer size for the buffer reset + size_t peer_size = 1; + for (size_t i = 0; i < 4; ++i){ + peer_size *= peerDims[i]; + } + assert(bn_group == void_peer_buffers.size()); - run_batch_norm_backward( - perChannelDims, - epsilonDims, - tensorDims, - peerDims, - x.data_ptr(), - dy.data_ptr(), - scale.data_ptr(), - minibatch_mean.data_ptr(), - minibatch_inv_var.data_ptr(), - x_grad.data_ptr(), - scale_grad.data_ptr(), - bias_grad.data_ptr(), - void_peer_buffers, - epsilon, - rank_id); - - + std::vector fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; + if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { + auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); + gbn_plan_cache.insert(std::make_pair(fv, plan)); + } + + // get plan and handle + auto plan = gbn_plan_cache.find(fv)->second; + + // execute + execute_batch_norm_backward(plan, + x.data_ptr(), + dy.data_ptr(), + scale.data_ptr(), + minibatch_mean.data_ptr(), + minibatch_inv_var.data_ptr(), + void_peer_buffers, + x_grad.data_ptr(), + scale_grad.data_ptr(), + bias_grad.data_ptr(), + static_cast(epsilon), + peer_size, + rank_id); return std::vector{x_grad, scale_grad, bias_grad}; } diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp index d4835190b..e14502109 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp @@ -1,24 +1,24 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and associated documentation files (the "Software"), - * to deal in the Software without restriction, including without limitation - * the rights to use, copy, modify, merge, publish, distribute, sublicense, - * and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - * DEALINGS IN THE SOFTWARE. - */ +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +*/ #include "norm_sample.h" #include @@ -27,31 +27,16 @@ #include #include -#define FatalError(s) { \ - std::stringstream _where, _message; \ - _where << __FILE__ << ':' << __LINE__; \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__;\ - std::cerr << _message.str() << "\nAborting...\n"; \ - exit(EXIT_FAILURE); \ -} - -#define checkCUDNN(status) { \ - std::stringstream _error; \ - if (status != CUDNN_STATUS_SUCCESS) { \ - _error << "CUDNN failure\nError: " << cudnnGetErrorString(status); \ - FatalError(_error.str()); \ - } \ -} - -#define checkCudaErrors(status) { \ - std::stringstream _error; \ - if (status != 0) { \ - _error << "Cuda failure\nError: " << cudaGetErrorString(status); \ - FatalError(_error.str()); \ - } \ +// some helpers +int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line) { + if (code) { + printf("CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int)code, cudaGetErrorString(code), expr); + return 1; + } + return 0; } -int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { +int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { if (code) { printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); return 1; @@ -59,416 +44,436 @@ int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int return 0; } -#define checkCudnnErr(...) \ - do { \ - int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - if (err) { \ - return; \ - } \ - } while (0) - bool AllowAll(cudnnBackendDescriptor_t engine_config) { - (void)engine_config; - return false; + (void)engine_config; + return false; } -void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { - // For INT8x4 and INT8x32 we still compute standard strides here to input - // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. - if (filterFormat == CUDNN_TENSOR_NCHW) { - strideA[nbDims - 1] = 1; - for (int64_t d = nbDims - 2; d >= 0; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - } else { - // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; - strideA[nbDims - 1] = strideA[1] * dimA[1]; - for (int64_t d = nbDims - 2; d >= 2; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - strideA[0] = strideA[2] * dimA[2]; +void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) { + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + if (filterFormat == CUDNN_TENSOR_NCHW) { + strideA[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + } else { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strideA[1] = 1; + strideA[nbDims - 1] = strideA[1] * dimA[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; } + strideA[0] = strideA[2] * dimA[2]; + } } -void -run_batch_norm_forward( - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *tensorDims, - int64_t *peerDims, - - void *xDevPtr, - void *yDevPtr, - void *scaledevPtr, - void *biasdevPtr, - void *in_meandevPtr, - void *in_vardevPtr, - void *out_meandevPtr, - void *out_vardevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - double exponential_decay_factor, - int rank_id) -{ - cudaStream_t stream; - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - cudnnGetStream(handle_, &stream); - - try { - - // Creates the necessary tensor descriptors - int64_t tensor_stride[4]; - int64_t stride[4]; - int64_t peer_stride[4]; - - generateStrides(tensorDims, tensor_stride, 4, CUDNN_TENSOR_NHWC); - generateStrides(peerDims, peer_stride, 4, CUDNN_TENSOR_NHWC); - - auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) - .setStrides(4, tensor_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) - .setStrides(4, peer_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - - generateStrides(perChannelSum, stride, 4, CUDNN_TENSOR_NHWC); - - auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, perChannelSum) - .setStrides(4, stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - - auto xTensor = tensor_create(CUDNN_DATA_HALF, 100); - auto yTensor = tensor_create(CUDNN_DATA_HALF, 101); - auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); - auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); - auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); - auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105); - auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); - auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); - auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108); - auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109); - - - int64_t epsilon_stride[4]; - generateStrides(epsilon, epsilon_stride, 4, CUDNN_TENSOR_NHWC); - auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, epsilon) - .setStrides(4, epsilon_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .setByValue(true) - .build(); - }; - - auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110); - auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111); - - std::vector peerStatTensors; - for (size_t i = 112; i < 112 + peer_devPtrs.size(); ++i) { - peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); - } -#if (CUDNN_VERSION >= 8500) - // Batch normalization - cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; - - // Forward training - cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; - - //Create a Finalize node - auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor) - .setNextRunningMeanAndVar(outMeanTensor, outVarTensor) - .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) - .setEpsilonTensor(epsilonTensor) - .setExpDecayFactorTensor(expDecayTensor) - .setPeerStatTensor(peerStatTensors) - .setyDesc(yTensor) - .build(); +// runtime +cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, + int64_t *perChannelSum, + int64_t *epsilon, + int64_t *peerDims, + cudnnDataType_t data_type) { + + // get the cudnn handle + cudnnHandle_t handle = torch::native::getCudnnHandle(); + + // Creates the necessary tensor descriptors + int64_t tensor_stride[4]; + int64_t stride[4]; + int64_t peer_stride[4]; + + // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW + generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + + auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, + int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, tensorDims) + .setStrides(4, tensor_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, + int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, tensorDims) + .setStrides(4, peer_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + + generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); + + auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, perChannelSum) + .setStrides(4, stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + auto xTensor = tensor_create(data_type, 100); + auto yTensor = tensor_create(data_type, 101); + auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); + auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); + auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); + auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105); + auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); + auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); + auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108); + auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109); + + + int64_t epsilon_stride[4]; + generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, epsilon) + .setStrides(4, epsilon_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .setByValue(true) + .build(); + }; + + auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110); + auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111); + + // Create the two peer stat tensors. Jump IDs in case we need to add more tensors with UIDs + std::vector peerStatTensors; + for (size_t i = 112; i < 112 + peerDims[0]; ++i) { + peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); + } - std::array ops = {&batch_norm_op}; - auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = - cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" - , "heuristics_fallback" - }, opGraph,::AllowAll, filtered_configs, true); - - auto plan_builder = [&filtered_configs, &opGraph, &handle_]() { - for (size_t i = 0; i < filtered_configs.size(); i++) { - try { - auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); - return plan; - } catch (cudnn_frontend::cudnnException &e) { - continue; - } - } - return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); - }; - - CHECK(filtered_configs.size() > 0); - auto plan = plan_builder(); - - auto workspace_size = plan.getWorkspaceSize(); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - std::vector data_ptrs {xDevPtr, yDevPtr, scaledevPtr, biasdevPtr, - in_meandevPtr, in_vardevPtr, out_meandevPtr, out_vardevPtr, - saved_meandevPtr, saved_inv_vardevPtr, - &epsilon_val, &exponential_decay_factor}; - data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); - std::vector uids; - for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { - uids.push_back(i); - } - - assert(data_ptrs.size() == uids.size()); - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs.size(), data_ptrs.data()) - .setUids(uids.size(), uids.data()) - .build(); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); +#if (CUDNN_VERSION >= 8500) + // Batch normalization + cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; + + // Forward training + cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; + + //Create a Finalize node + auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(phase) + .setxDesc(xTensor) + .setScaleAndBias(scaleTensor, biasTensor) + .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor) + .setNextRunningMeanAndVar(outMeanTensor, outVarTensor) + .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) + .setEpsilonTensor(epsilonTensor) + .setExpDecayFactorTensor(expDecayTensor) + .setPeerStatTensor(peerStatTensors) + .setyDesc(yTensor) + .build(); + + std::array ops = {&batch_norm_op}; +#else + std::array ops = {}; #endif + auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); + //std::cout << opGraph.describe() << std::endl; + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = + cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" + , "heuristics_fallback" + }, opGraph,::AllowAll, filtered_configs, true); + + //std::cout << "get_heuristics_list Statuses: "; + //for (auto i = 0u ; i < statuses.size(); i++) { + // std::cout << cudnn_frontend::to_string(statuses[i]) << " "; + //} + //std::cout << std::endl; + //std::cout << "Filter config list has " << filtered_configs.size() << " configurations " << std::endl; + + // some verbose printing: + //std::cout << "Tensor shape: (" << tensorDims[0] << ", " << tensorDims[1] << ", " << tensorDims[2] << ", " << tensorDims[3] << ")" << std::endl; + + auto plan_builder = [&filtered_configs, &opGraph, &handle]() { + for (auto i = 0u; i < filtered_configs.size(); i++) { + try { + auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); + return plan; + } catch (cudnn_frontend::cudnnException &e) { + continue; + } + } + return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); + }; + + assert(filtered_configs.size() > 0); + auto plan = plan_builder(); + + return plan; + +} - size_t peer_size = 1; - for (size_t i = 0; i < 4; ++i){ - peer_size *= peerDims[i]; +void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, + void *xDevPtr, + void *yDevPtr, + void *scaledevPtr, + void *biasdevPtr, + void *in_meandevPtr, + void *in_vardevPtr, + void *out_meandevPtr, + void *out_vardevPtr, + void *saved_meandevPtr, + void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, + double epsilon_val, + double exponential_decay_factor, + size_t peer_size, + int rank_id) { + + // get handle + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + + // get stream + cudaStream_t stream; + cudnnGetStream(handle_, &stream); + + try { + // allocate workspace + auto workspace_size = plan.getWorkspaceSize(); + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + void* workPtr = nullptr; + if (workspace_size > 0) { + workPtr = workspace_tensor.data_ptr(); } + + // first the data pointers + std::vector data_ptrs {xDevPtr, yDevPtr, scaledevPtr, biasdevPtr, + in_meandevPtr, in_vardevPtr, out_meandevPtr, out_vardevPtr, + saved_meandevPtr, saved_inv_vardevPtr, + &epsilon_val, &exponential_decay_factor}; + data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); + // then the uids + std::vector uids; + for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { + uids.push_back(i); + } + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workPtr) + .setDataPointers(data_ptrs.size(), data_ptrs.data()) + .setUids(uids.size(), uids.data()) + .build(); + //std::cout << "variantPack " << variantPack.describe() << std::endl; + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + // Reset local communication buffer cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream); - } catch (cudnn_frontend::cudnnException &e) { - struct cudaDeviceProp prop; - checkCudaErrors(cudaGetDeviceProperties(&prop, 0)); - if (prop.major == 8) { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - CHECK(false); - } + + } catch (cudnn_frontend::cudnnException &e) { + struct cudaDeviceProp prop; + checkCudaErr(cudaGetDeviceProperties(&prop, 0)); + if (prop.major == 8) { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + assert(false); } + } } -void -run_batch_norm_backward( - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *tensorDims, - int64_t *peerDims, - - void *xDevPtr, - void *dyDevPtr, - void *scaledevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - void *dxDevPtr, - void *dscaledevPtr, - void *dbiasdevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - int rank_id) - -{ - cudaStream_t stream; - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - cudnnGetStream(handle_, &stream); - try { - - // Creates the necessary tensor descriptors - int64_t tensor_stride[4]; - int64_t stride[4]; - int64_t peer_stride[4]; - - generateStrides(tensorDims, tensor_stride, 4, CUDNN_TENSOR_NHWC); - generateStrides(peerDims, peer_stride, 4, CUDNN_TENSOR_NHWC); - - auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) - .setStrides(4, tensor_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) - .setStrides(4, peer_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - generateStrides(perChannelSum, stride, 4, CUDNN_TENSOR_NHWC); - - auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, perChannelSum) - .setStrides(4, stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); - }; - - auto xTensor = tensor_create(CUDNN_DATA_HALF, 100); - auto dyTensor = tensor_create(CUDNN_DATA_HALF, 101); - auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); - auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); - auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); - auto dxTensor = tensor_create(CUDNN_DATA_HALF, 105); - auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); - auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); - - - int64_t epsilon_stride[4]; - generateStrides(epsilon, epsilon_stride, 4, CUDNN_TENSOR_NHWC); - auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, - int64_t id) { - return cudnn_frontend::TensorBuilder() - .setDim(4, epsilon) - .setStrides(4, epsilon_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .setByValue(true) - .build(); - }; - - auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108); - - - std::vector peerStatTensors; - for (size_t i = 109; i < 109 + peer_devPtrs.size(); ++i) { - peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); - } - -#if (CUDNN_VERSION >= 8500) - // Batch normalization - cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; - - //Create a Finalize node - auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setxDesc(xTensor) - .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) - .setdyDesc(dyTensor) - .setScale(scaleTensor) - .setEpsilonTensor(epsilonTensor) - .setDScaleAndDBias(dScaleTensor, dBiasTensor) - .setdxDesc(dxTensor) - .setPeerStatTensor(peerStatTensors) +cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, + int64_t *perChannelSum, + int64_t *epsilon, + int64_t *peerDims, + cudnnDataType_t data_type) { + + // get cudnn handle + cudnnHandle_t handle = torch::native::getCudnnHandle(); + + // Creates the necessary tensor descriptors + int64_t tensor_stride[4]; + int64_t stride[4]; + int64_t peer_stride[4]; + + // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW + generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + + auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, tensorDims) + .setStrides(4, tensor_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + auto peer_tensor_create = [&peer_stride, &peerDims](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, peerDims) + .setStrides(4, peer_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); + + auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, perChannelSum) + .setStrides(4, stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); + }; + + auto xTensor = tensor_create(data_type, 100); + auto dyTensor = tensor_create(data_type, 101); + auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); + auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); + auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); + auto dxTensor = tensor_create(data_type, 105); + auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); + auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); + + int64_t epsilon_stride[4]; + generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); + auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { + return cudnn_frontend::TensorBuilder() + .setDim(4, epsilon) + .setStrides(4, epsilon_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .setByValue(true) .build(); + }; - std::array ops = {&batch_norm_op}; - auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = - cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" - , "heuristics_fallback" - }, opGraph,::AllowAll, filtered_configs, true); - - auto plan_builder = [&filtered_configs, &opGraph, &handle_]() { - for (size_t i = 0; i < filtered_configs.size(); i++) { - try { - auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); - return plan; - } catch (cudnn_frontend::cudnnException &e) { - continue; - } - } - return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); - }; - - CHECK(filtered_configs.size() > 0); - auto plan = plan_builder(); - - auto workspace_size = plan.getWorkspaceSize(); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - std::vector data_ptrs {xDevPtr, dyDevPtr, scaledevPtr, - saved_meandevPtr, saved_inv_vardevPtr, - dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val}; - data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); - std::vector uids; - for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { - uids.push_back(i); - } - - assert(data_ptrs.size() == uids.size()); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs.size(), data_ptrs.data()) - .setUids(uids.size(), uids.data()) - .build(); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108); + + std::vector peerStatTensors; + for (size_t i = 109; i < 109 + peerDims[0]; ++i) { + peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); + } + +#if (CUDNN_VERSION >= 8500) + // Batch normalization + cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; + + //Create a Finalize node + auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setxDesc(xTensor) + .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) + .setdyDesc(dyTensor) + .setScale(scaleTensor) + .setEpsilonTensor(epsilonTensor) + .setDScaleAndDBias(dScaleTensor, dBiasTensor) + .setdxDesc(dxTensor) + .setPeerStatTensor(peerStatTensors) + .build(); + + std::array ops = {&batch_norm_op}; +#else + std::array ops = {}; #endif + + auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); + //std::cout << opGraph.describe() << std::endl; + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = + cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" + , "heuristics_fallback" + }, opGraph,::AllowAll, filtered_configs, true); + + auto plan_builder = [&filtered_configs, &opGraph, &handle]() { + for (auto i = 0u; i < filtered_configs.size(); i++) { + try { + auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); + return plan; + } catch (cudnn_frontend::cudnnException &e) { + continue; + } + } + return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); + }; - size_t peer_size = 1; - for (size_t i = 0; i < 4; ++i){ - peer_size *= peerDims[i]; + assert(filtered_configs.size() > 0); + auto plan = plan_builder(); + + return plan; +} + +void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, + void *xDevPtr, + void *dyDevPtr, + void *scaledevPtr, + void *saved_meandevPtr, + void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, + void *dxDevPtr, + void *dscaledevPtr, + void *dbiasdevPtr, + double epsilon_val, + size_t peer_size, + int rank_id) { + + // get handle + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + + // get stream + cudaStream_t stream; + cudnnGetStream(handle_, &stream); + + try { + // allocate workspace + auto workspace_size = plan.getWorkspaceSize(); + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + void* workPtr = nullptr; + if (workspace_size > 0) { + workPtr = workspace_tensor.data_ptr(); + } + + // create helper arrays + std::vector data_ptrs {xDevPtr, dyDevPtr, scaledevPtr, + saved_meandevPtr, saved_inv_vardevPtr, + dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val}; + data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); + std::vector uids; + for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { + uids.push_back(i); } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workPtr) + .setDataPointers(data_ptrs.size(), data_ptrs.data()) + .setUids(uids.size(), uids.data()) + .build(); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + // Reset local communication buffer cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream); - - } catch (cudnn_frontend::cudnnException &e) { - struct cudaDeviceProp prop; - checkCudaErrors(cudaGetDeviceProperties(&prop, 0)); - if (prop.major == 8) { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - CHECK(false); - } + + } catch (cudnn_frontend::cudnnException &e) { + struct cudaDeviceProp prop; + checkCudaErr(cudaGetDeviceProperties(&prop, 0)); + if (prop.major == 8) { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + assert(false); } -} \ No newline at end of file + } +} diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.h b/apex/contrib/csrc/cudnn_gbn/norm_sample.h index 26d51bacc..0706416b5 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.h +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.h @@ -1,5 +1,7 @@ +#pragma once + /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -27,53 +29,125 @@ #include #include #include -#include #include #include #include -#include #include +#include + +/* some helpers + */ +void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat); -void -run_batch_norm_forward( - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *tensorDims, - int64_t *peerDims, +int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line); +int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line); - void *xDevPtr, - void *yDevPtr, - void *scaledevPtr, - void *biasdevPtr, - void *in_meandevPtr, - void *in_vardevPtr, - void *out_meandevPtr, - void *out_vardevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - double exponential_decay_factor, - int rank_id -); +#define checkCudaErr(...) \ + do { \ + int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + assert(err == 0); \ + } while (0) -void -run_batch_norm_backward( - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *tensorDims, - int64_t *peerDims, +#define checkCudnnErr(...) \ + do { \ + int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + assert(err == 0); \ + } while (0) - void *xDevPtr, - void *dyDevPtr, - void *scaledevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - void *dxDevPtr, - void *dscaledevPtr, - void *dbiasdevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - int rank_id -); \ No newline at end of file +/** + * @brief Run a Group BN forward sample with 2 peer stat tensors. + * + * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format + * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor + * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN + * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN + + * + */ +cudnn_frontend::ExecutionPlan run_batch_norm_forward( + int64_t *tensorDims, + int64_t *perChannelSum, + int64_t *epsilon, + int64_t *peerDims, + cudnnDataType_t in_out_data_type); +/** + * @param xDevPtr input tensor device pointer + * @param yDevPtr output tensor device pointer + * @param scaledevPtr input scale device pointer for BN scaling + * @param biasdevPtr input scale device pointer for BN bias + * @param in_meandevPtr Input mean device pointer + * @param in_vardevPtr Input variance device pointer + * @param out_meandevPtr output mean device pointer + * @param out_vardevPtr output variance device pointer + * @param saved_meandevPtr saved mean device pointer for BN backward + * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward + * @param peer_devPtr1 peer stat tensor 1 device pointer + * @param peer_devPtr2 peer stat tensor 2 device pointer + * @param epsilon_val episilon value as a double + * @param exponential_decay_factor exponential_decay_factor as a value + * +**/ +void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, + void *xDevPtr, + void *yDevPtr, + void *scaledevPtr, + void *biasdevPtr, + void *in_meandevPtr, + void *in_vardevPtr, + void *out_meandevPtr, + void *out_vardevPtr, + void *saved_meandevPtr, + void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, + double epsilon_val, + double exponential_decay_factor, + size_t peer_size, + int rank_id); + +/** + * @brief Run a Group BN backward sample with 2 peer stat tensors. + * + * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format + * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor + * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN + * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN + * +*/ +cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, + int64_t *perChannelSum, + int64_t *epsilon, + int64_t *peerDims, + cudnnDataType_t data_type); + +/** + * @brief Run a Group BN backward sample with 2 peer stat tensors. + * + * @param xDevPtr input tensor device pointer + * @param yDevPtr output tensor device pointer + * @param scaledevPtr input scale device pointer for BN scaling + * @param biasdevPtr input scale device pointer for BN bias + * @param in_meandevPtr Input mean device pointer + * @param in_vardevPtr Input variance device pointer + * @param out_meandevPtr output mean device pointer + * @param out_vardevPtr output variance device pointer + * @param saved_meandevPtr saved mean device pointer for BN backward + * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward + * @param peer_devPtr1 peer stat tensor 1 device pointer + * @param peer_devPtr2 peer stat tensor 2 device pointer + * @param epsilon_val episilon value as a double + * + */ +void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, + void *xDevPtr, + void *dyDevPtr, + void *scaledevPtr, + void *saved_meandevPtr, + void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, + void *dxDevPtr, + void *dscaledevPtr, + void *dbiasdevPtr, + double epsilon_val, + size_t peer_size, + int rank_id); From 3ad083dacad5f581541999f6917d189a6297ad4d Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 22 Jun 2023 00:48:24 -0700 Subject: [PATCH 3/4] disentangling the mplamb MR and SGBN MR --- apex/optimizers/fused_mixed_precision_lamb.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index eaf70a1ce..f1b2902ca 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -12,27 +12,22 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True, amsgrad=False, adam_w_mode=True, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False, reduced_precision_dtype=None): - if amsgrad: raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - - # init defaults + + # The learning rate (lr) and optimizer step (step) should be located on device + # in order to faciliated device sync free execution defaults = dict(lr=torch.tensor(lr, dtype=torch.float32), step=torch.tensor([step], dtype=torch.int), bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm) - - # init base module + tensor_state = ['lr', 'step'] super(FusedMixedPrecisionLamb, self).__init__(params, defaults) - - # The learning rate (lr) and optimizer step (step) should be located on device + device = self.param_groups[0]['params'][0].device - # The learning rate (lr) and optimizer step (step) should be located on device - # in order to faciliated device sync free execution - tensor_state = ['lr', 'step'] for idx,group in enumerate(self.param_groups): for item in tensor_state: self.param_groups[idx][item] = group[item].to(device=device) From 96b961f38128cce0452d53c8b501e28e220b28e2 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jun 2023 12:23:21 +0200 Subject: [PATCH 4/4] cleaner caching --- apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index c1d32e1d6..bebf0f44b 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -57,7 +57,7 @@ at::Tensor gbn_forward(const at::Tensor& x, std::vector fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); - gbn_plan_cache.insert(std::make_pair(fv, plan)); + gbn_plan_cache.emplace(fv, std::move(plan)); } // get plan and handle @@ -130,7 +130,7 @@ std::vector gbn_backward( std::vector fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); - gbn_plan_cache.insert(std::make_pair(fv, plan)); + gbn_plan_cache.emplace(fv, std::move(plan)); } // get plan and handle