Skip to content

Commit

Permalink
make diffrax the default solver type (faster)
Browse files Browse the repository at this point in the history
  • Loading branch information
etch4966 committed Dec 6, 2024
1 parent bdb9c4f commit f83e67b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
14 changes: 2 additions & 12 deletions sbmltoodejax/modulegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def GenerateModel(modelData, outputFilePath,
atol: float=1e-6,
rtol: float = 1e-12,
mxstep: int = 5000000,
solver_type: str = 'odeint',
solver_type: str = 'diffrax',
diffrax_solver: str = 'Tsit5'
):
"""
Expand Down Expand Up @@ -591,14 +591,4 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y",
outputFile.write("\t\treturn ys, ws, ts\n\n")

# ================================================================================================================================
outputFile.close()










outputFile.close()
8 changes: 5 additions & 3 deletions sbmltoodejax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def generate_biomodel(model_idx, model_fp="jax_model.py",
vary_constant_reactants=False, vary_boundary_reactants=False,
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000,
solver_type='odeint', diffrax_solver='Tsit5'):
solver_type='diffrax', diffrax_solver='Tsit5'):
"""Calls the `sbmltoodejax.modulegeneration_3.GenerateModel` for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`.
Args:
Expand All @@ -16,7 +16,7 @@ def generate_biomodel(model_idx, model_fp="jax_model.py",
atol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-6.
rtol (float, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 1e-12.
mxstep (int, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 5000000.
solver_type (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'odeint'.
solver_type (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'diffrax'.
diffrax_solver (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'Tsit5'.
Returns:
Expand All @@ -35,7 +35,7 @@ def generate_biomodel(model_idx, model_fp="jax_model.py",
def load_biomodel(model_idx, model_fp="jax_model.py",
vary_constant_reactants=False, vary_boundary_reactants=False,
deltaT=0.1, atol=1e-6, rtol=1e-12, mxstep=5000000,
solver_type='odeint', diffrax_solver='Tsit5'):
solver_type='diffrax', diffrax_solver='Tsit5'):
"""Calls the generate_biomodel function for a SBML model hosted on the BioModel website and indexed by the provided `model_idx`,
then loads and returns the generated `model` module and `y0`, `w0`, `c` variables.
Expand All @@ -46,6 +46,8 @@ def load_biomodel(model_idx, model_fp="jax_model.py",
atol (float, optional): parameter passed to `generate_biomodel`. Default to 1e-6.
rtol (float, optional): parameter passed to `generate_biomodel`. Default to 1e-12.
mxstep (int, optional): parameter passed to `generate_biomodel`. Default to 5000000.
solver_type (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'diffrax'.
diffrax_solver (str, optional): parameter passed to `sbmltoodejax.modulegeneration.GenerateModel`. Default to 'Tsit5'.
Returns:
tuple containing
Expand Down

0 comments on commit f83e67b

Please sign in to comment.