diff --git a/sbmltoodejax/modulegeneration.py b/sbmltoodejax/modulegeneration.py index db01eea..bb941ab 100644 --- a/sbmltoodejax/modulegeneration.py +++ b/sbmltoodejax/modulegeneration.py @@ -13,7 +13,9 @@ def GenerateModel(modelData, outputFilePath, deltaT: float =0.1, atol: float=1e-6, rtol: float = 1e-12, - mxstep: int = 5000000 + mxstep: int = 5000000, + solver_type: str = 'odeint', + diffrax_solver: str = 'Tsit5' ): """ This function takes model data created by :func:`~sbmltoodejax.parse.ParseSBMLFile` and generates a python file containing @@ -125,7 +127,9 @@ def GenerateModel(modelData, outputFilePath, outputFile.write("from functools import partial\n") outputFile.write("from jax import jit, lax, vmap\n") outputFile.write("from jax.experimental.ode import odeint\n") - outputFile.write("import jax.numpy as jnp\n\n") + outputFile.write("import jax.numpy as jnp\n") + outputFile.write("from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n") + outputFile.write("from typing import Any\n\n") outputFile.write("from sbmltoodejax import jaxfuncs\n\n") @@ -512,32 +516,49 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y", outputFile.write("\tatol: float = eqx.static_field()\n") outputFile.write("\trtol: float = eqx.static_field()\n") outputFile.write("\tmxstep: int = eqx.static_field()\n") - outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n\n") + outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n") + outputFile.write("\tsolver_type: str = eqx.static_field()\n") + outputFile.write("\tsolver: Any = eqx.static_field()\n\n") outputFile.write(f"\tdef __init__(self, " f"y_indexes={y_indexes}, " f"w_indexes={w_indexes}, " f"c_indexes={c_indexes}, " - f"atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n") + f"atol={atol}, rtol={rtol}, mxstep={mxstep}, " + f"solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n") outputFile.write("\t\tself.y_indexes = y_indexes\n") outputFile.write("\t\tself.w_indexes = w_indexes\n") - outputFile.write("\t\tself.c_indexes = c_indexes\n\n") - + outputFile.write("\t\tself.c_indexes = c_indexes\n") outputFile.write(f"\t\tself.ratefunc = {RateofSpeciesChangeName}()\n") outputFile.write("\t\tself.rtol = rtol\n") outputFile.write("\t\tself.atol = atol\n") outputFile.write("\t\tself.mxstep = mxstep\n") - - outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n\n") - + outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n") + outputFile.write("\t\tself.solver_type = solver_type\n") + outputFile.write("\t\tif solver_type == 'odeint':\n") + outputFile.write("\t\t\tself.solver = odeint\n") + outputFile.write("\t\telif solver_type == 'diffrax':\n") + outputFile.write("\t\t\tfrom diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n") + outputFile.write("\t\t\tvalid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}\n") + outputFile.write("\t\t\tif diffrax_solver not in valid_solvers:\n") + outputFile.write("\t\t\t\traise ValueError(f'Unknown diffrax solver: {diffrax_solver}')\n") + outputFile.write(f"\t\t\tself.solver = {diffrax_solver}()\n") + outputFile.write("\t\telse:\n") + outputFile.write("\t\t\traise ValueError(f'Unknown solver type: {solver_type}')\n\n") outputFile.write("\t@jit\n") outputFile.write("\tdef __call__(self, y, w, c, t, deltaT):\n") - outputFile.write("\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\t\n") - outputFile.write("\t\tt_new = t + deltaT\t\n") - outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\t\n") - outputFile.write("\t\treturn y_new, w_new, c, t_new\t\n\n") + outputFile.write("\t\tif self.solver_type == 'odeint':\n") + outputFile.write("\t\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\n") + outputFile.write("\t\telse: # diffrax\n") + outputFile.write("\t\t\tterm = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))\n") + outputFile.write("\t\t\ttprev, tnext = t, t + deltaT\n") + outputFile.write("\t\t\tstate = self.solver.init(term, tprev, tnext, y, (w, c))\n") + outputFile.write("\t\t\ty_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)\n") + outputFile.write("\t\tt_new = t + deltaT\n") + outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\n") + outputFile.write("\t\treturn y_new, w_new, c, t_new\n\n") # ================================================================================================================================ @@ -545,9 +566,9 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y", outputFile.write("\tdeltaT: float = eqx.static_field()\n") outputFile.write(f"\tmodelstepfunc: {ModelStepName}\n\n") - outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n") + outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}, solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n") outputFile.write("\t\tself.deltaT = deltaT\n") - outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep)\n\n") + outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)\n\n") outputFile.write("\t@partial(jit, static_argnames=(\"n_steps\",))\n") outputFile.write("\tdef __call__(self, n_steps, " @@ -571,3 +592,13 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y", # ================================================================================================================================ outputFile.close() + + + + + + + + + + diff --git a/sbmltoodejax/parse.py b/sbmltoodejax/parse.py index 1aca702..4def921 100644 --- a/sbmltoodejax/parse.py +++ b/sbmltoodejax/parse.py @@ -1,7 +1,6 @@ import libsbml import os -import sbmltoodepy -from tempfile import NamedTemporaryFile +from sbmltoodepy.parse import * def ParseSBMLFile(file: str): """ @@ -45,16 +44,10 @@ def ParseSBMLFile(file: str): """ if os.path.exists(file): - filePath = file - libsbml.readSBML(filePath) doc = libsbml.readSBML(file) else: - tmp_sbml_file = NamedTemporaryFile(suffix=".xml") - with open(tmp_sbml_file.name, 'w') as f: - f.write(file) doc = libsbml.readSBMLFromString(file) - filePath = tmp_sbml_file.name # Raise an Error if SBML error if doc.getNumErrors() > 0: @@ -65,7 +58,31 @@ def ParseSBMLFile(file: str): if model.getNumEvents() > 0: raise NotImplementedError("Events are not handled") - modelData = sbmltoodepy.parse.ParseSBMLFile(filePath) + modelData = sbmltoodepy.dataclasses.ModelData() + for i in range(model.getNumParameters()): + newParameter = ParseParameterAssignment(i, model.getParameter(i)) + modelData.parameters[newParameter.Id] = newParameter + for i in range(model.getNumCompartments()): + newCompartment = ParseCompartment(i, model.getCompartment(i)) + modelData.compartments[newCompartment.Id] = newCompartment + for i in range(model.getNumSpecies()): + newSpecies = ParseSpecies(i, model.getSpecies(i)) + modelData.species[newSpecies.Id] = newSpecies + for i in range(model.getNumFunctionDefinitions()): + newFunction = ParseFunction(i, model.getFunctionDefinition(i)) + modelData.functions[newFunction.Id] = newFunction + for i in range(model.getNumRules()): + newRule = ParseRule(i,model.getRule(i)) + if type(newRule) == sbmltoodepy.dataclasses.AssignmentRuleData: + modelData.assignmentRules[newRule.Id] = newRule + elif type(newRule) == sbmltoodepy.dataclasses.RateRuleData: + modelData.rateRules[newRule.Id] = newRule + for i in range(model.getNumReactions()): + newReaction = ParseReaction(i, model.getReaction(i)) + modelData.reactions[newReaction.Id] = newReaction + for i in range(model.getNumInitialAssignments()): + newAssignment = ParseInitialAssignment(i, model.getInitialAssignment(i)) + modelData.initialAssignments[newAssignment.Id] = newAssignment return modelData diff --git a/sbmltoodejax/utils.py b/sbmltoodejax/utils.py index 6c77732..46bfc55 100644 --- a/sbmltoodejax/utils.py +++ b/sbmltoodejax/utils.py @@ -5,8 +5,9 @@ def generate_biomodel(model_idx, model_fp="jax_model.py", vary_constant_reactants=False, vary_boundary_reactants=False, - deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000): - """Calls the `sbmltoodejax.modulegeneration.GenerateModel` for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`. + deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000, + solver_type='odeint', diffrax_solver='Tsit5'): + """Calls the `sbmltoodejax.modulegeneration_3.GenerateModel` for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`. Args: model_idx: either an integer, or a valid model id @@ -15,6 +16,8 @@ def generate_biomodel(model_idx, model_fp="jax_model.py", atol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-6. rtol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-12. mxstep (int, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 5000000. + solver_type (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'odeint'. + diffrax_solver (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'Tsit5'. Returns: model_fp (str): the filepath containing the generated python file @@ -23,14 +26,16 @@ def generate_biomodel(model_idx, model_fp="jax_model.py", model_data = ParseSBMLFile(model_xml_body) GenerateModel(model_data, model_fp, vary_constant_reactants=vary_constant_reactants, vary_boundary_reactants=vary_boundary_reactants, - deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep) + deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep, + solver_type=solver_type, diffrax_solver=diffrax_solver) return model_fp def load_biomodel(model_idx, model_fp="jax_model.py", vary_constant_reactants=False, vary_boundary_reactants=False, - deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000): + deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000, + solver_type='odeint', diffrax_solver='Tsit5'): """Calls the generate_biomodel function for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`, then loads and returns the generated `model` module and `y0`, `w0`, `c` variables. @@ -52,7 +57,8 @@ def load_biomodel(model_idx, model_fp="jax_model.py", """ model_fp = generate_biomodel(model_idx, model_fp=model_fp, vary_constant_reactants=vary_constant_reactants, vary_boundary_reactants=vary_boundary_reactants, - deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep) + deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep, + solver_type=solver_type, diffrax_solver=diffrax_solver) spec = importlib.util.spec_from_file_location("JaxModelSpec", model_fp) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -63,3 +69,5 @@ def load_biomodel(model_idx, model_fp="jax_model.py", c = getattr(module, "c") return model, y0, w0, c + + diff --git a/test/jax_model.py b/test/jax_model.py index 4d4eb0f..337cce2 100644 --- a/test/jax_model.py +++ b/test/jax_model.py @@ -3,6 +3,8 @@ from jax import jit, lax, vmap from jax.experimental.ode import odeint import jax.numpy as jnp +from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston +from typing import Any from sbmltoodejax import jaxfuncs @@ -90,34 +92,52 @@ class ModelStep(eqx.Module): rtol: float = eqx.static_field() mxstep: int = eqx.static_field() assignmentfunc: AssignmentRule + solver_type: str = eqx.static_field() + solver: Any = eqx.static_field() - def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000): + def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Dopri8'): self.y_indexes = y_indexes self.w_indexes = w_indexes self.c_indexes = c_indexes - self.ratefunc = RateofSpeciesChange() self.rtol = rtol self.atol = atol self.mxstep = mxstep self.assignmentfunc = AssignmentRule() + self.solver_type = solver_type + if solver_type == 'odeint': + self.solver = odeint + elif solver_type == 'diffrax': + from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston + valid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'} + if diffrax_solver not in valid_solvers: + raise ValueError(f'Unknown diffrax solver: {diffrax_solver}') + self.solver = Dopri8() + else: + raise ValueError(f'Unknown solver type: {solver_type}') @jit def __call__(self, y, w, c, t, deltaT): - y_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1] - t_new = t + deltaT - w_new = self.assignmentfunc(y_new, w, c, t_new) - return y_new, w_new, c, t_new + if self.solver_type == 'odeint': + y_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1] + else: # diffrax + term = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args)) + tprev, tnext = t, t + deltaT + state = self.solver.init(term, tprev, tnext, y, (w, c)) + y_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False) + t_new = t + deltaT + w_new = self.assignmentfunc(y_new, w, c, t_new) + return y_new, w_new, c, t_new class ModelRollout(eqx.Module): deltaT: float = eqx.static_field() modelstepfunc: ModelStep - def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000): + def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Dopri8'): self.deltaT = deltaT - self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep) + self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver) @partial(jit, static_argnames=("n_steps",)) def __call__(self, n_steps, y0=jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0]), w0=jnp.array([]), c=jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0]), t0=0.0):