Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Root finding in TauHybridCSolver #867

Merged
merged 6 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions gillespy2/solvers/cpp/c_base/arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ char ArgParser::match_arg(std::string &token)
return 'M';
}

else
if (!token.compare("--use_root_finding"))
{
return 0;
return 'u';
}

return 0;
}

ArgParser::ArgParser(int argc, char *argv[])
Expand Down Expand Up @@ -176,6 +178,10 @@ ArgParser::ArgParser(int argc, char *argv[])
verbose = true;
break;

case 'u':
use_root_finding = true;
break;

case 'R':
std::stringstream(argv[i + 1]) >> rtol;
break;
Expand Down
1 change: 1 addition & 0 deletions gillespy2/solvers/cpp/c_base/arg_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ArgParser
double atol = 1e-12;

bool verbose = false;
bool use_root_finding = false;

ArgParser(int argc, char *argv[]);
~ArgParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ int main(int argc, char* argv[])
parser.rtol,
parser.atol,
parser.max_step,

};

TauHybrid::TauHybridCSolver(&simulation, events, logger, tau_tol, config);
TauHybrid::TauHybridCSolver(&simulation, events, logger, tau_tol, config, parser.use_root_finding);
simulation.output_buffer_final(std::cout);
return simulation.get_status();
}
129 changes: 54 additions & 75 deletions gillespy2/solvers/cpp/c_base/tau_hybrid_cpp_solver/TauHybridSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace Gillespy
// Temporary variable for the reaction's state.
// Does not get updated unless the changes are deemed valid.
double rxn_state = result.reactions[rxn_i];
double old_rxn_state = rxn_state;

if (simulation->reaction_state[rxn_i].mode == SimulationState::DISCRETE) {
unsigned int rxn_count = 0;
Expand All @@ -109,15 +110,14 @@ namespace Gillespy
}
}

