Skip to content

Commit

Permalink
Merge pull request #7 from DylanEsguerra/my-feature
Browse files Browse the repository at this point in the history
Diffrax Support for Explicit Runge--Kutta (ERK) methods
  • Loading branch information
mayalenE authored Dec 4, 2024
2 parents c709183 + f23a8aa commit 86e5f95
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 37 deletions.
61 changes: 46 additions & 15 deletions sbmltoodejax/modulegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def GenerateModel(modelData, outputFilePath,
deltaT: float =0.1,
atol: float=1e-6,
rtol: float = 1e-12,
mxstep: int = 5000000
mxstep: int = 5000000,
solver_type: str = 'odeint',
diffrax_solver: str = 'Tsit5'
):
"""
This function takes model data created by :func:`~sbmltoodejax.parse.ParseSBMLFile` and generates a python file containing
Expand Down Expand Up @@ -125,7 +127,9 @@ def GenerateModel(modelData, outputFilePath,
outputFile.write("from functools import partial\n")
outputFile.write("from jax import jit, lax, vmap\n")
outputFile.write("from jax.experimental.ode import odeint\n")
outputFile.write("import jax.numpy as jnp\n\n")
outputFile.write("import jax.numpy as jnp\n")
outputFile.write("from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n")
outputFile.write("from typing import Any\n\n")
outputFile.write("from sbmltoodejax import jaxfuncs\n\n")


Expand Down Expand Up @@ -512,42 +516,59 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y",
outputFile.write("\tatol: float = eqx.static_field()\n")
outputFile.write("\trtol: float = eqx.static_field()\n")
outputFile.write("\tmxstep: int = eqx.static_field()\n")
outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n\n")
outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n")
outputFile.write("\tsolver_type: str = eqx.static_field()\n")
outputFile.write("\tsolver: Any = eqx.static_field()\n\n")

outputFile.write(f"\tdef __init__(self, "
f"y_indexes={y_indexes}, "
f"w_indexes={w_indexes}, "
f"c_indexes={c_indexes}, "
f"atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n")
f"atol={atol}, rtol={rtol}, mxstep={mxstep}, "
f"solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n")

outputFile.write("\t\tself.y_indexes = y_indexes\n")
outputFile.write("\t\tself.w_indexes = w_indexes\n")
outputFile.write("\t\tself.c_indexes = c_indexes\n\n")

outputFile.write("\t\tself.c_indexes = c_indexes\n")
outputFile.write(f"\t\tself.ratefunc = {RateofSpeciesChangeName}()\n")
outputFile.write("\t\tself.rtol = rtol\n")
outputFile.write("\t\tself.atol = atol\n")
outputFile.write("\t\tself.mxstep = mxstep\n")

outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n\n")

outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n")
outputFile.write("\t\tself.solver_type = solver_type\n")
outputFile.write("\t\tif solver_type == 'odeint':\n")
outputFile.write("\t\t\tself.solver = odeint\n")
outputFile.write("\t\telif solver_type == 'diffrax':\n")
outputFile.write("\t\t\tfrom diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n")
outputFile.write("\t\t\tvalid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}\n")
outputFile.write("\t\t\tif diffrax_solver not in valid_solvers:\n")
outputFile.write("\t\t\t\traise ValueError(f'Unknown diffrax solver: {diffrax_solver}')\n")
outputFile.write(f"\t\t\tself.solver = {diffrax_solver}()\n")
outputFile.write("\t\telse:\n")
outputFile.write("\t\t\traise ValueError(f'Unknown solver type: {solver_type}')\n\n")

outputFile.write("\t@jit\n")
outputFile.write("\tdef __call__(self, y, w, c, t, deltaT):\n")
outputFile.write("\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\t\n")
outputFile.write("\t\tt_new = t + deltaT\t\n")
outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\t\n")
outputFile.write("\t\treturn y_new, w_new, c, t_new\t\n\n")
outputFile.write("\t\tif self.solver_type == 'odeint':\n")
outputFile.write("\t\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\n")
outputFile.write("\t\telse: # diffrax\n")
outputFile.write("\t\t\tterm = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))\n")
outputFile.write("\t\t\ttprev, tnext = t, t + deltaT\n")
outputFile.write("\t\t\tstate = self.solver.init(term, tprev, tnext, y, (w, c))\n")
outputFile.write("\t\t\ty_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)\n")
outputFile.write("\t\tt_new = t + deltaT\n")
outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\n")
outputFile.write("\t\treturn y_new, w_new, c, t_new\n\n")

# ================================================================================================================================

outputFile.write("class " + ModelRolloutName + "(eqx.Module):\n")
outputFile.write("\tdeltaT: float = eqx.static_field()\n")
outputFile.write(f"\tmodelstepfunc: {ModelStepName}\n\n")

outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n")
outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}, solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n")
outputFile.write("\t\tself.deltaT = deltaT\n")
outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep)\n\n")
outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)\n\n")

outputFile.write("\t@partial(jit, static_argnames=(\"n_steps\",))\n")
outputFile.write("\tdef __call__(self, n_steps, "
Expand All @@ -571,3 +592,13 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y",

# ================================================================================================================================
outputFile.close()










35 changes: 26 additions & 9 deletions sbmltoodejax/parse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import libsbml
import os
import sbmltoodepy
from tempfile import NamedTemporaryFile
from sbmltoodepy.parse import *

def ParseSBMLFile(file: str):
"""
Expand Down Expand Up @@ -45,16 +44,10 @@ def ParseSBMLFile(file: str):
"""

if os.path.exists(file):
filePath = file
libsbml.readSBML(filePath)
doc = libsbml.readSBML(file)

else:
tmp_sbml_file = NamedTemporaryFile(suffix=".xml")
with open(tmp_sbml_file.name, 'w') as f:
f.write(file)
doc = libsbml.readSBMLFromString(file)
filePath = tmp_sbml_file.name

# Raise an Error if SBML error
if doc.getNumErrors() > 0:
Expand All @@ -65,7 +58,31 @@ def ParseSBMLFile(file: str):
if model.getNumEvents() > 0:
raise NotImplementedError("Events are not handled")

modelData = sbmltoodepy.parse.ParseSBMLFile(filePath)
modelData = sbmltoodepy.dataclasses.ModelData()
for i in range(model.getNumParameters()):
newParameter = ParseParameterAssignment(i, model.getParameter(i))
modelData.parameters[newParameter.Id] = newParameter
for i in range(model.getNumCompartments()):
newCompartment = ParseCompartment(i, model.getCompartment(i))
modelData.compartments[newCompartment.Id] = newCompartment
for i in range(model.getNumSpecies()):
newSpecies = ParseSpecies(i, model.getSpecies(i))
modelData.species[newSpecies.Id] = newSpecies
for i in range(model.getNumFunctionDefinitions()):
newFunction = ParseFunction(i, model.getFunctionDefinition(i))
modelData.functions[newFunction.Id] = newFunction
for i in range(model.getNumRules()):
newRule = ParseRule(i,model.getRule(i))
if type(newRule) == sbmltoodepy.dataclasses.AssignmentRuleData:
modelData.assignmentRules[newRule.Id] = newRule
elif type(newRule) == sbmltoodepy.dataclasses.RateRuleData:
modelData.rateRules[newRule.Id] = newRule
for i in range(model.getNumReactions()):
newReaction = ParseReaction(i, model.getReaction(i))
modelData.reactions[newReaction.Id] = newReaction
for i in range(model.getNumInitialAssignments()):
newAssignment = ParseInitialAssignment(i, model.getInitialAssignment(i))
modelData.initialAssignments[newAssignment.Id] = newAssignment

return modelData

Expand Down
18 changes: 13 additions & 5 deletions sbmltoodejax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

def generate_biomodel(model_idx, model_fp="jax_model.py",
vary_constant_reactants=False, vary_boundary_reactants=False,
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000):
"""Calls the `sbmltoodejax.modulegeneration.GenerateModel` for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`.
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000,
solver_type='odeint', diffrax_solver='Tsit5'):
"""Calls the `sbmltoodejax.modulegeneration_3.GenerateModel` for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`.
Args:
model_idx: either an integer, or a valid model id
Expand All @@ -15,6 +16,8 @@ def generate_biomodel(model_idx, model_fp="jax_model.py",
atol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-6.
rtol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-12.
mxstep (int, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 5000000.
solver_type (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'odeint'.
diffrax_solver (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'Tsit5'.
Returns:
model_fp (str): the filepath containing the generated python file
Expand All @@ -23,14 +26,16 @@ def generate_biomodel(model_idx, model_fp="jax_model.py",
model_data = ParseSBMLFile(model_xml_body)
GenerateModel(model_data, model_fp,
vary_constant_reactants=vary_constant_reactants, vary_boundary_reactants=vary_boundary_reactants,
deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep)
deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep,
solver_type=solver_type, diffrax_solver=diffrax_solver)

return model_fp


def load_biomodel(model_idx, model_fp="jax_model.py",
vary_constant_reactants=False, vary_boundary_reactants=False,
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000):
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000,
solver_type='odeint', diffrax_solver='Tsit5'):
"""Calls the generate_biomodel function for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`,
then loads and returns the generated `model` module and `y0`, `w0`, `c` variables.
Expand All @@ -52,7 +57,8 @@ def load_biomodel(model_idx, model_fp="jax_model.py",
"""
model_fp = generate_biomodel(model_idx, model_fp=model_fp,
vary_constant_reactants=vary_constant_reactants, vary_boundary_reactants=vary_boundary_reactants,
deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep)
deltaT=deltaT, atol=atol, rtol=rtol, mxstep=mxstep,
solver_type=solver_type, diffrax_solver=diffrax_solver)
spec = importlib.util.spec_from_file_location("JaxModelSpec", model_fp)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
Expand All @@ -63,3 +69,5 @@ def load_biomodel(model_idx, model_fp="jax_model.py",
c = getattr(module, "c")

return model, y0, w0, c


36 changes: 28 additions & 8 deletions test/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from jax import jit, lax, vmap
from jax.experimental.ode import odeint
import jax.numpy as jnp
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
from typing import Any

from sbmltoodejax import jaxfuncs

Expand Down Expand Up @@ -90,34 +92,52 @@ class ModelStep(eqx.Module):
rtol: float = eqx.static_field()
mxstep: int = eqx.static_field()
assignmentfunc: AssignmentRule
solver_type: str = eqx.static_field()
solver: Any = eqx.static_field()

def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000):
def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Dopri8'):

self.y_indexes = y_indexes
self.w_indexes = w_indexes
self.c_indexes = c_indexes

self.ratefunc = RateofSpeciesChange()
self.rtol = rtol
self.atol = atol
self.mxstep = mxstep
self.assignmentfunc = AssignmentRule()
self.solver_type = solver_type
if solver_type == 'odeint':
self.solver = odeint
elif solver_type == 'diffrax':
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
valid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}
if diffrax_solver not in valid_solvers:
raise ValueError(f'Unknown diffrax solver: {diffrax_solver}')
self.solver = Dopri8()
else:
raise ValueError(f'Unknown solver type: {solver_type}')

@jit
def __call__(self, y, w, c, t, deltaT):
y_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]
t_new = t + deltaT
w_new = self.assignmentfunc(y_new, w, c, t_new)
return y_new, w_new, c, t_new
if self.solver_type == 'odeint':
y_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]
else: # diffrax
term = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))
tprev, tnext = t, t + deltaT
state = self.solver.init(term, tprev, tnext, y, (w, c))
y_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)
t_new = t + deltaT
w_new = self.assignmentfunc(y_new, w, c, t_new)
return y_new, w_new, c, t_new

class ModelRollout(eqx.Module):
deltaT: float = eqx.static_field()
modelstepfunc: ModelStep

def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000):
def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Dopri8'):

self.deltaT = deltaT
self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep)
self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)

@partial(jit, static_argnames=("n_steps",))
def __call__(self, n_steps, y0=jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0]), w0=jnp.array([]), c=jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0]), t0=0.0):
Expand Down

0 comments on commit 86e5f95

Please sign in to comment.