Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the transpose op #3920

Merged
merged 14 commits into from
Sep 21, 2017
Merged

add the transpose op #3920

merged 14 commits into from
Sep 21, 2017

Conversation

NHZlX
Copy link
Contributor

@NHZlX NHZlX commented Sep 6, 2017

fix #4163

int* host_buffer_data = host_buffer.mutable_data<int>(buffer_dims, cpu_place);

auto offset_buffer =
memory::Alloc(context.GetPlace(), ndims * 3 * sizeof(int));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里接口使用不当,context.GetPlace() 返回的是 Place,是一个boost::variant<CPUPlace, GPUPlace>,而Alloc需要接收一个确定的GPUPlace or CPUPlace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

memory::Copy(gpu_place, offset_buffer, cpu_place, host_buffer_data,
ndims * 3 * sizeof(int));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在memory copy的时候,需要指定对应的cuda Stream

template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
          cudaStream_t stream);

cuda stream 可以从context里面拿

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Merge the latest develop first
  • Need more unit test cases. In particular, we need a unit test with large Tensor to test the CUDA kernel.


#include "paddle/operators/transpose_op.h"
#include <vector>
#include "paddle/framework/ddim.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 16, 17 can be removed.


PADDLE_ENFORCE_EQ(
in_dim_size, axis_size,
"the input tensor dimensions should be equal to the axis size");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the value of in_dim_size and axis_size in error message.
input tensor dimensions -> input tensor's dimension
axis size -> axis's size

