-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PTen】Add dot and matmul grad kernel in pten #38713
Changes from 23 commits
76608b5
beecda8
43b3267
63d1681
8cfb873
2dffa46
0f3c6d1
2f077b4
d6217d2
dc01e07
a3658f5
41c0b85
682059d
6a5c6b9
a8830fc
0846c48
83d7207
0d0f654
e2e4af0
854c6fb
8b60eaa
3c955bd
ab5c095
bc8a80b
c113ab5
fc13a47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,26 +338,48 @@ static void BuildDygraphPtenKernelContext( | |
|
||
for (size_t i = 0; i < output_names.size(); ++i) { | ||
auto& out_def = output_defs.at(i); | ||
auto& outs_vector = outs.at(output_names[i]); | ||
|
||
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); | ||
size_t end_idx = start_idx + outs_vector.size(); | ||
auto current_vector_size = kernel_ctx->OutputsSize(); | ||
// If the memory needed is less than the current memory allocated, we will | ||
// reuse the current memory by using ReMakePtenDenseTensorFromVar. | ||
// Otherwise,we will create new storage. | ||
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { | ||
if (current_vector_size > start_idx + offset) { | ||
experimental::ReMakePtenDenseTensorFromVar( | ||
outs_vector[offset]->MutableVar(), out_def, | ||
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset)); | ||
|
||
auto iter = outs.find(output_names[i]); | ||
if (iter != outs.end()) { | ||
auto& outs_vector = iter->second; | ||
size_t end_idx = start_idx + outs_vector.size(); | ||
|
||
// If the memory needed is less than the current memory allocated, we will | ||
// reuse the current memory by using ReMakePtenDenseTensorFromVar. | ||
// Otherwise,we will create new storage. | ||
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { | ||
if (current_vector_size > start_idx + offset) { | ||
auto* buffer_tensor = kernel_ctx->MutableOutputAt<pten::DenseTensor>( | ||
start_idx + offset); | ||
if (buffer_tensor) { | ||
experimental::ReMakePtenDenseTensorFromVar( | ||
outs_vector[offset]->MutableVar(), out_def, buffer_tensor); | ||
} else { | ||
kernel_ctx->SetOutputWithoutSetRange( | ||
start_idx + offset, | ||
experimental::MakePtenTensorBaseFromVar( | ||
outs_vector[offset]->MutableVar(), out_def)); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个分支有用到吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 动态图模式下会执行到 |
||
} else { | ||
kernel_ctx->EmplaceBackOutputWithoutSetRange( | ||
experimental::MakePtenTensorBaseFromVar( | ||
outs_vector[offset]->MutableVar(), out_def)); | ||
} | ||
} | ||
|
||
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); | ||
} else { | ||
if (current_vector_size > start_idx) { | ||
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr}); | ||
} else { | ||
kernel_ctx->EmplaceBackOutputWithoutSetRange( | ||
experimental::MakePtenTensorBaseFromVar( | ||
outs_vector[offset]->MutableVar(), out_def)); | ||
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); | ||
} | ||
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), | ||
i); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里建议将这段逻辑挪到开头,使用iter == outs.end判断执行后直接continue,这样可以优化代码结构,减少if else逻辑嵌套便于代码维护与理解 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); | ||
} | ||
|
||
for (size_t i = 0; i < attr_names.size(); ++i) { | ||
|
@@ -465,15 +487,18 @@ static void WriteBackToOutputs( | |
auto& output_names = std::get<2>(pt_kernel_signature.args); | ||
|
||
for (size_t i = 0; i < output_names.size(); ++i) { | ||
auto& outs_vector = outs.at(output_names[i]); | ||
auto iter = outs.find(output_names[i]); | ||
if (iter != outs.end()) { | ||
auto& outs_vector = iter->second; | ||
|
||
auto& range_pair = kernel_ctx->OutputRangeAt(i); | ||
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>( | ||
range_pair.first, range_pair.second); | ||
auto& range_pair = kernel_ctx->OutputRangeAt(i); | ||
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>( | ||
range_pair.first, range_pair.second); | ||
|
||
for (size_t j = 0; j < pten_outs.size(); ++j) { | ||
experimental::MakeVariableFromPtenTensor(pten_outs[j], | ||
outs_vector[j]->MutableVar()); | ||
for (size_t j = 0; j < pten_outs.size(); ++j) { | ||
experimental::MakeVariableFromPtenTensor(pten_outs[j], | ||
outs_vector[j]->MutableVar()); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -530,6 +555,7 @@ static void PreparedOpRunImpl( | |
template <typename VarType> | ||
static void PreparedOpRunPtImpl( | ||
const framework::OperatorBase& op, | ||
const framework::OpKernelType& kernel_type, | ||
const framework::KernelSignature& pt_kernel_signature, | ||
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, | ||
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, | ||
|
@@ -560,17 +586,19 @@ static void PreparedOpRunPtImpl( | |
pt_kernel_context->ClearData(); | ||
|
||
// TODO(chenweihang): add debug flags later | ||
// TODO(chenweihang): deal with complex cases later | ||
if (framework::IsComplexType(kernel_type.data_type_)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是否可以使用pten_kernel的data type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 由于传入的 |
||
HandleComplexGradToRealGrad<VarType>(outs); | ||
} | ||
} | ||
|
||
void PreparedOp::Run(const NameVarMap<VarBase>& ins, | ||
const NameVarMap<VarBase>& outs, | ||
const framework::AttributeMap& attrs, | ||
const framework::AttributeMap& default_attrs) { | ||
if (run_pten_kernel_) { | ||
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_, | ||
pt_kernel_context_, dev_ctx_, ins, outs, attrs, | ||
default_attrs); | ||
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_, | ||
pt_kernel_, pt_kernel_context_, dev_ctx_, ins, | ||
outs, attrs, default_attrs); | ||
} else { | ||
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, | ||
outs, attrs, default_attrs); | ||
|
@@ -582,9 +610,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, | |
const framework::AttributeMap& attrs, | ||
const framework::AttributeMap& default_attrs) { | ||
if (run_pten_kernel_) { | ||
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_, | ||
pt_kernel_context_, dev_ctx_, ins, | ||
outs, attrs, default_attrs); | ||
PreparedOpRunPtImpl<VariableWrapper>( | ||
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_, | ||
dev_ctx_, ins, outs, attrs, default_attrs); | ||
} else { | ||
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, | ||
ins, outs, attrs, default_attrs); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加点注释吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done