Skip to content

Commit

Permalink
Solve non-linear systems that are not kinetic schemes.
Browse files Browse the repository at this point in the history
  • Loading branch information
noraabiakar committed Oct 18, 2021
1 parent 489dc20 commit 7db8ece
Showing 1 changed file with 47 additions and 52 deletions.
99 changes: 47 additions & 52 deletions modcc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool Module::semantic() {
// symbol table.
// Returns false if a symbol name clashes with the name of a symbol that
// is already in the symbol table.
bool linear = true;
bool linear_homogeneous = true;
std::vector<std::string> state_vars;
auto move_symbols = [this] (std::vector<symbol_ptr>& symbol_list) {
for(auto& symbol: symbol_list) {
Expand Down Expand Up @@ -388,61 +388,56 @@ bool Module::semantic() {
continue;
}
found_solve = true;
std::unique_ptr<SolverVisitorBase> solver;

// If the derivative block is a kinetic block, perform kinetic rewrite first.
auto deriv = solve_expression->procedure();
auto solve_body = deriv->body()->clone();
if (deriv->kind()==procedureKind::kinetic) {
solve_body = kinetic_rewrite(deriv->body());
}
else if (deriv->kind()==procedureKind::linear) {
solve_body = linear_rewrite(deriv->body(), state_vars);
}

// Calculate linearity and homogeneity of the statements in the derivative block.
bool linear = true;
bool homogeneous = true;
for (auto& s: solve_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
homogeneous &= r.is_homogeneous;
}
}
linear_homogeneous &= (linear & homogeneous);

// Construct solver based on system kind, linearity and solver method.
std::unique_ptr<SolverVisitorBase> solver;
switch(solve_expression->method()) {
case solverMethod::cnexp:
solver = std::make_unique<CnexpSolverVisitor>();
break;
case solverMethod::sparse: {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
break;
}
case solverMethod::none:
solver = std::make_unique<DirectSolverVisitor>();
break;
}

// If the derivative block is a kinetic block, perform kinetic
// rewrite first.

auto deriv = solve_expression->procedure();

if (deriv->kind()==procedureKind::kinetic) {
auto rewrite_body = kinetic_rewrite(deriv->body());
bool linear_kinetic = true;

for (auto& s: rewrite_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear_kinetic &= r.is_linear;
}
if (linear) {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
}

if (!linear_kinetic) {
else {
solver = std::make_unique<SparseNonlinearSolverVisitor>();
}

rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
}
else if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
auto rewrite_body = linear_rewrite(deriv->body(), state_vars);

rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
break;
}
else {
deriv->body()->accept(solver.get());
for (auto& s: deriv->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
}
case solverMethod::none:
if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
}
else {
solver = std::make_unique<DirectSolverVisitor>();
}
break;
}
// Perform semantic analysis on the solve block statements and solve them.
solve_body->semantic(advance_state_scope);
solve_body->accept(solver.get());

if (auto solve_block = solver->as_block(false)) {
// Check that we didn't solve an already solved variable.
Expand Down Expand Up @@ -490,8 +485,8 @@ bool Module::semantic() {
for (auto& s: breakpoint->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
linear_homogeneous &= r.is_linear;
linear_homogeneous &= r.is_homogeneous;
}
}

Expand All @@ -517,27 +512,27 @@ bool Module::semantic() {
for (const auto &id: state_vars) {
auto coef = symbolic_pdiff(s->is_assignment()->rhs(), id);
if(!coef) {
linear = false;
linear_homogeneous = false;
continue;
}
if(coef->is_number()) {
if (!s->is_assignment()->lhs()->is_identifier()) {
error(pprintf("Left hand side of assignment is not an identifier"));
return false;
}
linear &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
linear_homogeneous &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
}
else {
linear = false;
linear_homogeneous = false;
}
}
}
}
}
}
linear_ = linear;
linear_ = linear_homogeneous;

post_events_ = has_symbol("post_event", symbolKind::procedure);
if (post_events_) {
Expand Down

0 comments on commit 7db8ece

Please sign in to comment.