std::vector<int> axis_sorted(axis);
std::sort(axis_sorted.begin(), axis_sorted.end());
for (size_t i = 0; i < axis_sorted.size(); i++) {
PADDLE_ENFORCE_EQ(axis_sorted[i], (int)i,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use static_cast<int>

for (size_t i = 0; i < axis_sorted.size(); i++) {
PADDLE_ENFORCE_EQ(axis_sorted[i], (int)i,
"the sorted axis should be [0, 1, ... dims - 1], "
"the dims equals to the input tensor dimensions");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where the dims is the axis's size.

framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of transpose op");
AddOutput("Out", "The output of transpose op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

template <typename Place, typename T, int Dims>
void DoTranspose(const framework::ExecutionContext& context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function calls Eigen to do the transpose, how about renaming it to EigenCpuTranspose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great

out->mutable_data<T>(context.GetPlace());

auto axis_temp = context.Attr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about : axis_temp -> axis and axis -> axis_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no grad for axis

public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in -> in_grad, out -> out_grad

case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ndims is 1, calling NaiveCpuTranspose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ndmis > 5, calling NaiveCpuTranspose

namespace operators {

template <typename T>
__global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this CUDA function can be renamed to NaiveCUDATranspose, and we may need some optimized implementation in the future.


protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in_dim = ctx.Input<Tensor>("X")->dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferShape里的check需要全面,对输入检查NOT_NULL,例如:https://github.com/PaddlePaddle/Paddle/pull/4086/files#diff-1fcd5ee1c1e63ed40789a0e60fdb1bf6R29

for (size_t i = 0; i < axis.size(); i++) {
out_dim[i] = in_dim[axis[i]];
}
ctx.Output<Tensor>("Out")->Resize(out_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在需要在InferShape里对Output使用LoDTensor: Output< framework::Tensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output< framework::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

namespace operators {

template <typename T>
__global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个kernel效率太低,可以考虑Eigen::shuffle

https://github.com/RLovelett/eigen/tree/master/unsupported/Eigen/CXX11/src/Tensor#-shuffleconst-shuffle-shuffle

这样CPU,GPU同时支持。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have deleted this kernel and used the eigen interface

break;
case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉rank 5也不差多了,觉得可以去掉NaiveCpuTranspose, 直接用Eigen::shuffle, 这个不支持GPU吗? rank > 5 时:

PADDLE_THROW("Tensors with rank at most 6 are supported").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,如果对与维度大于5的不支持的话,是可以的,省了很多代码,包括NaiveCpuTranspose 以及 Gpu kernel的代码

import numpy as np
from gradient_checker import GradientChecker
from op_test_util import OpTestMeta
from paddle.v2.framework.op import Operator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用新的单测框架。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多测一些case: 2维 3维 4维 5维

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得总体上已经差不多OK了。主要看看输入输出以及Op的comment怎么修缮下,以后要生成文档的。

auto axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个op是否支持1-D的输入呢?如果支持,这里应该是copy操作;如果不支持,在InterShape里面应该使用PADDLE_ENFORCE进行检查。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是个bug, 已修复

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
"Input(Input) should not be null");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对输出也需要检查

: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"(Tensor)The input tensor, tensors with rank at most 7 are supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最高只支持6-D吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对, 6D以上不支持

ctx.Input<Tensor>(framework::GradVarName("Output"))->dims();
auto output_dims = ctx.Input<Tensor>("Output")->dims();

PADDLE_ENFORCE(output_grad_dims == output_dims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用PADDLE_ENFORCE_EQ


switch (ndims) {
case 1:
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
"Input(Input) should not be null");
auto input_dim = ctx.Input<Tensor>("Input")->dims();
auto axis = ctx.Attr<std::vector<int>>("axis");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Writing std:vector<int> is not complex. Don't overuse auto. It may make readers confused about variable type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"Input(Input) should not be null");
auto input_dim = ctx.Input<Tensor>("Input")->dims();
auto axis = ctx.Attr<std::vector<int>>("axis");
size_t input_dim_size = input_dim.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of a tensor's dimensions is called rank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

size_t axis_size = axis.size();

PADDLE_ENFORCE_EQ(input_dim_size, axis_size,
"the input tensor's dimension(%d) "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the input tensor's rank(%d) "

PADDLE_ENFORCE_EQ(axis_sorted[i], static_cast<int>(i),
"the sorted axis should be [0, 1, ... dims - 1], "
"where the dims is the axis's size");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think traversing axis and recording times that each number occurs is shorter and faster than the current implementation:

std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis.size(); i++) {
  PADDLE_ENFORCE(axis[i] < axis_size && ++count[axis[i]] == 1,
         "Attribute axis should be a permutation of [0, 1, ... dims - 1], "
         "where the dims is the axis's size");
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddAttr<std::vector<int>>(
"axis",
"(vector<int>)a list of values, and the size of the list should be "
"the same with the input tensor dimensions, the tensor will "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'the input tensor's rank'

ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));

auto output_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Output"))->dims();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert output_grad_dims is not nullptr.

PADDLE_ENFORCE(output_grad_dims == output_dims,
"Output@GRAD dims must equal to Input(Input) dims");

input_grad->Resize(input_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_grad can be nullptr, which means it is useless for backward, and we don't need to resize and compute it.

context.Input<framework::Tensor>(framework::GradVarName("Output"));
auto* input_grad =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
input_grad->mutable_data<T>(context.GetPlace());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check input_grad first. If it is nullptr we don't need to do the following computing.

An example: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cos_sim_op.h#L104

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<int> axis(axis_temp);

for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about rename axis_temp to axis and rename current axis to reversed_axis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good choice

class TestCase4(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
Copy link
Collaborator

@JiayiFeng JiayiFeng Sep 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here should be another test about whether our Op can throw an exception correctly. However, our framework can't support such a test right now. So I leave a comment here to remind us there is something TODO. I have created an issue about this: #4173

JiayiFeng
JiayiFeng previously approved these changes Sep 19, 2017
Copy link
Collaborator

@JiayiFeng JiayiFeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your work!

AddInput(
"Input",
"(Tensor)The input tensor, tensors with rank at most 6 are supported");
AddOutput("Output", "(Tensor)The output tensor");
Copy link
Collaborator

@JiayiFeng JiayiFeng Sep 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of a single in/out Op's input and output should be X and Out respectively.

See the document: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md

@NHZlX NHZlX merged commit c003895 into PaddlePaddle:develop Sep 21, 2017
@NHZlX NHZlX deleted the op_transpose branch September 21, 2017 02:44
heavengate pushed a commit to heavengate/Paddle that referenced this pull request Aug 16, 2021
* add faq

* Update README_cn.md

* Update FAQ-README.md

* Update FAQ第一期.md

* Rename FAQ-README.md to README.md

* Update README_cn.md

* Update FAQ第一期.md

* delete 2 files

* Delete .DS_Store
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Transpose Operator
5 participants