Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support TNLRV3 fine-tuning #4639

Merged
merged 4 commits into from
Jul 30, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
@@ -396,8 +396,9 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
}

CudnnReduceDescriptor reduce_desc;
if (std::is_same<T, MLFloat16>::value)
if (std::is_same<T, MLFloat16>::value) {
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType<float>(), ReduceTensorIndices));
}
else
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices));
const auto one = Consts<CudaT>::One;
@@ -438,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
} else {
// Reduce max -- Max/Min will output indices data
CudnnReduceDescriptor reduce_max_desc;
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES));
cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X;
if((std::is_same<T, MLFloat16>::value)) {
cudnn_reduce_max_type = CUDNN_DATA_FLOAT;
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
}
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES));
size_t indices_bytes_max = 0;
CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc,
input_tensor, output_tensor, &indices_bytes_max));
29 changes: 29 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
@@ -788,6 +788,35 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}

ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
Tixxx marked this conversation as resolved.
Show resolved Hide resolved

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
}
else {
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
}

result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")}));
SherlockNoMad marked this conversation as resolved.
Show resolved Hide resolved

result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)}));

return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient)
DECLARE_GRADIENT_BUILDER(GetDivGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
DECLARE_GRADIENT_BUILDER(GetPowGradient)
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient);
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
8 changes: 7 additions & 1 deletion orttraining/orttraining/python/ort_trainer.py
Original file line number Diff line number Diff line change
@@ -773,7 +773,13 @@ def state_dict(self):
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))

return torch_state
# Need to remove redundant initializers and name suffices to map back to original torch state names
torch_state_to_return = {}
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
for name, value in torch_state.items():
if not (("Moment" in name) or ("Update_Count" in name)):
name = name.replace('_fp16', '')
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
torch_state_to_return[name] = value
return torch_state_to_return

def load_state_dict(self, state_dict, strict=False):
# Note: It may happen ONNX model has not yet been initialized
68 changes: 68 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
@@ -546,6 +546,74 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
}
}

TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
float max_error;
GradientChecker<float, float, float> gradient_checker;
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};

// default
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
EXPECT_IS_TINY(max_error);
}

// axes = [0, 1, 2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [1], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}

// axes = [2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [-2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}

// axes = [-1, -3], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-1, -3}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
}

#ifndef USE_CUDA
TEST(GradientCheckerTest, CastGrad) {
// A dummy test that cast float to float