Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cloudpickle for serializing NegLogParameterPriors #1467

Merged
merged 7 commits into from
Sep 26, 2024
9 changes: 9 additions & 0 deletions pypesto/objective/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Callable, Union

import cloudpickle
import numpy as np

from .. import C
Expand Down Expand Up @@ -67,6 +68,14 @@ def __init__(
self.prior_list = prior_list
super().__init__(x_names)

def __getstate__(self):
"""Get state using cloudpickle."""
return cloudpickle.dumps(self.__dict__)

def __setstate__(self, state):
"""Set state using cloudpickle."""
self.__dict__.update(cloudpickle.loads(state))

def call_unprocessed(
self,
x: np.ndarray,
Expand Down
16 changes: 16 additions & 0 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,22 @@ def test_ess_multiprocess(problem, request):

from pypesto.optimize.ess import ESSOptimizer, FunctionEvaluatorMP, RefSet

# augment objective with parameter prior to check it's copyable
# https://github.com/ICB-DCM/pyPESTO/issues/1465
# https://github.com/ICB-DCM/pyPESTO/pull/1467
problem.objective = pypesto.objective.AggregatedObjective(
[
problem.objective,
pypesto.objective.NegLogParameterPriors(
[
pypesto.objective.get_parameter_prior_dict(
0, "uniform", [0, 1], "lin"
)
]
),
]
)

ess = ESSOptimizer(
max_iter=20,
# also test passing a callable as local_optimizer
Expand Down