-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a new op: paddle.linalg.multi_dot #35224
Conversation
Thanks for your contribution! |
} | ||
|
||
/** | ||
* @brief multi matrix dot by a chain order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加些注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
}; | ||
|
||
template <typename DeviceContext, typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加些计算逻辑的注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto order = GetOrder(ins, ins_dims); | ||
auto n = ins.size(); | ||
std::vector<framework::Tensor> results(n * n); | ||
MatChainMul<DeviceContext, T>(ctx, ins, ins_dims, order, 0, n - 1, true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以使用前向结果?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以设置成AsIntermediate作为中间结果,在前向的时候保存下来,放到后面优化的时候改。
ops::MultiDotOpDoubleGradMaker<paddle::framework::OpDesc>, | ||
ops::MultiDotOpDoubleGradMaker<paddle::imperative::OpBase>); | ||
|
||
REGISTER_OP_CPU_KERNEL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu版不支持fp16
|
||
|
||
def multi_dot(x, name=None): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改写一下语言
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle.enable_static() | ||
|
||
|
||
class TestMultiDotOp(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加些注释说明下函数作用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.assertRaises(ValueError, paddle.multi_dot, [x5, x6, x7]) | ||
|
||
|
||
class API_TestMultiDot(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字格式改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -28,4 +28,5 @@ | |||
'cvm', | |||
'cudnn_lstm', | |||
'rnn', | |||
'multi_dot', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check 一下白名单
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经找zhupengyang确认过可以加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to optimize the comments of API and add more comments for the get order algorithm.
python/paddle/tensor/linalg.py
Outdated
|
||
def multi_dot(x, name=None): | ||
""" | ||
Compute the dot product of tow or more matrix in a single function call, while automatically selecting the fastest evaluation order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tow -> two
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/tensor/linalg.py
Outdated
|
||
Supports inputs of float, double and float16 dtypes. This function does not support batched inputs. | ||
|
||
Every tensor in x must be 2D, except for the first and last which may be 1D. if the first tensor is a 1D vector of shape(n, ) it is treated as row vector of shape(1, n), similarly if the last tensor is a 1D vector of shape(n, ), it is treated as a column vector of shape(n, 1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every tensor in x must be 2D
x要加单括号标明变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/tensor/linalg.py
Outdated
If the first and last tensors are matrices, the output will be a matrix. However, if either is a 1D vector, then the output will be a 1D vector. | ||
|
||
The cost of multiplying two matrices with shapes (a, b) and (b, c) is a * b * c. Given matrices A, B, C with shapes (10, 100), (100, 5), (5, 50) respectively, we can calculate the cost of different multiplication orders as follows: | ||
- Cost((AB)C) = 10x100x5 + 10x5x50 = 7500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
例子与pytorch一样,是否可以更换下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
B = paddle.to_tensor(B_data) | ||
C = paddle.to_tensor(C_data) | ||
out = paddle.multi_dot([A, B, C]) | ||
print(out.numpy().shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code examples 应该要给出结正确果,使用注释符号后给出
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
std::vector<uint64_t> m(n * n, 0); | ||
std::vector<uint64_t> order(n * n); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对这一算法增加原理注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
x0 = fluid.data(name='x0', shape=[3, 2], dtype="float64") | ||
x1 = fluid.data(name='x1', shape=[2, 3], dtype='float64') | ||
result = paddle.multi_dot([x0, x1]) | ||
exe = fluid.Executor(fluid.CPUPlace()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里fluid. Executor -> paddle.static.Executor.
凡是fluid.xx都修改下,参考下其他单测,或者是官网文档搜同名,用paddle.xx替换
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/tensor/linalg.py
Outdated
@@ -2,7 +2,6 @@ | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里包括下面,有一些和你的PR无关的修改,恢复成原样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个恢复了,其他的是precommit自动格式化的。
python/paddle/tensor/linalg.py
Outdated
|
||
Supports inputs of float, double and float16 dtypes. This function does not support batched inputs. | ||
|
||
The input tensor in [x] must be 2D except for the first and last can be 1D. If the first tensor is a 1D vector of shape(n, ) it is treated as row vector of shape(1, n), similarly if the last tensor is a 1D vector of shape(n, ), it is treated as a column vector of shape(n, 1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 2D- > 2-D,1D->1-D
- 这里文档描述调整下换行,太长了
下面的也需要注意类似问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
LGTM |
framework::DDim out_dim; | ||
|
||
if (first_dim.size() > 2) { | ||
PADDLE_THROW(platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用PADDLE_ENFORCE_GT(first_dim.size(), 2, ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
|
||
auto last_dim = inputs_dims[n - 1]; | ||
if (last_dim.size() > 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
const framework::OpKernelType& expected_kernel_type) const override { | ||
return framework::OpKernelType(expected_kernel_type.data_type_, | ||
tensor.place(), tensor.layout()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有输入、输出的数据类型是一样的吧,这两个函数没有必要重写。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
framework::OpKernelType GetKernelTypeForVar( | ||
const std::string& var_name, const framework::Tensor& tensor, | ||
const framework::OpKernelType& expected_kernel_type) const { | ||
if (framework::IsComplexType(expected_kernel_type.data_type_)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并没有注册复数类型的Kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
void Make() override { | ||
AddInput("X", "The input tensors of multi_dot operator.").AsDuplicable(); | ||
AddOutput("Out", "The output tensor of multi_dot operator"); | ||
AddAttr<bool>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并没有实现MKLDNN
类型的OpKernel,建议删除mkldnn所有相关的代码。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ba0a92c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
import unittest | ||
import numpy as np | ||
from op_test import OpTest, skip_check_grad_ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_check_grad_ci 这个没有用到,记得删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,下个PR我在去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
Add the multi_dot to paddle linear algebra library: