Skip to content

Commit

Permalink
Merge pull request #693 from StochSS/expression-error-handling
Browse files Browse the repository at this point in the history
Improve expression error handling
  • Loading branch information
seanebum authored Jan 27, 2022
2 parents fbb02f8 + 70cd5c0 commit eae9251
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 18 deletions.
4 changes: 2 additions & 2 deletions gillespy2/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, name=None, variable=None, expression=None):
'valid string expression')

def __str__(self):
return self.variable.name + ': ' + self.expression
return f"{self.variable}: {self.expression}"

class EventTrigger(Jsonify):
"""
Expand Down Expand Up @@ -261,4 +261,4 @@ def add_assignment(self, assignment):
else:
raise ModelError("Unexpected parameter for add_assignment. Parameter must be EventAssignment or list of "
"EventAssignments")
return assignment
return assignment
2 changes: 1 addition & 1 deletion gillespy2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def get_element(self, ename):
return self.get_assignment_rule(ename)
if ename in self.listOfFunctionDefinitions:
return self.get_function_definition(ename)
return 'Element not found!'
raise ModelError(f"model.get_element(): element={ename} not found")


def get_best_solver(self):
Expand Down
32 changes: 21 additions & 11 deletions gillespy2/solvers/cpp/build/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,25 @@ def __init__(self,
self.namespace = dict({}) if namespace is None else namespace
self.blacklist = dict({}) if blacklist is None else blacklist
self.sanitize = sanitize
self.invalid_names = []
self.invalid_operators = []
self.invalid_names = set()
self.invalid_operators = set()

def check_blacklist(self, operator):
operator = type(operator)
if operator in self.blacklist:
self.invalid_operators.append(str(self.blacklist.get(operator)))
self.invalid_operators.add(str(self.blacklist.get(operator)))

def visit_Name(self, node: "ast.Name"):
if node.id not in self.namespace:
self.invalid_names.append(node.id)
self.invalid_names.add(node.id)
elif self.sanitize:
node.id = self.namespace.get(node.id)
self.generic_visit(node)
return node

def visit_Call(self, node: "ast.Call"):
if node.func.id not in self.namespace:
self.invalid_names.append(node.func.id)
self.invalid_names.add(node.func.id)
elif self.sanitize:
node.func.id = self.namespace.get(node.func.id)
self.generic_visit(node)
Expand Down Expand Up @@ -199,15 +199,25 @@ def validate(self, statement: "str") -> "ExpressionResults":

return ExpressionResults(invalid_names=validator.invalid_names, invalid_operators=validator.invalid_operators)

def __get_expr(self, converter: "ExpressionConverter") -> "Optional[str]":
def __get_expr(self, statement: "str", converter: "ExpressionConverter") -> "Optional[str]":
validator = Expression.ValidationVisitor(self.namespace, self.blacklist, self.sanitize)
validator.visit(converter.tree)

failures_found = []

if validator.invalid_operators:
return None
base_msg = "Blacklisted operator"
base_msg = f"{base_msg}s" if len(validator.invalid_operators) > 1 else base_msg
failures_found.append(f"{base_msg}: {','.join(validator.invalid_operators)}")

if validator.invalid_names:
return None
base_msg = "Cannot resolve species name"
base_msg = f"{base_msg}s" if len(validator.invalid_names) > 1 else base_msg
failures_found.append(f"{base_msg}: {','.join(validator.invalid_names)}")

if len(failures_found) > 0:
raise SyntaxError(f"Invalid GillesPy2 expression \"{statement}\"\n"
+ "\n".join([f"* {msg}" for msg in failures_found]))

return converter.get_str()

Expand All @@ -220,7 +230,7 @@ def getexpr_python(self, statement: "str") -> "Optional[str]":
:returns: Python expression string, if valid. Returns None if validation fails.
"""
expr = ast.parse(statement)
return self.__get_expr(PythonConverter(expr))
return self.__get_expr(statement, PythonConverter(expr))

def getexpr_cpp(self, statement: "str") -> "Optional[str]":
"""
Expand All @@ -232,7 +242,7 @@ def getexpr_cpp(self, statement: "str") -> "Optional[str]":
"""
statement = ExpressionConverter.convert_str(statement)
expr = ast.parse(statement)
return self.__get_expr(CppConverter(expr))
return self.__get_expr(statement, CppConverter(expr))


class ExpressionResults:
Expand All @@ -241,7 +251,7 @@ class ExpressionResults:
Any expression items which indicate an invalid expression are listed on an ExpressionResults instance.
Empty lists indicate that the expression is valid.
"""
def __init__(self, invalid_names: "list[str]" = None, invalid_operators: "list[str]" = None, is_valid=True):
def __init__(self, invalid_names: "set[str]" = None, invalid_operators: "set[str]" = None, is_valid=True):
"""
Container struct for returning the results of expression validation.
Expand Down
7 changes: 7 additions & 0 deletions gillespy2/solvers/cpp/build/template_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class SanitizedModel:
"t": "t",
}

# Global functions that aren't present in the `math` package,
# as well as functions in Python that have a different name in C++.
function_map = {
"abs": "abs",
}

def __init__(self, model: Model, variable=False):
self.model = model
self.variable = variable
Expand Down Expand Up @@ -68,6 +74,7 @@ def __init__(self, model: Model, variable=False):
# All "system" namespace entries should always be first.
# Otherwise, user-defined identifiers (like, for example, "gamma") might get overwritten.
**{name: name for name in math.__dict__.keys()},
**self.function_map,
**self.species_names,
**self.parameter_names,
**self.reserved_names,
Expand Down
14 changes: 13 additions & 1 deletion gillespy2/solvers/cpp/tau_hybrid_c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def __create_options(cls, sanitized_model: "SanitizedModel") -> "SanitizedModel"
elif variable in sanitized_model.model.listOfParameters:
variable = sanitized_model.model.listOfParameters.get(variable)
else:
raise ValueError(f"Invalid event assignment {assign}: received name {variable} "
raise ValueError(f"Error in event={event} "
f"Invalid event assignment {assign}: received name {variable} "
f"Must match the name of a valid Species or Parameter.")

if isinstance(variable, gillespy2.Species):
Expand All @@ -108,6 +109,17 @@ def __create_options(cls, sanitized_model: "SanitizedModel") -> "SanitizedModel"
assignments.append(str(assign_id))
event_assignment_list.append(assign_str)
assign_id += 1
# Check for "None"s
for a in assignments:
if a is None: raise Exception(f"assignment={a} is None in event={event}")
if event_id is None: raise Exception(f"event_id is None in event={event}")
if trigger is None: raise Exception(f"trigger is None in event={event}")
if delay is None: raise Exception(f"delay is None in event={event}")
if priority is None: raise Exception(f"priority is None in event={event}")
if use_trigger is None: raise Exception(f"use_trigger is None in event={event}")
if use_persist is None: raise Exception(f"use_persist is None in event={event}")
if initial_value is None: raise Exception(f"initial_value is None in event={event}")

assignments: "str" = " AND ".join(assignments)
event_list.append(
f"EVENT("
Expand Down
8 changes: 6 additions & 2 deletions gillespy2/solvers/utilities/solverutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ast # for dependency graphing
import numpy as np
from gillespy2.core import log, Species
from gillespy2.core import ModelError

"""
NUMPY SOLVER UTILITIES BELOW
Expand Down Expand Up @@ -143,8 +144,11 @@ def species_parse(model, custom_prop_fun):

class SpeciesParser(ast.NodeTransformer):
def visit_Name(self, node):
if isinstance(model.get_element(node.id), Species):
parsed_species.append(model.get_element(node.id))
try:
if isinstance(model.get_element(node.id), Species):
parsed_species.append(model.get_element(node.id))
except ModelError:
pass

expr = custom_prop_fun
expr = ast.parse(expr, mode='eval')
Expand Down
11 changes: 10 additions & 1 deletion test/test_c_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gillespy2.solvers.cpp import SSACSolver, ODECSolver, TauLeapingCSolver
from gillespy2.solvers.cpp import TauHybridCSolver
from gillespy2.solvers.cpp.build.expression import Expression, ExpressionConverter
from gillespy2.solvers.cpp.build.template_gen import SanitizedModel


class ExpressionTestCase:
Expand Down Expand Up @@ -85,6 +86,10 @@ class TestCSolvers(unittest.TestCase):
ExpressionTestCase({"x": "x", "y": "y", "z": "z"}, "(x^2/y^2/z^2)/x^2/y^2/z^2**1/x**1/y**1/z", [
[5.1, 0.1, 2.0], [0.1, 5.1, 2.0], [2.0, 0.1, 5.1], [2.0, 5.1, 0.1],
]),
# Known, builtin math expression functions work.
ExpressionTestCase({"x": "x"}, "abs(x)", [
[100.0], [100], [-100.0], [-100], [0],
]),
]
comparisons = [
# Asserts that single comparison expressions work.
Expand Down Expand Up @@ -189,7 +194,11 @@ def run(args: "list[str]") -> str:
def test_expressions(expressions: "list[ExpressionTestCase]", use_bool=False):
for entry in expressions:
expression = ExpressionConverter.convert_str(entry.expression)
expr = Expression(namespace=entry.args)
expr = Expression(namespace={
**SanitizedModel.reserved_names,
**SanitizedModel.function_map,
**entry.args,
})
cpp_expr = expr.getexpr_cpp(expression)
with self.subTest(msg="Evaluating converted C expressions",
expression=entry.expression,
Expand Down

0 comments on commit eae9251

Please sign in to comment.