-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implmentation of Quasi-Hyperbolic Momentum for Adam #81
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
39af38d
First Commit
niteya-shah 160c1f3
Implementation of QHAdam Update
niteya-shah 9f481ab
Added inline to template Specialisation
niteya-shah 60672c1
Fixes to Adam and addition of QHSGD
niteya-shah adb7efe
Added Tests for QHAdam
niteya-shah d1976ed
Style Fixes and additional documentation
niteya-shah 30ffe57
Style Fixes and additional documentation
niteya-shah 0434e7c
Merge remote-tracking branch 'niteya-shah/QHAdam' into QHAdam
niteya-shah e1d9917
Update qhadam_update.hpp
niteya-shah 63238ad
Fix to Parameter type
niteya-shah d8a7274
Added documentation
niteya-shah 6bc84bf
fix to resolve conflict
niteya-shah 35d1868
Merge branch 'master' into QHAdam
niteya-shah 32ab0d7
Added changes to reflect those of AdamW
niteya-shah 9ec5eb4
Documentation Fixes and Added test for QHSGD
niteya-shah 97b9494
added to function types
niteya-shah a40b2a1
documentation fix
niteya-shah 38d913e
Changed some test parameters
niteya-shah 692b692
documentation fixes
niteya-shah 438aea6
Documentation FIxes and parameterisation change
niteya-shah 6a81c91
Documentation Fixes
niteya-shah 14cddea
doc fixes
niteya-shah 0f17a19
documentation fixes wrt review
niteya-shah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/** | ||
* @file qhadam.hpp | ||
* @author Niteya Shah | ||
* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add add the license and a description. |
||
* Class wrapper for the QHAdam update Policy. QHAdam is a variant of the Adam | ||
* based on quasi hyperbolic moments. | ||
* | ||
* ensmallen is free software; you may redistribute it and/or modify it under | ||
* the terms of the 3-clause BSD license. You should have received a copy of | ||
* the 3-clause BSD license along with ensmallen. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef ENSMALLEN_ADAM_QHADAM_HPP | ||
#define ENSMALLEN_ADAM_QHADAM_HPP | ||
|
||
#include <ensmallen_bits/sgd/sgd.hpp> | ||
#include "qhadam_update.hpp" | ||
|
||
namespace ens { | ||
|
||
/** | ||
* QHadam is an variation of Adam with Quasi-Hyperbolic step. It can be | ||
* a weighted mean of the momentum step. Due to its paramterisation it can | ||
* recover many other optimisation strategies. | ||
* | ||
* For more information, see the following. | ||
* | ||
* @code | ||
* @inproceedings{ma2019qh, | ||
* title={Quasi-hyperbolic momentum and Adam for deep learning}, | ||
* author={Jerry Ma and Denis Yarats}, | ||
* booktitle={International Conference on Learning Representations}, | ||
* year={2019} | ||
* } | ||
* @endcode | ||
* | ||
* QHAdam can optimize differentiable separable functions. For more details, | ||
* see the documentation on function types included with this distribution or | ||
* on the ensmallen website. | ||
*/ | ||
class QHAdam | ||
{ | ||
public: | ||
/** | ||
* Construct the QHAdam optimizer with the given function and parameters. | ||
* QHAdam is sensitive to its paramters and hence a good hyper paramater | ||
* selection is necessary as its default may not fit every case. | ||
* | ||
* The maximum number of iterations refers to the maximum number of | ||
* points that are processed (i.e., one iteration equals one point; one | ||
* iteration does not equal one pass over the dataset). | ||
* | ||
* @param stepSize Step size for each iteration. | ||
* @param batchSize Number of points to process in a single step. | ||
* @param v1 The first quasi-hyperbolic term. | ||
* @param v1 The second quasi-hyperbolic term. | ||
* @param beta1 Exponential decay rate for the first moment estimates. | ||
* @param beta2 Exponential decay rate for the weighted infinity norm | ||
estimates. | ||
* @param epsilon Value used to initialise the mean squared gradient parameter. | ||
* @param maxIterations Maximum number of iterations allowed (0 means no | ||
* limit). | ||
* @param tolerance Maximum absolute tolerance to terminate algorithm. | ||
* @param shuffle If true, the function order is shuffled; otherwise, each | ||
* function is visited in linear order. | ||
* @param resetPolicy If true, parameters are reset before every Optimize | ||
* call; otherwise, their values are retained. | ||
*/ | ||
QHAdam(const double stepSize = 0.001, | ||
const size_t batchSize = 32, | ||
const double v1 = 0.7, | ||
const double v2 = 1, | ||
const double beta1 = 0.9, | ||
const double beta2 = 0.999, | ||
const double epsilon = 1e-8, | ||
const size_t maxIterations = 100000, | ||
const double tolerance = 1e-5, | ||
const bool shuffle = true, | ||
const bool resetPolicy = true); | ||
|
||
/** | ||
* Optimize the given function using QHAdam. The given starting point will be | ||
* modified to store the finishing point of the algorithm, and the final | ||
* objective value is returned. | ||
* | ||
* @tparam DecomposableFunctionType Type of the function to optimize. | ||
* @param function Function to optimize. | ||
* @param iterate Starting point (will be modified). | ||
* @return Objective value of the final point. | ||
*/ | ||
template<typename DecomposableFunctionType> | ||
double Optimize(DecomposableFunctionType& function, arma::mat& iterate) | ||
{ | ||
return optimizer.Optimize(function, iterate); | ||
} | ||
|
||
//! Get the step size. | ||
double StepSize() const { return optimizer.StepSize(); } | ||
//! Modify the step size. | ||
double& StepSize() { return optimizer.StepSize(); } | ||
|
||
//! Get the batch size. | ||
size_t BatchSize() const { return optimizer.BatchSize(); } | ||
//! Modify the batch size. | ||
size_t& BatchSize() { return optimizer.BatchSize(); } | ||
|
||
//! Get the smoothing parameter. | ||
double Beta1() const { return optimizer.UpdatePolicy().Beta1(); } | ||
//! Modify the smoothing parameter. | ||
double& Beta1() { return optimizer.UpdatePolicy().Beta1(); } | ||
|
||
//! Get the second moment coefficient. | ||
double Beta2() const { return optimizer.UpdatePolicy().Beta2(); } | ||
//! Modify the second moment coefficient. | ||
double& Beta2() { return optimizer.UpdatePolicy().Beta2(); } | ||
|
||
//! Get the value used to initialise the mean squared gradient parameter. | ||
double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); } | ||
//! Modify the value used to initialise the mean squared gradient parameter. | ||
double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); } | ||
|
||
//! Get the maximum number of iterations (0 indicates no limit). | ||
size_t MaxIterations() const { return optimizer.MaxIterations(); } | ||
//! Modify the maximum number of iterations (0 indicates no limit). | ||
size_t& MaxIterations() { return optimizer.MaxIterations(); } | ||
|
||
//! Get the tolerance for termination. | ||
double Tolerance() const { return optimizer.Tolerance(); } | ||
//! Modify the tolerance for termination. | ||
double& Tolerance() { return optimizer.Tolerance(); } | ||
|
||
//! Get whether or not the individual functions are shuffled. | ||
bool Shuffle() const { return optimizer.Shuffle(); } | ||
//! Modify whether or not the individual functions are shuffled. | ||
bool& Shuffle() { return optimizer.Shuffle(); } | ||
|
||
//! Get whether or not the update policy parameters | ||
//! are reset before Optimize call. | ||
bool ResetPolicy() const { return optimizer.ResetPolicy(); } | ||
//! Modify whether or not the update policy parameters | ||
//! are reset before Optimize call. | ||
bool& ResetPolicy() { return optimizer.ResetPolicy(); } | ||
|
||
//! Get the first quasi hyperbolic parameter. | ||
double V1() const { return optimizer.UpdatePolicy().V1(); } | ||
//! Modify the first quasi hyperbolic parameter. | ||
double& V1() { return optimizer.UpdatePolicy().V1(); } | ||
|
||
//! Get the second quasi hyperbolic parameter. | ||
double V2() const { return optimizer.UpdatePolicy().V2(); } | ||
//! Modify the second quasi hyperbolic parameter. | ||
double& V2() { return optimizer.UpdatePolicy().V2(); } | ||
|
||
private: | ||
//! The Stochastic Gradient Descent object with QHAdam policy. | ||
SGD<QHAdamUpdate> optimizer; | ||
}; | ||
|
||
} // namespace ens | ||
|
||
// Include implementation. | ||
#include "qhadam_impl.hpp" | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth also linking to Adam, Nadam, and RMSprop? (And vice versa from those?)