Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LAMB Beta1Pow and Beta2Pow update error #38518

Merged
merged 1 commit into from
Dec 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 108 additions & 72 deletions paddle/fluid/operators/optimizers/lamb_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor {
const bool* skip_update_;

LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
MT beta1_pow, MT* beta1_pow_out, MT beta2_pow,
MT* beta2_pow_out, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const T* grad,
const MT* param, MT* trust_ratio_div,
const bool* skip_update)
MT beta1_pow, MT beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out,
const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor {
MT epsilon_;

const MT* beta1_pow_;
MT* beta1_pow_out_;
const MT* beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
Expand All @@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor {
const bool* skip_update_;

LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow, MT* beta1_pow_out,
const MT* beta2_pow, MT* beta2_pow_out,
const MT* beta1_pow, const MT* beta2_pow,
const MT* mom1, MT* mom1_out, const MT* mom2,
MT* mom2_out, const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
Expand All @@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor {
T epsilon_;

T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor {
const bool* skip_update_;

SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count, const bool* skip_update)
T beta1_pow, T beta2_pow, const T* mom1,
T* mom1_out, const T* mom2, T* mom2_out,
const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
int64_t row_numel, int64_t row_count,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor {
T epsilon_;

const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor {
const bool* skip_update_;

SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* beta1_pow, const T* beta2_pow,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
Expand All @@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor {
}
};

template <typename T, bool IsMultiPrecision>
struct LambParamUpateFunctor {
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
template <typename MT, bool NeedUpdateBetaPow /*=true*/>
struct LambBetaPowUpdateFunctor {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {
beta1pow_ = beta1pow;
beta2pow_ = beta2pow;
beta1pow_out_ = beta1pow_out;
beta2pow_out_ = beta2pow_out;
beta1_ = beta1;
beta2_ = beta2;
}

HOSTDEVICE void UpdateBetaPow(size_t i) const {
if (i == 0) {
beta1pow_out_[0] = beta1pow_[0] * beta1_;
beta2pow_out_[0] = beta2pow_[0] * beta2_;
}
}

private:
const MT* beta1pow_;
const MT* beta2pow_;
MT* beta1pow_out_;
MT* beta2pow_out_;
MT beta1_;
MT beta2_;
};

template <typename MT>
struct LambBetaPowUpdateFunctor<MT, /*NeedUpdateBetaPow=*/false> {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {}
HOSTDEVICE void UpdateBetaPow(size_t) const {}
};

template <typename T, typename MT, bool IsMultiPrecision, bool UpdateBetaPow>
struct LambParamUpateFunctor
: public LambBetaPowUpdateFunctor<MT, UpdateBetaPow> {
const MT* lr_;
const T* param_;
const MT* master_param_;
Expand Down Expand Up @@ -396,6 +396,7 @@ struct LambParamUpateFunctor {
if (IsMultiPrecision) {
master_param_out_[i] = param_out;
}
this->UpdateBetaPow(i);
}
};

Expand Down Expand Up @@ -501,15 +502,19 @@ class LambOpKernel : public framework::OpKernel<T> {
: nullptr;

// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<MT>(),
nullptr, *beta2_pow.template data<MT>(), nullptr,
mom1.template data<MT>(),
*beta2_pow.template data<MT>(), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<MT>(),
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace()),
beta2_pow.template data<MT>(),
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace()),
mom1.template data<MT>(),
weight_decay, beta1, beta2, epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(IsMultiPrecision, false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True"));
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(kIsSameType, true,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
Expand Down Expand Up @@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel<T> {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
*beta1_pow.template data<T>(), nullptr,
*beta2_pow.template data<T>(), nullptr, mom1.template data<T>(),
*beta1_pow.template data<T>(), *beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand All @@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand Down Expand Up @@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel<T> {
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();

LambParamUpateFunctor<T, IsMultiPrecision> param_update_functor(
lr.template data<MT>(), static_cast<const T*>(param_ptr),
static_cast<const MT*>(master_param_ptr), p_norm_t.template data<MT>(),
trust_ratio_div.template data<MT>(),
trust_ratio_div_norm_t.template data<MT>(),
static_cast<T*>(param_out_ptr), static_cast<MT*>(master_param_out_ptr),
skip_update_flag);
for_range(param_update_functor);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \
beta1_pow_out_ptr, beta2_pow_out_ptr, \
beta1, beta2); \
} \
for_range(param_update_functor); \
} while (0)

if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}

#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};

Expand Down