-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add the transpose op #3920
add the transpose op #3920
Changes from 12 commits
17b4b98
d6651b9
828008e
5599182
4da89f2
61c7930
e129dcf
6b3ae01
5ede6fd
35967e8
9de45e1
a9a7ba3
0cd9b8c
1792e58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/transpose_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::Tensor; | ||
|
||
class TransposeOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), | ||
"Input(Input) should not be null"); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), | ||
"Output(Output) should not be null"); | ||
auto input_dim = ctx.Input<Tensor>("Input")->dims(); | ||
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); | ||
size_t input_rank = input_dim.size(); | ||
size_t axis_size = axis.size(); | ||
|
||
PADDLE_ENFORCE_EQ(input_rank, axis_size, | ||
"the input tensor's rank(%d) " | ||
"should be equal to the axis's size(%d)", | ||
input_rank, axis_size); | ||
|
||
std::vector<int> count(axis_size, 0); | ||
for (size_t i = 0; i < axis_size; i++) { | ||
PADDLE_ENFORCE( | ||
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1, | ||
"Each element of Attribute axis should be a unique value " | ||
"range from 0 to (dims - 1), " | ||
"where the dims is the axis's size"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think traversing std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE(axis[i] < axis_size && ++count[axis[i]] == 1,
"Attribute axis should be a permutation of [0, 1, ... dims - 1], "
"where the dims is the axis's size");
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
framework::DDim output_dim(input_dim); | ||
for (size_t i = 0; i < axis_size; i++) { | ||
output_dim[i] = input_dim[axis[i]]; | ||
} | ||
ctx.Output<framework::LoDTensor>("Output")->Resize(output_dim); | ||
} | ||
}; | ||
|
||
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
TransposeOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput( | ||
"Input", | ||
"(Tensor)The input tensor, tensors with rank at most 6 are supported"); | ||
AddOutput("Output", "(Tensor)The output tensor"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The names of a single in/out Op's input and output should be See the document: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md |
||
AddAttr<std::vector<int>>( | ||
"axis", | ||
"(vector<int>)a list of values, and the size of the list should be " | ||
"the same with the input tensor rank, the tensor will " | ||
"permute the axes according the the values given"); | ||
AddComment(R"DOC( | ||
The Tensor will be permuted according to the axis values given. | ||
The op is very much like the numpy.transpose function in python | ||
For example: | ||
>> input = numpy.arange(6).reshape((2,3)) | ||
>> input | ||
array([[0, 1, 2], | ||
[3, 4, 5]]) | ||
>> axis = [1, 0] | ||
>> output = input.transpose(axis) | ||
>> output | ||
array([[0, 3], | ||
[1, 4], | ||
[2, 5]]) | ||
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1}, | ||
the output tensor shape will be (N, H, W, C) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need more detail about the op. We'd better write a common equation here, then the example. |
||
)DOC"); | ||
} | ||
}; | ||
|
||
class TransposeOpGrad : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), | ||
"Input(Input) should not be null"); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")), | ||
"Input(Output@GRAD) should not be null"); | ||
auto input_dim = ctx.Input<Tensor>("Input")->dims(); | ||
auto *input_grad = | ||
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input")); | ||
|
||
if (input_grad) input_grad->Resize(input_dim); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad, | ||
ops::TransposeOpGrad); | ||
REGISTER_OP_CPU_KERNEL(transpose, | ||
ops::TransposeKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL( | ||
transpose_grad, | ||
ops::TransposeGradKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/transpose_op.h" | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_GPU_KERNEL(transpose, | ||
ops::TransposeKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL( | ||
transpose_grad, | ||
ops::TransposeGradKernel<paddle::platform::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T, int Rank> | ||
void EigenTranspose(const framework::ExecutionContext& context, | ||
const framework::Tensor& in, framework::Tensor& out, | ||
std::vector<int> axis) { | ||
Eigen::array<int, Rank> permute; | ||
for (int i = 0; i < Rank; i++) { | ||
permute[i] = axis[i]; | ||
} | ||
auto in_dim = in.dims(); | ||
auto out_dim = out.dims(); | ||
|
||
auto eigen_in = framework::EigenTensor<T, Rank>::From(in); | ||
auto eigen_out = framework::EigenTensor<T, Rank>::From(out); | ||
auto& dev = context.GetEigenDevice<Place>(); | ||
eigen_out.device(dev) = eigen_in.shuffle(permute); | ||
} | ||
|
||
template <typename Place, typename T> | ||
class TransposeKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* input = context.Input<framework::Tensor>("Input"); | ||
auto* output = context.Output<framework::Tensor>("Output"); | ||
output->mutable_data<T>(context.GetPlace()); | ||
|
||
std::vector<int> axis = context.Attr<std::vector<int>>("axis"); | ||
int ndims = axis.size(); | ||
switch (ndims) { | ||
case 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个op是否支持1-D的输入呢?如果支持,这里应该是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是个bug, 已修复 |
||
EigenTranspose<Place, T, 1>(context, *input, *output, axis); | ||
break; | ||
case 2: | ||
EigenTranspose<Place, T, 2>(context, *input, *output, axis); | ||
break; | ||
case 3: | ||
EigenTranspose<Place, T, 3>(context, *input, *output, axis); | ||
break; | ||
case 4: | ||
EigenTranspose<Place, T, 4>(context, *input, *output, axis); | ||
break; | ||
case 5: | ||
EigenTranspose<Place, T, 5>(context, *input, *output, axis); | ||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉rank 5也不差多了,觉得可以去掉NaiveCpuTranspose, 直接用Eigen::shuffle, 这个不支持GPU吗? rank > 5 时: PADDLE_THROW("Tensors with rank at most 6 are supported").
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok,如果对与维度大于5的不支持的话,是可以的,省了很多代码,包括NaiveCpuTranspose 以及 Gpu kernel的代码 |
||
case 6: | ||
EigenTranspose<Place, T, 6>(context, *input, *output, axis); | ||
break; | ||
default: | ||
PADDLE_THROW("Tensors with rank at most 6 are supported"); | ||
} | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class TransposeGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* output_grad = | ||
context.Input<framework::Tensor>(framework::GradVarName("Output")); | ||
auto* input_grad = | ||
context.Output<framework::Tensor>(framework::GradVarName("Input")); | ||
if (input_grad) { | ||
input_grad->mutable_data<T>(context.GetPlace()); | ||
|
||
std::vector<int> axis = context.Attr<std::vector<int>>("axis"); | ||
std::vector<int> reversed_axis(axis); | ||
|
||
for (size_t i = 0; i < axis.size(); i++) { | ||
reversed_axis[axis[i]] = i; | ||
} | ||
|
||
int ndims = axis.size(); | ||
|
||
switch (ndims) { | ||
case 1: | ||
EigenTranspose<Place, T, 1>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
case 2: | ||
EigenTranspose<Place, T, 2>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
case 3: | ||
EigenTranspose<Place, T, 3>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
case 4: | ||
EigenTranspose<Place, T, 4>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
case 5: | ||
EigenTranspose<Place, T, 5>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
case 6: | ||
EigenTranspose<Place, T, 6>(context, *output_grad, *input_grad, | ||
reversed_axis); | ||
break; | ||
default: | ||
PADDLE_THROW("Tensors with rank at most 6 are supported"); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
class TestTransposeOp(OpTest): | ||
def setUp(self): | ||
self.initTestCase() | ||
self.op_type = "transpose" | ||
self.inputs = {'Input': np.random.random(self.shape).astype("float32")} | ||
self.attrs = {'axis': list(self.axis)} | ||
self.outputs = {'Output': self.inputs['Input'].transpose(self.axis)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['Input'], 'Output') | ||
|
||
def initTestCase(self): | ||
self.shape = (3, 4) | ||
self.axis = (1, 0) | ||
|
||
|
||
class TestCase0(TestTransposeOp): | ||
def initTestCase(self): | ||
self.shape = (3, ) | ||
self.axis = (0, ) | ||
|
||
|
||
class TestCase1(TestTransposeOp): | ||
def initTestCase(self): | ||
self.shape = (3, 4, 5) | ||
self.axis = (0, 2, 1) | ||
|
||
|
||
class TestCase2(TestTransposeOp): | ||
def initTestCase(self): | ||
self.shape = (2, 3, 4, 5) | ||
self.axis = (0, 2, 3, 1) | ||
|
||
|
||
class TestCase3(TestTransposeOp): | ||
def initTestCase(self): | ||
self.shape = (2, 3, 4, 5, 6) | ||
self.axis = (4, 2, 3, 1, 0) | ||
|
||
|
||
class TestCase4(TestTransposeOp): | ||
def initTestCase(self): | ||
self.shape = (2, 3, 4, 5, 6, 1) | ||
self.axis = (4, 2, 3, 1, 0, 5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here should be another test about whether our Op can throw an exception correctly. However, our framework can't support such a test right now. So I leave a comment here to remind us there is something TODO. I have created an issue about this: #4173 |
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 多测一些case: 2维 3维 4维 5维 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对输出也需要检查