-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PTen】Add dot and matmul grad kernel in pten #38713
Conversation
Thanks for your contribution! |
@@ -560,17 +586,19 @@ static void PreparedOpRunPtImpl( | |||
pt_kernel_context->ClearData(); | |||
|
|||
// TODO(chenweihang): add debug flags later | |||
// TODO(chenweihang): deal with complex cases later | |||
if (framework::IsComplexType(kernel_type.data_type_)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否可以使用pten_kernel的data type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于传入的KernelSignature
和Kernel
数据结构都不具有data_type信息,所以需要使用kernel_type
的数据
… pten_matmul_grad
if (current_vector_size > start_idx) { | ||
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr}); | ||
} else { | ||
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr}); | ||
} | ||
end_idx = start_idx + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加点注释吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} else { | ||
kernel_ctx->SetOutputWithoutSetRange( | ||
start_idx + offset, | ||
experimental::MakePtenTensorBaseFromVar( | ||
outs_vector[offset]->MutableVar(), out_def)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支有用到吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动态图模式下会执行到
} else { | ||
if (current_vector_size > start_idx) { | ||
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr}); | ||
} else { | ||
kernel_ctx->EmplaceBackOutputWithoutSetRange( | ||
experimental::MakePtenTensorBaseFromVar( | ||
outs_vector[offset]->MutableVar(), out_def)); | ||
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); | ||
} | ||
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), | ||
i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议将这段逻辑挪到开头,使用iter == outs.end判断执行后直接continue,这样可以优化代码结构,减少if else逻辑嵌套便于代码维护与理解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle::platform::complex<float>, | ||
paddle::platform::complex<double>) {} | ||
|
||
PT_REGISTER_CTX_KERNEL(matmul_grad_grad, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议命名与函数一致:matmul_double_grad,alias_name也如此
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… pten_matmul_grad
怀疑该PR导致了linear反向性能下降一倍: linear_2的nvprof结果如下:
linear_2的nvprof结果如下:
新的linear反向计算多了一个 |
收到,我排查一下 |
PR types
Others
PR changes
Others
Describe
迁移dot和matmul反向一阶、二阶和三阶计算kernel到pten中。
为了完成PTen反向计算kernel与框架的适配,本PR中还包括了以下几项调整:
OpProto
信息,与前向Op的处理有所不同,因此本PR中调整了相应的处理逻辑并为迁移的每个反向kernel对应的Op配置GetExpectedPtenKernelArgs
,该解决方案后续有可能会替换。paddle::optional<const DenseTensor&>
来包裹此类可能为空的输入变量。为此也在pten中增加了对paddle::optional<const DenseTensor&>
输入类型的支持。DenseTensor& operator=(DenseTensor&& other)
。