diff --git a/fedeca/strategies/bootstraper.py b/fedeca/strategies/bootstraper.py index 5ad18b02..789e14f7 100644 --- a/fedeca/strategies/bootstraper.py +++ b/fedeca/strategies/bootstraper.py @@ -62,7 +62,7 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = obj = strategy.algo key = "algo" method_args_dict = inspect.signature(getattr(obj, method_name)).parameters - if not (("shared_states" in method_args_dict) or ("shared_state" in method_args_dict)): + if not (("shared_states" in method_args_dict) or ("shared_state" in method_args_dict)) or (method_name in ["save_local_state", "load_local_state"]): continue # We create a copy of all methods with the original suffix to avoid name # collision and infinite recursion when decorating the old methods @@ -74,7 +74,8 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = elif "shared_states" in method_args_dict: aggregations_names[key].append(method_name) else: - raise ValueError("Method {} has a shared_state.s argument but isn't respecting conventions".format(method_name)) + if not method_name in ["save_local_state", "load_local_state"]: + raise ValueError("Method {} has a shared_state.s argument but isn't respecting conventions".format(method_name)) # Now we are totally free to modify the original methods inplace # We need to differentiate between aggregations and local computations @@ -124,8 +125,8 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = # this stems from the algo reinstantiating itself using its class # doing sthg like my_algo.__class__(**my_algo.kwargs) for local_computation, local_name in zip(local_computations_fct[key], local_functions_names[key]): - setattr(obj, local_name, types.MethodType(local_computation, obj)) setattr(obj_class, local_name, types.MethodType(local_computation, obj_class)) + setattr(obj, local_name, types.MethodType(local_computation, obj)) for agg_fct, agg_name in zip(aggregations_fct[key], aggregations_names[key]): @@ -134,10 +135,10 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = # We need to hook the load and save state methods to be able to save load # all bootstrapped states - setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states, strategy.algo)) - setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states, strategy.algo.__class__)) - setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states, strategy.algo)) - setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states, strategy.algo.__class__)) + setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states(), strategy.algo)) + setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states(), strategy.algo.__class__)) + setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states(), strategy.algo)) + setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states(), strategy.algo.__class__)) if not inplace: return strategy @@ -301,6 +302,7 @@ def load_local_state(self, path: Path) -> "TorchAlgo": # This first call is needed when no bootstrap has been done # self.load_local_state_original(path) raise ValueError("No bootstrap has been done yet we cannot load states") + return load_local_state def _save_all_bootstraps_states():