Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Add mean/momentum yaml #41319

Merged
merged 14 commits into from
Apr 5, 2022
143 changes: 143 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,149 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out;
}

std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
paddle::optional<const Tensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "momentum";
if (grad.is_selected_rows()) {
kernel_name = "momentum_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
auto input_velocity = PrepareData(velocity, kernel.InputAt(2), {});
auto input_learning_rate = PrepareData(learning_rate, kernel.InputAt(3), {});
paddle::optional<const phi::DenseTensor&> input_master_param(paddle::none);
auto input_master_param_ptr =
PrepareData(master_param, kernel.InputAt(4), {});

std::tuple<Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_velocity.get();
phi::DenseTensor* kernel_out_2 = nullptr;
if (input_master_param_ptr) {
input_master_param =
paddle::make_optional<const phi::DenseTensor&>(*input_master_param_ptr);
kernel_out_2 =
paddle::make_optional<phi::DenseTensor&>(*input_master_param_ptr)
.get_ptr();
}

paddle::optional<const phi::MetaTensor&> input_meta_ref_master_param(
paddle::none);
phi::DenseTensor dt;
phi::MetaTensor input_meta_tmp_master_param(dt);
if (input_master_param_ptr) {
input_meta_tmp_master_param.set_dtype(input_master_param_ptr->dtype());
input_meta_tmp_master_param.set_dims(input_master_param_ptr->dims());
input_meta_tmp_master_param.set_layout(input_master_param_ptr->layout());
input_meta_ref_master_param = input_meta_tmp_master_param;
}
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
if (kernel_out_2) {
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
&meta_out_2);
} else {
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
nullptr);
}

using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
paddle::optional<const phi::DenseTensor&>,
float,
bool,
const std::string&,
float,
bool,
float,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();

(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_velocity,
*input_learning_rate,
input_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
kernel_out_0,
kernel_out_1,
kernel_out_2);

return api_output;
}

////////////////// Backward(grad) api impls //////////////////////

// TODO(chenweihang): the original sum grad op can support higher-level
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"

namespace paddle {
namespace experimental {
Expand All @@ -33,6 +34,19 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);

std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
paddle::optional<const Tensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad);

////////////////// Backward(grad) api impls //////////////////////

std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,53 @@ void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
}
}

void MomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out) {
PADDLE_ENFORCE_NE(
param_out,
nullptr,
errors::NotFound("Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_NE(
velocity_out,
nullptr,
errors::NotFound("Output(VelocityOut) of Momentum should not be null."));

auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
phi::product(lr_dims),
0,
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
1,
errors::InvalidArgument("Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));

auto param_dim = param.dims();
param_out->set_dims(param_dim);
velocity_out->set_dims(param_dim);

if (master_param_out) {
master_param_out->set_dims(param_dim);
}
}

void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);

Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ void InterpolateInferMeta(
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);

void MomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out);

void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);

void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12806,8 +12806,10 @@ def mean(x, name=None):
mean = fluid.layers.mean(input)
"""

if _non_static_mode():
if _in_legacy_dygraph():
return _C_ops.mean(x)
if in_dygraph_mode():
return _C_ops.final_state_mean_all(x)

helper = LayerHelper("mean", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mean')
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

from paddle.fluid.framework import _test_eager_guard
np.random.seed(10)


Expand All @@ -40,7 +40,7 @@ def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False):
class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = mean_wrapper
self.python_api = fluid.layers.mean
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
Expand Down Expand Up @@ -81,7 +81,7 @@ def init_dtype_type(self):
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
self.check_output_with_place(place, check_eager=True)

def test_checkout_grad(self):
place = core.CUDAPlace(0)
Expand All @@ -104,11 +104,11 @@ def init_dtype_type(self):

def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(core.CPUPlace(), check_eager=True)

def test_checkout_grad(self):
place = core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)


def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/fluid/tests/unittests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
import paddle.fluid as fluid
import numpy
from paddle.fluid.framework import _test_eager_guard


def calculate_momentum_by_numpy(param,
Expand Down Expand Up @@ -528,6 +529,11 @@ def test_raise_error(self):
ValueError, paddle.optimizer.Momentum, learning_rate=None)
self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None)

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_momentum_dygraph()
self.test_raise_error()


class TestMomentumOpWithDecay(OpTest):
def setUp(self):
Expand Down Expand Up @@ -921,6 +927,10 @@ def test_main(self):
self._check_with_param_arrt(place, use_amp)
self._check_with_param_group(place, use_amp)

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_main()


class TestMultiTensorMomentumStatic(unittest.TestCase):
def _momentum_optimize_static(self,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle import _C_ops
import paddle
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph

__all__ = []

Expand Down Expand Up @@ -313,7 +314,7 @@ def _append_optimize_op(self, block, param_and_grad):
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)

if framework._non_static_mode():
if _in_legacy_dygraph():
if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay'])
_, _, _ = _C_ops.momentum(
Expand All @@ -323,8 +324,15 @@ def _append_optimize_op(self, block, param_and_grad):
'regularization_method', regularization_method,
'regularization_coeff', regularization_coeff, 'multi_precision',
find_master)

return None
if in_dygraph_mode():
if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay'])
return _C_ops.final_state_momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
master_weight, self._momentum, self._use_nesterov,
regularization_method, regularization_coeff, find_master,
self._rescale_grad)

attrs = {
"mu": self._momentum,
Expand Down
Loading