Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the transpose op #3920

Merged
merged 14 commits into from
Sep 21, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ list(REMOVE_ITEM GENERAL_OPS
minus_op
mul_op
recurrent_op
scale_op)
scale_op
transpose_op)

op_library(net_op SRCS net_op.cc)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
op_library(transpose_op SRCS transpose_op.cc transpose_op.cu DEPS paddle_memory device_context)

foreach(src ${GENERAL_OPS})
op_library(${src} SRCS ${src}.cc ${src}.cu)
Expand Down
106 changes: 106 additions & 0 deletions paddle/operators/transpose_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/transpose_op.h"
#include <vector>
#include "paddle/framework/ddim.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 16, 17 can be removed.


namespace paddle {
namespace operators {

using framework::Tensor;

class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in_dim = ctx.Input<Tensor>("X")->dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferShape里的check需要全面,对输入检查NOT_NULL,例如:https://github.com/PaddlePaddle/Paddle/pull/4086/files#diff-1fcd5ee1c1e63ed40789a0e60fdb1bf6R29

auto axis = ctx.GetAttr<std::vector<int>>("axis");
size_t in_dim_size = in_dim.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(
in_dim_size, axis_size,
"the input tensor dimensions should be equal to the axis size");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the value of in_dim_size and axis_size in error message.
input tensor dimensions -> input tensor's dimension
axis size -> axis's size


std::vector<int> axis_sorted(axis);
std::sort(axis_sorted.begin(), axis_sorted.end());
for (size_t i = 0; i < axis_sorted.size(); i++) {
PADDLE_ENFORCE_EQ(axis_sorted[i], (int)i,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use static_cast<int>

"the sorted axis should be [0, 1, ... dims - 1], "
"the dims equals to the input tensor dimensions");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where the dims is the axis's size.

}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think traversing axis and recording times that each number occurs is shorter and faster than the current implementation:

std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis.size(); i++) {
  PADDLE_ENFORCE(axis[i] < axis_size && ++count[axis[i]] == 1,
         "Attribute axis should be a permutation of [0, 1, ... dims - 1], "
         "where the dims is the axis's size");
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

//
framework::DDim out_dim(in_dim);
for (size_t i = 0; i < axis.size(); i++) {
out_dim[i] = in_dim[axis[i]];
}
ctx.Output<Tensor>("Out")->Resize(out_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在需要在InferShape里对Output使用LoDTensor: Output< framework::Tensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
};

class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of transpose op");
AddOutput("Out", "The output of transpose op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddAttr<std::vector<int>>(
"axis",
"a list of integers, and the num of integers should be "
"the same with the input tensor dimensions");
AddComment(R"DOC(
Transpose the input tensor.
For example, input tensor shape(N, C, H, W) and axis {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more detail about the op. We'd better write a common equation here, then the example.

)DOC");
}
};

class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to check PADDLE_ENFORCE_NOT_NULL but not in TransposeOp?

auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output< framework::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto out_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto out_dims = ctx.Input<Tensor>("Out")->dims();

PADDLE_ENFORCE(out_grad_dims == out_dims,
"Out@GRAD dims must equal to Input(X) dims");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PADDLE_ENFORCE_EQ instead.
out_dims -> x_dims or Input(X) -> Input(Out)?


x_grad->Resize(x_dims);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad,
ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUPlace, float>);
123 changes: 123 additions & 0 deletions paddle/operators/transpose_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/transpose_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this CUDA function can be renamed to NaiveCUDATranspose, and we may need some optimized implementation in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个kernel效率太低,可以考虑Eigen::shuffle

https://github.com/RLovelett/eigen/tree/master/unsupported/Eigen/CXX11/src/Tensor#-shuffleconst-shuffle-shuffle

这样CPU,GPU同时支持。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have deleted this kernel and used the eigen interface

int* offset_buffer, int ndims) {
int* in_offset = offset_buffer;
int* out_offset = offset_buffer + ndims;
int* axis = offset_buffer + ndims;

int to_index = blockIdx.x * blockDim.x + threadIdx.x;

if (to_index < nthreads) {
int from_index = 0;
int temp = to_index;
for (size_t i = 0; i < ndims; i++) {
from_index += (temp / out_offset[i]) * in_offset[axis[i]];
temp = temp % out_offset[i];
}
out_data[to_index] = in_data[from_index];
}
}

