Skip to content

Commit

Permalink
not working because of weird class vs instance issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 9, 2024
1 parent 167fa42 commit 095b67b
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =
obj = strategy.algo
key = "algo"
method_args_dict = inspect.signature(getattr(obj, method_name)).parameters
if not (("shared_states" in method_args_dict) or ("shared_state" in method_args_dict)):
if not (("shared_states" in method_args_dict) or ("shared_state" in method_args_dict)) or (method_name in ["save_local_state", "load_local_state"]):
continue
# We create a copy of all methods with the original suffix to avoid name
# collision and infinite recursion when decorating the old methods
Expand All @@ -74,7 +74,8 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =
elif "shared_states" in method_args_dict:
aggregations_names[key].append(method_name)
else:
raise ValueError("Method {} has a shared_state.s argument but isn't respecting conventions".format(method_name))
if not method_name in ["save_local_state", "load_local_state"]:
raise ValueError("Method {} has a shared_state.s argument but isn't respecting conventions".format(method_name))

# Now we are totally free to modify the original methods inplace
# We need to differentiate between aggregations and local computations
Expand Down Expand Up @@ -124,8 +125,8 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =
# this stems from the algo reinstantiating itself using its class
# doing sthg like my_algo.__class__(**my_algo.kwargs)
for local_computation, local_name in zip(local_computations_fct[key], local_functions_names[key]):
setattr(obj, local_name, types.MethodType(local_computation, obj))
setattr(obj_class, local_name, types.MethodType(local_computation, obj_class))
setattr(obj, local_name, types.MethodType(local_computation, obj))


for agg_fct, agg_name in zip(aggregations_fct[key], aggregations_names[key]):
Expand All @@ -134,10 +135,10 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =

# We need to hook the load and save state methods to be able to save load
# all bootstrapped states
setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states, strategy.algo))
setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states, strategy.algo.__class__))
setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states, strategy.algo))
setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states, strategy.algo.__class__))
setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states(), strategy.algo))
setattr(strategy.algo, "save_local_state", types.MethodType(_save_all_bootstraps_states(), strategy.algo.__class__))
setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states(), strategy.algo))
setattr(strategy.algo, "load_local_state", types.MethodType(_load_all_bootstraps_states(), strategy.algo.__class__))
if not inplace:
return strategy

Expand Down Expand Up @@ -301,6 +302,7 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
# This first call is needed when no bootstrap has been done
# self.load_local_state_original(path)
raise ValueError("No bootstrap has been done yet we cannot load states")
return load_local_state


def _save_all_bootstraps_states():
Expand Down

0 comments on commit 095b67b

Please sign in to comment.