diff --git a/bayeux/_src/optimize/optimistix.py b/bayeux/_src/optimize/optimistix.py index 4be810d..de4ccdc 100644 --- a/bayeux/_src/optimize/optimistix.py +++ b/bayeux/_src/optimize/optimistix.py @@ -70,41 +70,11 @@ class BFGS(_OptimistixOptimizer): optimizer = "BFGS" -class Chord(_OptimistixOptimizer): - name = "optimistix_chord" - optimizer = "Chord" - - -class Dogleg(_OptimistixOptimizer): - name = "optimistix_dogleg" - optimizer = "Dogleg" - - -class GaussNewton(_OptimistixOptimizer): - name = "optimistix_gauss_newton" - optimizer = "GaussNewton" - - -class IndirectLevenbergMarquardt(_OptimistixOptimizer): - name = "optimistix_indirect_levenberg_marquardt" - optimizer = "IndirectLevenbergMarquardt" - - -class LevenbergMarquardt(_OptimistixOptimizer): - name = "optimistix_levenberg_marquardt" - optimizer = "LevenbergMarquardt" - - class NelderMead(_OptimistixOptimizer): name = "optimistix_nelder_mead" optimizer = "NelderMead" -class Newton(_OptimistixOptimizer): - name = "optimistix_newton" - optimizer = "Newton" - - class NonlinearCG(_OptimistixOptimizer): name = "optimistix_nonlinear_cg" optimizer = "NonlinearCG" diff --git a/bayeux/_src/optimize/shared.py b/bayeux/_src/optimize/shared.py index 17119d9..c6d787f 100644 --- a/bayeux/_src/optimize/shared.py +++ b/bayeux/_src/optimize/shared.py @@ -51,6 +51,8 @@ def get_optimizer_kwargs(optimizer, kwargs, ignore_required=None): f"{','.join(optimizer_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) + optimizer_kwargs.update( + {k: kwargs[k] for k in optimizer_kwargs if k in kwargs}) return optimizer_kwargs diff --git a/bayeux/optimize/__init__.py b/bayeux/optimize/__init__.py index 458feb9..df1fb5b 100644 --- a/bayeux/optimize/__init__.py +++ b/bayeux/optimize/__init__.py @@ -26,28 +26,6 @@ from bayeux._src.optimize.jaxopt import NonlinearCG __all__.extend(["BFGS", "GradientDescent", "LBFGS", "NonlinearCG"]) -if importlib.util.find_spec("optimistix") is not None: - from bayeux._src.optimize.optimistix import BFGS as optimistix_BFGS - from bayeux._src.optimize.optimistix import Chord - from bayeux._src.optimize.optimistix import Dogleg - from bayeux._src.optimize.optimistix import GaussNewton - from bayeux._src.optimize.optimistix import IndirectLevenbergMarquardt - from bayeux._src.optimize.optimistix import LevenbergMarquardt - from bayeux._src.optimize.optimistix import NelderMead - from bayeux._src.optimize.optimistix import Newton - from bayeux._src.optimize.optimistix import NonlinearCG as optimistix_NonlinearCG - - __all__.extend([ - "optimistix_BFGS", - "Chord", - "Dogleg", - "GaussNewton", - "IndirectLevenbergMarquardt", - "LevenbergMarquardt", - "NelderMead", - "Newton", - "optimistix_NonlinearCG"]) - if importlib.util.find_spec("optax") is not None: from bayeux._src.optimize.optax import AdaBelief from bayeux._src.optimize.optax import Adafactor @@ -92,3 +70,13 @@ "Sm3", "Yogi", ]) + +if importlib.util.find_spec("optimistix") is not None: + from bayeux._src.optimize.optimistix import BFGS as optimistix_BFGS + from bayeux._src.optimize.optimistix import NelderMead + from bayeux._src.optimize.optimistix import NonlinearCG as optimistix_NonlinearCG + + __all__.extend([ + "optimistix_BFGS", + "NelderMead", + "optimistix_NonlinearCG"])