Skip to content

Commit

Permalink
Changes to PSparamManager updating of global model:
Browse files Browse the repository at this point in the history
	The merging of the worker models into the global model now uses a new virtual function of ParamInterface that takes a pointer array (which defaults to a simple loop). This provides support for more sophisticated merge strategies.
	The merge is now performed into a temporary object that is then moved to replace the global model. The workers are only locked while the pointers are updated, improving parallelization.
  • Loading branch information
giltirn committed Jun 28, 2022
1 parent 9f4cecd commit 8369129
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
7 changes: 7 additions & 0 deletions include/chimbuko/param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ namespace chimbuko {
*/
virtual void update(const ParamInterface &other) = 0;

/**
* @brief Update the internal run statistics with those from multiple other instances
*
* The instance will be dynamically cast to the derived type internally, and will throw an error if the types do not match
* The other instance will be locked during the process for thread safety
*/
virtual void update(const std::vector<ParamInterface*> &other);

/**
* @brief Set the internal run statistics to match those included in the serialized input map. Overwrite performed only for those keys in input.
Expand Down
1 change: 1 addition & 0 deletions include/chimbuko/pserver/PSparamManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ namespace chimbuko{
ParamType & getWorkerParams(const int i){ return dynamic_cast<ParamType&>(*m_worker_params[i]); }

private:
std::string m_ad_algorithm; /**< The AD algorithm*/
int m_agg_freq_ms; /**< The frequence in ms at which the global model is updated. Default 1000ms*/
ParamInterface *m_global_params; /**< The global model*/
std::string m_latest_global_params; /**< Cache of the serialized form the the latest global model*/
Expand Down
6 changes: 6 additions & 0 deletions src/param.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ ParamInterface *ParamInterface::set_AdParam(const std::string & ad_algorithm) {
fatal_error("Invalid algorithm: \"" + ad_algorithm + "\". Available options: HBOS, SSTD, COPOD");
}
}

void ParamInterface::update(const std::vector<ParamInterface*> &other){
for(auto p : other){
this->update(*p);
}
}
21 changes: 15 additions & 6 deletions src/pserver/PSparamManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using namespace chimbuko;

PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorithm): m_agg_freq_ms(1000), m_updater_thread(nullptr), m_worker_params(nworker,nullptr), m_global_params(nullptr), m_updater_exit(false), m_force_update(false){
PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorithm): m_agg_freq_ms(1000), m_updater_thread(nullptr), m_worker_params(nworker,nullptr), m_global_params(nullptr), m_updater_exit(false), m_force_update(false), m_ad_algorithm(ad_algorithm){
for(int i=0;i<nworker;i++)
m_worker_params[i] = ParamInterface::set_AdParam(ad_algorithm);
m_global_params = ParamInterface::set_AdParam(ad_algorithm);
Expand All @@ -13,11 +13,20 @@ PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorith

void PSparamManager::updateGlobalModel(){
verboseStream << "PSparamManager::updateGlobalModel updating global model" << std::endl;
std::unique_lock<std::shared_mutex> _(m_mutex); //unique lock to prevent read/write from other threads
m_global_params->clear(); //reset the global params and reform from worker params which have been aggregating since the start of the run
for(auto p: m_worker_params)
m_global_params->update(*p); //locks the worker params temporarily
m_latest_global_params = m_global_params->serialize();

//Avoid needing to lock out worker threads while updating by merging into a new location and moving after
ParamInterface* new_glob_params = ParamInterface::set_AdParam(m_ad_algorithm);
new_glob_params->update(m_worker_params);
std::string new_glob_params_ser = new_glob_params->serialize();

ParamInterface *tmp;
{
std::unique_lock<std::shared_mutex> _(m_mutex); //unique lock to prevent read/write from other threads
tmp = m_global_params;
m_global_params = new_glob_params;
m_latest_global_params = std::move(new_glob_params_ser);
}
delete tmp;
}


Expand Down

0 comments on commit 8369129

Please sign in to comment.