Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve non-linear systems that are not kinetic schemes. #1724

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 47 additions & 52 deletions modcc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool Module::semantic() {
// symbol table.
// Returns false if a symbol name clashes with the name of a symbol that
// is already in the symbol table.
bool linear = true;
bool linear_homogeneous = true;
std::vector<std::string> state_vars;
auto move_symbols = [this] (std::vector<symbol_ptr>& symbol_list) {
for(auto& symbol: symbol_list) {
Expand Down Expand Up @@ -388,61 +388,56 @@ bool Module::semantic() {
continue;
}
found_solve = true;
std::unique_ptr<SolverVisitorBase> solver;

// If the derivative block is a kinetic block, perform kinetic rewrite first.
auto deriv = solve_expression->procedure();
auto solve_body = deriv->body()->clone();
if (deriv->kind()==procedureKind::kinetic) {
solve_body = kinetic_rewrite(deriv->body());
}
else if (deriv->kind()==procedureKind::linear) {
solve_body = linear_rewrite(deriv->body(), state_vars);
}

// Calculate linearity and homogeneity of the statements in the derivative block.
bool linear = true;
bool homogeneous = true;
for (auto& s: solve_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
homogeneous &= r.is_homogeneous;
}
}
linear_homogeneous &= (linear & homogeneous);

// Construct solver based on system kind, linearity and solver method.
std::unique_ptr<SolverVisitorBase> solver;
switch(solve_expression->method()) {
case solverMethod::cnexp:
solver = std::make_unique<CnexpSolverVisitor>();
break;
case solverMethod::sparse: {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
break;
}
case solverMethod::none:
solver = std::make_unique<DirectSolverVisitor>();
break;
}

// If the derivative block is a kinetic block, perform kinetic
// rewrite first.

auto deriv = solve_expression->procedure();

if (deriv->kind()==procedureKind::kinetic) {
auto rewrite_body = kinetic_rewrite(deriv->body());
bool linear_kinetic = true;

for (auto& s: rewrite_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear_kinetic &= r.is_linear;
}
if (linear) {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
}

if (!linear_kinetic) {
else {
solver = std::make_unique<SparseNonlinearSolverVisitor>();
}

rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
}
else if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
auto rewrite_body = linear_rewrite(deriv->body(), state_vars);

rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
break;
}
else {
deriv->body()->accept(solver.get());
for (auto& s: deriv->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
}
case solverMethod::none:
if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
}
else {
solver = std::make_unique<DirectSolverVisitor>();
}
break;
}
// Perform semantic analysis on the solve block statements and solve them.
solve_body->semantic(advance_state_scope);
solve_body->accept(solver.get());

if (auto solve_block = solver->as_block(false)) {
// Check that we didn't solve an already solved variable.
Expand Down Expand Up @@ -490,8 +485,8 @@ bool Module::semantic() {
for (auto& s: breakpoint->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
linear_homogeneous &= r.is_linear;
linear_homogeneous &= r.is_homogeneous;
}
}

Expand All @@ -517,27 +512,27 @@ bool Module::semantic() {
for (const auto &id: state_vars) {
auto coef = symbolic_pdiff(s->is_assignment()->rhs(), id);
if(!coef) {
linear = false;
linear_homogeneous = false;
continue;
}
if(coef->is_number()) {
if (!s->is_assignment()->lhs()->is_identifier()) {
error(pprintf("Left hand side of assignment is not an identifier"));
return false;
}
linear &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
linear_homogeneous &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
}
else {
linear = false;
linear_homogeneous = false;
}
}
}
}
}
}
linear_ = linear;
linear_ = linear_homogeneous;

post_events_ = has_symbol("post_event", symbolKind::procedure);
if (post_events_) {
Expand Down