bool TakeIntegrationStep(Integrator&sol, IntegrationResults&result, double next_time, int*population_changes,
bool TakeIntegrationStep(Integrator&sol, IntegrationResults&result, double *next_time, int*population_changes,
std::vector<double> current_state, std::set<unsigned int>&rxn_roots,
std::set<int>&event_roots, HybridSimulation*simulation, URNGenerator&urn,
int only_reaction_to_fire){
// Integration Step
// For deterministic reactions, the concentrations are updated directly.
// For stochastic reactions, integration updates the rxn_offsets vector.
//IntegrationResults result = sol.integrate(&next_time, event_roots, rxn_roots);
result = sol.integrate(&next_time, event_roots, rxn_roots);
result = sol.integrate(next_time, event_roots, rxn_roots);
if (sol.status == IntegrationStatus::BAD_STEP_SIZE)
{
simulation->set_status(HybridSimulation::INTEGRATOR_FAILED);
Expand All @@ -138,7 +138,7 @@ namespace Gillespy
// Explicitly check for invalid population state, now that changes have been tallied.
// Note: this should only check species that are reactants or products
for (const auto &r : tau_args_reactants) {
if (current_state[r.id] + population_changes[r.id] < 0) {
if (population_changes[r.id] != 0 && current_state[r.id] + population_changes[r.id] < 0) {
return true;
}
}
Expand All @@ -151,7 +151,8 @@ namespace Gillespy
std::vector<Event> &events,
Logger &logger,
double tau_tol,
SolverConfiguration config)
SolverConfiguration config,
bool default_use_root_finding)
{
if (simulation == NULL)
{
Expand All @@ -164,6 +165,9 @@ namespace Gillespy
int num_reactions = model.number_reactions;
int num_trajectories = simulation->number_trajectories;
std::unique_ptr<Species<double>[]> &species = model.species;
bool use_root_finding = default_use_root_finding;
bool in_event_handling = false;
unsigned int neg_state_loop_cnt = 0;

generator = std::mt19937_64(simulation->random_seed);
URNGenerator urn(simulation->random_seed);
Expand Down Expand Up @@ -320,8 +324,23 @@ namespace Gillespy

IntegrationResults result;

if(in_event_handling){
sol.use_events(events, simulation->reaction_state);
sol.enable_root_finder();
}else if(use_root_finding){
sol.use_reactions(simulation->reaction_state);
sol.enable_root_finder();
if(neg_state_loop_cnt > 0){
neg_state_loop_cnt--;
}else{
use_root_finding = default_use_root_finding;
}
}else{
sol.disable_root_finder();
}

if(!TauHybrid::TakeIntegrationStep(sol, result, next_time, population_changes, current_state, rxn_roots, event_roots, simulation, urn, -1)){

if(!TauHybrid::TakeIntegrationStep(sol, result, &next_time, population_changes, current_state, rxn_roots, event_roots, simulation, urn, -1)){
return;
}

Expand All @@ -330,92 +349,33 @@ namespace Gillespy
// If state is invalid, we took too agressive tau step and need to take a single SSA step forward
// Restore the solver to the intial step state
sol.restore_state();

// Calculate floor()'ed state for use in SSA step
for(int spec_i = 0; spec_i < num_species; ++spec_i){
floored_current_state[spec_i] = floor(current_state[spec_i]);
}
// estimate the time to the first stochatic reaction by assuming constant propensities
double min_tau = 0.0;
int rxn_selected = -1;

double *rxn_state = sol.get_reaction_state();

for (int rxn_k = 0; rxn_k < num_reactions; ++rxn_k) {
HybridReaction &rxn = simulation->reaction_state[rxn_k];
double propensity_value = rxn.ssa_propensity(current_state.data());
double floored_propensity_value = rxn.ssa_propensity(floored_current_state);
//estimate the zero crossing time
if(floored_propensity_value > 0.0){
double est_tau = -1* rxn_state[rxn_k] / propensity_value;

if(rxn_selected == -1 || est_tau < min_tau ){
min_tau = est_tau;
rxn_selected = rxn_k;
}
}
}
if(rxn_selected == -1){
simulation->set_status(HybridSimulation::NEGATIVE_STATE_NO_SSA_REACTION);
return;
}
// if min_tau < 1e-10, we can't take an ODE step that small.
if( min_tau < 1e-10 ){
// instead we will fire the reaction
CalculateSpeciesChangeAfterStep(result, population_changes, current_state, rxn_roots, event_roots, simulation, urn, rxn_selected);
// re-attempt the step at the same time
next_time = simulation->current_time;

}else{
// Use the found tau-step for single SSA
next_time = simulation->current_time + min_tau;

//***********************************
//***********************************

// Integreate the system forward
if(!TauHybrid::TakeIntegrationStep(sol, result, next_time, population_changes, current_state, rxn_roots, event_roots, simulation, urn, rxn_selected)){
return;
}
}
// check for invalid state again
if (TauHybrid::IsStateNegativeCheck(num_species, population_changes, current_state, tau_args.reactants)) {
//Got an invalid state after the SSA step
simulation->set_status(HybridSimulation::INVALID_AFTER_SSA);
return;
}
use_root_finding=true;
neg_state_loop_cnt = 2; // How many single SSA events should we find before we go back to tau steping
continue;
}

// "Permanently" update the rxn_state and populations.

// Update solver object with stochastic changes
for (int p_i = 0; p_i < num_species; ++p_i)
{
if (!simulation->species_state[p_i].boundary_condition)
{
// Boundary conditions are not modified directly by reactions.
// As such, population dx in stochastic regime is not considered.
// For deterministic species, their effective dy/dt should always be 0.
HybridSpecies *spec = &simulation->species_state[p_i];
if( spec->partition_mode == SimulationState::CONTINUOUS ){
current_state[p_i] = result.concentrations[p_i] + population_changes[p_i];
result.concentrations[p_i] = result.concentrations[p_i] + population_changes[p_i];
}else if( spec->partition_mode == SimulationState::DISCRETE ){
current_state[p_i] += population_changes[p_i];
result.concentrations[p_i] = current_state[p_i] + population_changes[p_i];
}
result.concentrations[p_i] = current_state[p_i];
}
}

if (interrupted){
break;
}

// ===== <EVENT HANDLING> =====
if (!event_list.has_active_events())
{
if (event_list.evaluate_triggers(N_VGetArrayPointer(sol.y), next_time))
{
sol.restore_state();
sol.use_events(events, simulation->reaction_state);
sol.enable_root_finder();
use_root_finding=true;
in_event_handling=true;
continue;
}
}
Expand All @@ -424,12 +384,31 @@ namespace Gillespy
double *event_state = N_VGetArrayPointer(sol.y);
if (!event_list.evaluate(event_state, num_species, next_time, event_roots))
{
sol.disable_root_finder();
in_event_handling=false;
use_root_finding = default_use_root_finding; // set to default
}
std::copy(event_state, event_state + num_species, current_state.begin());
}
// ===== </EVENT HANDLING> =====

// "Permanently" update species populations.
// (needs to be below event handling)
for (int p_i = 0; p_i < num_species; ++p_i)
{
if (!simulation->species_state[p_i].boundary_condition)
{
// Boundary conditions are not modified directly by reactions.
// As such, population dx in stochastic regime is not considered.
// For deterministic species, their effective dy/dt should always be 0.
current_state[p_i] = result.concentrations[p_i];
}
}

if (interrupted){
break;
}


// Output the results for this time step.
sol.refresh_state();
simulation->current_time = next_time;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ namespace Gillespy
{
namespace TauHybrid
{
void TauHybridCSolver(HybridSimulation* simulation, std::vector<Event> &events, Logger &logger, double tau_tol, SolverConfiguration config);
void TauHybridCSolver(HybridSimulation* simulation, std::vector<Event> &events, Logger &logger, double tau_tol, SolverConfiguration config, bool use_root_finding);
}
}
47 changes: 27 additions & 20 deletions gillespy2/solvers/cpp/c_base/tau_hybrid_cpp_solver/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,17 @@ Integrator::~Integrator()

IntegrationResults Integrator::integrate(double *t)
{
if (!validate(this, CVode(cvode_mem, *t, y, &this->t, CV_NORMAL)))
int retcode = CVode(cvode_mem, *t, y, &this->t, CV_NORMAL);
if (!validate(this, retcode ))
{
return { nullptr, nullptr };
return { nullptr, nullptr, 0 };
}
*t = this->t;

return {
NV_DATA_S(y),
NV_DATA_S(y) + num_species
NV_DATA_S(y) + num_species,
retcode
};
}

Expand Down Expand Up @@ -162,32 +164,37 @@ IntegrationResults Integrator::integrate(double *t, std::set<int> &event_roots,
return results;
}

unsigned long long num_triggers = data.active_triggers.size();
unsigned long long num_rxn_roots = data.active_reaction_ids.size();
unsigned long long root_size = data.active_triggers.size() + data.active_reaction_ids.size();
int *root_results = new int[root_size];
// check to see if any root we found by the solver
if( results.retcode == CV_ROOT_RETURN ){
// find which roots were found and return them
unsigned long long num_triggers = data.active_triggers.size();
unsigned long long num_rxn_roots = data.active_reaction_ids.size();
unsigned long long root_size = data.active_triggers.size() + data.active_reaction_ids.size();
int *root_results = new int[root_size];

if (validate(this, CVodeGetRootInfo(cvode_mem, root_results)))
{
unsigned long long root_id;
for (root_id = 0; root_id < num_triggers; ++root_id)
if (validate(this, CVodeGetRootInfo(cvode_mem, root_results)))
{
if (root_results[root_id] != 0)
unsigned long long root_id;
for (root_id = 0; root_id < num_triggers; ++root_id)
{
event_roots.insert((int) root_id);
if (root_results[root_id] != 0)
{
event_roots.insert((int) root_id);
}
}
}

for (; root_id < num_rxn_roots; ++root_id)
{
if (root_results[root_id] < 0)
for (; root_id < root_size; ++root_id) // reaction roots
{
reaction_roots.insert(data.active_reaction_ids[root_id]);
if (root_results[root_id] != 0)
{
int rxn_id = root_id - num_triggers;
reaction_roots.insert(data.active_reaction_ids[rxn_id]);
}
}
}
}

delete[] root_results;
delete[] root_results;
}
return results;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ namespace Gillespy
realtype *concentrations;
// reactions: bounded by [num_species, num_species + num_reactions)
realtype *reactions;
int retcode;
};

struct URNGenerator
Expand Down
5 changes: 4 additions & 1 deletion gillespy2/solvers/cpp/tau_hybrid_c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_solver_settings(cls):

def run(self=None, model: Model = None, t: int = None, number_of_trajectories: int = 1, timeout: int = 0,
increment: int = None, seed: int = None, debug: bool = False, profile: bool = False, variables={},
resume=None, live_output: str = None, live_output_options: dict = {}, tau_step: int = .03, tau_tol=0.03, integrator_options: "dict[str, float]" = None, **kwargs):
resume=None, live_output: str = None, live_output_options: dict = {}, tau_step: int = .03, tau_tol=0.03, integrator_options: "dict[str, float]" = None, use_root_finding=False, **kwargs):

"""
:param model: The model on which the solver will operate. (Deprecated)
Expand Down Expand Up @@ -327,6 +327,9 @@ def run(self=None, model: Model = None, t: int = None, number_of_trajectories: i
args = self._make_args(args)
if debug:
args.append("--verbose")
if use_root_finding:
args.append("--use_root_finding")

decoder = IterativeSimDecoder.create_default(number_of_trajectories, number_timesteps, len(self.model.listOfSpecies))

sim_exec = self._build(self.model, self.target, self.variable, False)
Expand Down