Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose const value #6744

Merged
merged 4 commits into from
Dec 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init
${GLOB_OP_LIB})
endif(WITH_PYTHON)
Expand Down
29 changes: 29 additions & 0 deletions paddle/pybind/const_value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "const_value.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace pybind {

void BindConstValue(pybind11::module& m) {
m.def("kEmptyVarName", [] { return framework::kEmptyVarName; });
m.def("kTempVarName", [] { return framework::kTempVarName; });
m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
}

} // namespace pybind
} // namespace paddle
26 changes: 26 additions & 0 deletions paddle/pybind/const_value.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <Python.h>
#include "paddle/platform/enforce.h"
#include "pybind11/pybind11.h"

namespace py = pybind11;

namespace paddle {
namespace pybind {
extern void BindConstValue(pybind11::module& m);
} // namespace pybind
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/const_value.h"
#include "paddle/pybind/exception.h"
#include "paddle/pybind/pybind.h"
#include "paddle/pybind/tensor_py.h"
Expand Down Expand Up @@ -431,6 +432,7 @@ All parameter, weight, gradient are variables in Paddle.
BindBlockDesc(m);
BindVarDsec(m);
BindOpDesc(m);
BindConstValue(m);

py::class_<framework::LoDRankTable>(m, "LodRankTable")
.def("items", [](framework::LoDRankTable &table) {
Expand Down
18 changes: 15 additions & 3 deletions python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import collections
import contextlib

import numpy as np
from . import core

import proto.framework_pb2 as framework_pb2
import google.protobuf.message
import contextlib
from . import core

__all__ = [
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
'default_main_program', 'program_guard', 'switch_startup_program',
'switch_main_program'
]

EMPTY_VAR_NAME = core.kEmptyVarName()
TEMP_VAR_NAME = core.kTempVarName()
GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()


def grad_var_name(var_name):
"""
return gradient name for a certain var name
"""
return var_name + GRAD_VAR_SUFFIX


def unique_name(prefix):
"""
Expand Down
5 changes: 1 addition & 4 deletions python/paddle/v2/fluid/tests/test_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from op_test import OpTest
import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator


def grad_var_name(var_name):
return var_name + "@GRAD"
from paddle.v2.fluid.framework import grad_var_name


def get_backward_op(scope, op, no_grad_set):
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/v2/fluid/tests/test_const_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest
import paddle.v2.fluid.framework as framework


class ConditionalBlock(unittest.TestCase):
def test_const_value(self):
self.assertEqual(framework.GRAD_VAR_SUFFIX, "@GRAD")
self.assertEqual(framework.TEMP_VAR_NAME, "@TEMP@")
self.assertEqual(framework.GRAD_VAR_SUFFIX, "@GRAD")
self.assertEqual(framework.ZERO_VAR_SUFFIX, "@ZERO")


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/v2/fluid/tests/test_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

import paddle.v2.fluid.op as op
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2


Expand Down
8 changes: 3 additions & 5 deletions python/paddle/v2/fluid/tests/test_program.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import print_function
import unittest

from paddle.v2.fluid.framework import Program, default_main_program, program_guard
from paddle.v2.fluid.framework import Program, default_main_program, program_guard, grad_var_name
import paddle.v2.fluid.layers as layers

main_program = default_main_program()
Expand Down Expand Up @@ -109,12 +109,10 @@ def test_append_backward(self):
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(mean_out, set())

def grad_name(name):
return name + "@GRAD"

for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out",
"mean.out"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][0],
grad_var_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)

expect_ops = [
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/v2/fluid/tests/test_recurrent_op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.framework import Program, grad_var_name
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward_ops
import numpy as np
Expand Down Expand Up @@ -164,7 +164,7 @@ def backward(self):
for x in self.data_field
}
fetch_list = [
self.main_program.global_block().var(x + "@GRAD")
self.main_program.global_block().var(grad_var_name(x))
for x in self.data_field
]

Expand Down