template <typename T>
void TransposeCUDA(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
auto* in_data = in.template data<T>();
auto* out_data = out.template mutable_data<T>(context.GetPlace());
auto in_dim = in.dims();
auto out_dim = out.dims();
auto data_size = product(in_dim);
size_t ndims = in_dim.size();
std::vector<int> in_offset(ndims, 1);
std::vector<int> out_offset(ndims, 1);
std::vector<int64_t> buffer_dim_shape(1, ndims * 3);

auto buffer_dims = framework::make_ddim(buffer_dim_shape);
framework::Tensor host_buffer;
platform::CPUPlace cpu_place;
platform::GPUPlace gpu_place;

int* host_buffer_data = host_buffer.mutable_data<int>(buffer_dims, cpu_place);

auto offset_buffer =
memory::Alloc(context.GetPlace(), ndims * 3 * sizeof(int));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里接口使用不当,context.GetPlace() 返回的是 Place,是一个boost::variant<CPUPlace, GPUPlace>,而Alloc需要接收一个确定的GPUPlace or CPUPlace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


for (int i = ndims - 2; i >= 0; i--) {
in_offset[i] = in_offset[i + 1] * in_dim[i + 1];
out_offset[i] = out_offset[i + 1] * out_dim[i + 1];
}

for (int i = 0; i < ndims; i++) {
host_buffer_data[i] = in_offset[i];
host_buffer_data[i + ndims] = out_offset[i];
host_buffer_data[i + ndims * 2] = axis[i];
}

memory::Copy(gpu_place, offset_buffer, cpu_place, host_buffer_data,
ndims * 3 * sizeof(int));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在memory copy的时候,需要指定对应的cuda Stream

template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
          cudaStream_t stream);

cuda stream 可以从context里面拿

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int block = 512;
int grid = (data_size + block - 1) / block;
transpose_kernel<T><<<grid, block>>>(data_size, in_data, out_data,
static_cast<int*>(offset_buffer), ndims);
memory::Free(gpu_place, offset_buffer);
}

template <typename T>
class TransposeCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto axis = context.GetAttr<std::vector<int>>("axis");
TransposeCUDA<T>(context, *in, *out, axis);
}
};

template <typename T>
class TransposeGradCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto axis_temp = context.GetAttr<std::vector<int>>("axis");

std::vector<int> axis(axis_temp);

for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i;
}
TransposeCUDA<T>(context, *in, *out, axis);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(transpose, ops::TransposeCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(transpose_grad, ops::TransposeGradCUDAKernel<float>);
141 changes: 141 additions & 0 deletions paddle/operators/transpose_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
void NaiveCpuTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
auto in_data = in.data<T>();
auto out_data = out.mutable_data<T>(context.GetPlace());
auto in_dim = in.dims();
auto out_dim = out.dims();
size_t ndims = in_dim.size();

std::vector<int> in_offset(ndims, 1);
std::vector<int> out_offset(ndims, 1);

for (int i = ndims - 2; i >= 0; i--) {
in_offset[i] = in_offset[i + 1] * in_dim[i + 1];
out_offset[i] = out_offset[i + 1] * out_dim[i + 1];
}

size_t data_size = product(in_dim);

for (size_t to_index = 0; to_index < data_size; to_index++) {
int from_index = 0;
int temp = to_index;
for (size_t i = 0; i < ndims; i++) {
from_index += (temp / out_offset[i]) * in_offset[axis[i]];
temp = temp % out_offset[i];
}
out_data[to_index] = in_data[from_index];
}
}

template <typename Place, typename T, int Dims>
void DoTranspose(const framework::ExecutionContext& context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function calls Eigen to do the transpose, how about renaming it to EigenCpuTranspose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great

const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
Eigen::array<int, Dims> permute;
for (int i = 0; i < Dims; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out.dims();

auto eigen_in = framework::EigenTensor<T, Dims>::From(in);
auto eigen_out = framework::EigenTensor<T, Dims>::From(out);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute);
}

template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());

auto axis = context.GetAttr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 2:
DoTranspose<Place, T, 2>(context, *in, *out, axis);
break;
case 3:
DoTranspose<Place, T, 3>(context, *in, *out, axis);
break;
case 4:
DoTranspose<Place, T, 4>(context, *in, *out, axis);
break;
case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉rank 5也不差多了,觉得可以去掉NaiveCpuTranspose, 直接用Eigen::shuffle, 这个不支持GPU吗? rank > 5 时:

PADDLE_THROW("Tensors with rank at most 6 are supported").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,如果对与维度大于5的不支持的话,是可以的,省了很多代码,包括NaiveCpuTranspose 以及 Gpu kernel的代码

default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ndims is 1, calling NaiveCpuTranspose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ndmis > 5, calling NaiveCpuTranspose

NaiveCpuTranspose<Place, T>(context, *in, *out, axis);
break;
}
}
};

template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in -> in_grad, out -> out_grad

out->mutable_data<T>(context.GetPlace());

auto axis_temp = context.GetAttr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about : axis_temp -> axis and axis -> axis_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no grad for axis


for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about rename axis_temp to axis and rename current axis to reversed_axis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good choice

}

int ndims = axis.size();

switch (ndims) {
case 2:
DoTranspose<Place, T, 2>(context, *in, *out, axis);
break;
case 3:
DoTranspose<Place, T, 3>(context, *in, *out, axis);
break;
case 4:
DoTranspose<Place, T, 4>(context, *in, *out, axis);
break;
case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
default:
NaiveCpuTranspose<Place, T>(context, *in, *out, axis);
break;
}
}
};

} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ USE_OP(minus);
USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_OP(transpose);

namespace paddle {
namespace framework {
Expand Down