Skip to content

Commit

Permalink
Adding N-level arguments functionality
Browse files Browse the repository at this point in the history
Signed-off-by: Parth Mandaliya <parthx.mandaliya@intel.com>
  • Loading branch information
ParthM-GitHub committed Sep 28, 2023
1 parent 880b7a3 commit 52f0b52
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ collaborator :


federated_flow:
template: src.flow.FederatedFlow
template: src.flow.MNISTFlow
settings:
model: src.flow.Net
model:
template: src.flow.Net
settings: {}
optimizer: null
rounds: 4
checkpoint: true
Expand Down
22 changes: 11 additions & 11 deletions openfl-workspace/experimental/101_torch_cnn_mnist/src/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from openfl.experimental.interface import FLSpec
from openfl.experimental.placement import aggregator, collaborator
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
Expand Down Expand Up @@ -59,11 +61,7 @@ def inference(network, test_loader):
return accuracy


from openfl.experimental.interface import FLSpec
from openfl.experimental.placement import aggregator, collaborator


def FedAvg(models):
def fedavg(models):
new_model = models[0]
state_dicts = [model.state_dict() for model in models]
state_dict = new_model.state_dict()
Expand All @@ -75,7 +73,7 @@ def FedAvg(models):
return new_model


class FederatedFlow(FLSpec):
class MNISTFlow(FLSpec):
def __init__(self, model=None, optimizer=None, rounds=3, **kwargs):
super().__init__(**kwargs)
if model is not None:
Expand All @@ -90,7 +88,7 @@ def __init__(self, model=None, optimizer=None, rounds=3, **kwargs):

@aggregator
def start(self):
print(f"Performing initialization for model")
print("Performing initialization for model")
self.collaborators = self.runtime.collaborators
self.private = 10
self.current_round = 0
Expand Down Expand Up @@ -138,7 +136,8 @@ def train(self):
def local_model_validation(self):
self.local_validation_score = inference(self.model, self.test_loader)
print(
f"Doing local model validation for collaborator {self.input}: {self.local_validation_score}"
"Doing local model validation "
+ f"for collaborator {self.input}: {self.local_validation_score}"
)
self.next(self.join, exclude=["training_completed"])

Expand All @@ -152,11 +151,12 @@ def join(self, inputs):
input.local_validation_score for input in inputs
) / len(inputs)
print(
f"Average aggregated model validation values = {self.aggregated_model_accuracy}"
"Average aggregated model "
+ f"validation values = {self.aggregated_model_accuracy}"
)
print(f"Average training loss = {self.average_loss}")
print(f"Average local model validation values = {self.local_model_accuracy}")
self.model = FedAvg([input.model for input in inputs])
self.model = fedavg([input.model for input in inputs])
self.optimizer = [input.optimizer for input in inputs][0]
self.next(self.internal_loop)

Expand All @@ -174,4 +174,4 @@ def internal_loop(self):

@aggregator
def end(self):
print(f"This is the end of the flow")
print("This is the end of the flow")
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
Expand Down
76 changes: 47 additions & 29 deletions openfl/experimental/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@ def parse(
)
raise

@staticmethod
def accept_args(cls):
"""
Determines whether a class's constructor (__init__ method) accepts
variable positional arguments (*args).
Returns:
Boolean: True or False
"""
init_signature = inspect.signature(cls.__init__)
for param in init_signature.parameters.values():
if param.kind == param.VAR_POSITIONAL:
return True
return False

@staticmethod
def build(template, settings, **override):
"""
Expand All @@ -166,9 +181,13 @@ def build(template, settings, **override):
Plan.logger.debug(f"Override [red]🡆[/] {override}", extra={"markup": True})

settings.update(**override)

module = import_module(module_path)
instance = getattr(module, class_name)(**settings)

if Plan.accept_args(getattr(module, class_name)):
args = list(settings.values())
instance = getattr(module, class_name)(*args)
else:
instance = getattr(module, class_name)(**settings)

return instance

Expand Down Expand Up @@ -386,33 +405,32 @@ def get_flow(self):
return self.flow_

def import_kwargs_modules(self, defaults):
for key in defaults[SETTINGS]:
value_defaults = defaults[SETTINGS][key]
if isinstance(value_defaults, str):
class_name = splitext(value_defaults)[1].strip(".")
if class_name:
module_path = splitext(value_defaults)[0]
try:
if import_module(module_path):
module = import_module(module_path)
value_defaults_data = {
TEMPLATE: value_defaults,
SETTINGS: {},
}
attr = getattr(module, class_name)

if not inspect.isclass(attr):
self.logger.info(
"Setting private attributes from variable"
)
defaults[SETTINGS][key] = attr
else:
self.logger.info("Setting private from class")
defaults[SETTINGS][key] = Plan.build(
**value_defaults_data
)
except Exception:
raise ImportError(f"Cannot import {value_defaults}.")
def import_nested_settings(settings):
for key, value in settings.items():
if isinstance(value, dict):
settings[key] = import_nested_settings(value)
elif isinstance(value, str):
class_name = splitext(value)[1].strip(".")
if class_name:
module_path = splitext(value)[0]
try:
if import_module(module_path):
module = import_module(module_path)
value_defaults_data = {
'template': value,
'settings': settings.get('settings', {}),
}
attr = getattr(module, class_name)

if not inspect.isclass(attr):
settings[key] = attr
else:
settings = Plan.build(**value_defaults_data)
except ImportError:
raise ImportError(f"Cannot import {value}.")
return settings

defaults[SETTINGS] = import_nested_settings(defaults[SETTINGS])
return defaults

def get_private_attr(self, private_attr_name=None):
Expand Down

0 comments on commit 52f0b52

Please sign in to comment.