Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make network op #3067

Merged
merged 4 commits into from
Jul 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)

proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
16 changes: 4 additions & 12 deletions paddle/framework/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,7 @@
namespace paddle {
namespace framework {

std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}

void PlainNet::CompleteAddOp(bool calc) {
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set;
Expand Down Expand Up @@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index;
}

std::string PlainNet::DebugString() const {
std::string NetOp::DebugString() const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) {
Expand All @@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str();
}

bool NetOp::IsNetOp() const { return true; }

} // namespace framework
} // namespace paddle
24 changes: 5 additions & 19 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* it defines.
*/
class Net : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};

using NetPtr = std::shared_ptr<Net>;

/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
class NetOp : public OperatorBase {
public:
/**
* Infer all the operators' input and output variables' shapes, will be called
Expand Down Expand Up @@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) override {
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}

void CompleteAddOp(bool calculate = true) override;
void CompleteAddOp(bool calculate = true);

std::string DebugString() const override;

bool IsNetOp() const override;

std::vector<std::shared_ptr<OperatorBase>> ops_;

private:
Expand All @@ -100,7 +88,5 @@ class PlainNet : public Net {
}
};

std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);

} // namespace framework
} // namespace paddle
37 changes: 15 additions & 22 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}

TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>();
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);

auto op1 = std::make_shared<TestOp>();
Expand Down Expand Up @@ -71,28 +71,21 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}

// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
//! TODO(yuyang18): Refine Backward Op.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should remove this in the future

// TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
//}

} // namespace framework
} // namespace paddle
15 changes: 0 additions & 15 deletions paddle/framework/net_proto.proto

This file was deleted.

14 changes: 8 additions & 6 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;

// Get a input with argument's name described in `op_proto`
virtual bool IsNetOp() const { return false; }

//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;

public:
Expand Down
29 changes: 12 additions & 17 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {

class AddOp : public framework::OperatorWithKernel {
class AddOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
Expand All @@ -35,10 +32,10 @@ class AddOp : public framework::OperatorWithKernel {
}
};

class AddOpMaker : public framework::OpProtoAndCheckerMaker {
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
Expand All @@ -50,11 +47,10 @@ The equation is: Out = X + Y
}
};

class AddOpGrad : public framework::OperatorWithKernel {
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
Expand All @@ -64,7 +60,6 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle

REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
5 changes: 2 additions & 3 deletions paddle/operators/add_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"

REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
19 changes: 8 additions & 11 deletions paddle/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto output = context.Output(0)->GetMutable<Tensor>();

output->mutable_data<T>(context.GetPlace());

framework::EigenVector<T>::Flatten(*output).device(
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
}
};

Expand Down
28 changes: 11 additions & 17 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
class OnehotCrossEntropyOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
Expand All @@ -35,15 +32,14 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
outputs[0]->Resize({inputs[0]->dims()[0]});
}
};

class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp");
Expand All @@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
} // namespace paddle

REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
4 changes: 1 addition & 3 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"

REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<
::paddle::platform::GPUPlace, float>);
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
14 changes: 6 additions & 8 deletions paddle/operators/cross_entropy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }

void Compute(const framework::KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>();
const int* label_data =
context.Input(1)->Get<framework::Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>();

Y->mutable_data<T>(context.GetPlace());

Expand Down
Loading