-
-
Notifications
You must be signed in to change notification settings - Fork 1k
premature stopping
Premature stopping feature enables us to write algorithms that can be stopped prematurely and yet obtain some meaningful result from the computation that has been done. This is meant to be a small guide to summarize what this feature consists of and how it is implemented so that developers can have a better insight on what was done and how.
At the present moment, this feature is implemented only for CMachine
derived classes, which means that only algorithms which inherit from CMachine
will be able to benefit from the premature stopping feature.
Here I will show some code snippets (taken directly from src/shogun/base/machine/Machine.h
) which show which methods and member variables were added.
To manage user's signals, the RxCpp library was used which enabled us to easily implement an observer pattern. There is a global observable which all algorithms can observe. For instance, when the user decides to terminate the execution of a machine, this global observable will send a signal to the machine which will act accordingly to the message received.
To make a machine observe the global observable we used this method:
rxcpp::subscription CMachine::connect_to_signal_handler()
{
// Subscribe this algorithm to the signal handler
auto subscriber = rxcpp::make_subscriber<int>(
[this](int i) {
if (i == SG_PAUSE_COMP)
this->on_pause();
else
this->on_next();
},
[this]() { this->on_complete(); });
return get_global_signal()->get_observable()->subscribe(subscriber);
}
First of all, some member variables were added:
-
std::atomic<bool> m_cancel_computation
: atomic flag which is used to indicate when an algorithm needs to be prematurely stopped; -
std::atomic<bool> m_pause_computation_flag
: atomic flag which is used to indicate when an algorithm needs to be paused; -
std::condition_variable m_pause_computation
: conditional variable used to make threads wait;
There are then three methods which can be used to manage these variables and implement the stopping/pausing/resuming behaviour of algorithms. These methods are pretty self-explicative and their usage is quite standard. For these reasons, a macro was defined to simplify developer's life. If one wants to simply add to its method premature stopping capabilities, it just needs to add this macro on top of his loop statements.
SG_FORCED_INLINE bool cancel_computation() const;
SG_FORCED_INLINE void pause_computation();
SG_FORCED_INLINE void resume_computation();
/** Macro which gather together all the methods before**/
#define COMPUTATION_CONTROLLERS\
if (cancel_computation())\
continue;\
pause_computation();
Obviously, to make things highly customizable, there are also three more methods which are called when we want to: prematurely stop the algorithm's execution, pause the algorithm's execution or terminate the program.
/** The action which will be done when the user decides to
* premature stop the CMachine execution */
virtual void on_next()
{
m_cancel_computation.store(true);
}
/** The action which will be done when the user decides to
* pause the CMachine execution */
virtual void on_pause()
{
m_pause_computation_flag.store(true);
/* Here there should be the actual code*/
resume_computation();
}
/** The action which will be done when the user decides to
* return to prompt and terminate the program execution */
virtual void on_complete()
{
}
Inside these methods, the algorithm's developer can write whichever actions he wants. These methods can be used to add an additional behaviour to the machine (for example, print some diagnostic information).
This code below represent a custom implementation of the premature stopping feature and it shows how to use this new feature to extend (or add) a new algorithm.
#include <shogun/base/init.h>
#include <shogun/base/some.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/Signal.h>
#include <shogun/machine/Machine.h>
#include <iostream>
#include <thread>
using namespace shogun;
using namespace std;
// Mock Algorithm which implements a fake train_machine method
class MockAlg : public CMachine {
public:
MockAlg() : CMachine() {}
~MockAlg() {}
protected:
virtual bool train_machine(CFeatures * feat) {
cout << "Training machine..." << endl;
#pragma omp parallel for
for(int i=0; i<10000; i++)
{
COMPUTATION_CONTROLLERS
cout << i << endl;
// We set up a delay to simulate long computation
this_thread::sleep_for(chrono::milliseconds(1000));
}
}
// This on_pause() method will just print to the screen
// the string "PAUSED" and the it will resume the computation.
virtual void on_pause()
{
m_pause_computation_flag = true;
cout << "PAUSED" << endl;
this_thread::sleep_for(chrono::milliseconds(5000));
resume_computation();
}
};
int main() {
init_shogun_with_defaults();
// Set up binary labels
int * labs = new int[2];
labs[0] = -1;
labs[1] = 1;
SGVector<int32_t> labs_v {labs, 2};
auto train_labs = some<CBinaryLabels>(labs_v);
// We enable the signal handler. This way when pressing
// CTRL+C, it will be Shogun handler to catch the SIGINT.
get_global_signal()->enable_handler();
MockAlg a, b;
a.set_labels(train_labs);
a.train();
b.set_labels(train_labs);
b.train();
return 0;
}