Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix master weight bug for multi_tensor optimizer(momentum, adam) #38991

Merged
merged 2 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 60 additions & 50 deletions paddle/fluid/operators/optimizers/merged_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ struct MergedMomentumKernelParam
T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr;
const MultiPrecisionType<MT> *PADDLE_RESTRICT lr;
MT mu;
MT rescale_grad;
uint32_t param_num;

HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
const MT lr_val = static_cast<MT>(*lr);
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;
Expand All @@ -81,8 +81,22 @@ struct MergedMomentumKernelParam

template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
using MPType = typename operators::details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext &ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}

private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext &ctx,
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
Expand Down Expand Up @@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
Expand Down Expand Up @@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();

using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MT>(mu); \
kernel_params.rescale_grad = static_cast<MT>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MT>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
Expand All @@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;

MPType regularization_coeff = static_cast<MPType>(0.0);
MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];

const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr;
MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp,
static_cast<MT>(mu), use_nesterov, regularization_flag,
regularization_coeff, params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), lr_temp->data<MPType>(), master_in_data, \
static_cast<MT>(mu), static_cast<MT>(rescale_grad), \
params[idx]->numel(), regularization_coeff, params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'

_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
Expand All @@ -571,7 +570,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], 'epsilon', self._epsilon,
'beta1', _beta1, 'beta2', _beta2, 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
Expand All @@ -594,11 +593,11 @@ def _append_optimize_multi_tensor_op(self, target_block,
"beta1": _beta1,
"beta2": _beta2
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_adam",
inputs=inputs,
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'

if framework.in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum(
Expand All @@ -478,7 +477,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
self._regularization_method_dict[key],
'regularization_coeff',
self._regularization_coeff_dict[key], 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
Expand All @@ -498,11 +497,11 @@ def _append_optimize_multi_tensor_op(self, target_block,
"regularization_coeff":
self._regularization_coeff_dict[key],
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_momentum",
inputs=inputs,
Expand Down