Skip to content

Commit

Permalink
Merge branch 'main' into update_monai_example
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Jul 30, 2024
2 parents 7e95853 + 1ef5207 commit c9a4375
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 80 deletions.
43 changes: 43 additions & 0 deletions examples/getting_started/pt/fedavg_script_executor_cifar10_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.net import Net

from nvflare import FedAvg, FedJob, ScriptExecutor

if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "src/cifar10_fl.py"

job = FedJob(name="cifar10_fedavg")

# Define the controller workflow and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to_server(controller)

# Define the initial global model and send to server
job.to_server(Net())

# Send executor to all clients
executor = ScriptExecutor(
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to_clients(executor)

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients)
26 changes: 2 additions & 24 deletions nvflare/fuel/f3/cellnet/cell_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def _verify(k, m, s):
)


def _sym_enc(k, n, m):
def _sym_enc(k: bytes, n: bytes, m: bytes):
cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n))
encryptor = cipher.encryptor()
padder = padding.PKCS7(PADDING_LENGTH).padder()
padded_data = padder.update(m) + padder.finalize()
return encryptor.update(padded_data) + encryptor.finalize()


def _sym_dec(k, n, m):
def _sym_dec(k: bytes, n: bytes, m: bytes):
cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n))
decryptor = cipher.decryptor()
plain_text = decryptor.update(m)
Expand Down Expand Up @@ -157,28 +157,6 @@ def get_latest_key(self):
return last_value


class CellCipher:
def __init__(self, session_key_manager: SessionKeyManager):
self.session_key_manager = session_key_manager

def encrypt(self, message):
key = self.session_key_manager.get_latest_key()
key_hash = get_hash(key)
nonce = os.urandom(NONCE_LENGTH)
return nonce + key_hash[-HASH_LENGTH:] + _sym_enc(key, nonce, message)

def decrypt(self, message):
nonce, key_hash, message = (
message[:NONCE_LENGTH],
message[NONCE_LENGTH:HEADER_LENGTH],
message[HEADER_LENGTH:],
)
key = self.session_key_manager.get_key(key_hash)
if key is None:
raise SessionKeyUnavailable("No session key found for received message")
return _sym_dec(key, nonce, message)


class SimpleCellCipher:
def __init__(self, root_ca: Certificate, pri_key: asymmetric.rsa.RSAPrivateKey, cert: Certificate):
self._root_ca = root_ca
Expand Down
4 changes: 4 additions & 0 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,10 @@ def encrypt_payload(self, message: Message):

if message.payload is None:
message.payload = bytes(0)
elif isinstance(message.payload, memoryview) or isinstance(message.payload, bytearray):
message.payload = bytes(message.payload)
elif not isinstance(message.payload, bytes):
raise RuntimeError(f"Payload type of {type(message.payload)} is not supported.")

payload_len = len(message.payload)
message.add_headers(
Expand Down
186 changes: 131 additions & 55 deletions nvflare/job_config/fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from nvflare.apis.executor import Executor
from nvflare.apis.filter import Filter
from nvflare.apis.impl.controller import Controller
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.app_common.executors.script_executor import ScriptExecutor
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.fuel.utils.class_utils import get_component_init_parameters
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.validation_utils import check_positive_int
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig

Expand Down Expand Up @@ -103,6 +105,56 @@ def add_external_scripts(self, external_scripts: List):
self.app.add_ext_script(_script)


class ExecutorApp(FedApp):
def __init__(self):
"""Wrapper around `ClientAppConfig`."""
super().__init__()
self._create_client_app()

def add_executor(self, executor, tasks=None):
if tasks is None:
tasks = ["*"] # Add executor for any task by default
self.app.add_executor(tasks, executor)

def _create_client_app(self):
self.app = ClientAppConfig()

component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.")
self.app.add_component("event_to_fed", component)


class ControllerApp(FedApp):
"""Wrapper around `ServerAppConfig`.
Args:
"""

def __init__(self, key_metric="accuracy"):
super().__init__()
self.key_metric = key_metric
self._create_server_app()

def add_controller(self, controller, id=None):
if id is None:
id = "controller"
self.app.add_workflow(self._gen_tracked_id(id), controller)

def _create_server_app(self):
self.app: ServerAppConfig = ServerAppConfig()

component = ValidationJsonGenerator()
self.app.add_component("json_generator", component)

if self.key_metric:
component = IntimeModelSelector(key_metric=self.key_metric)
self.app.add_component("model_selector", component)

# TODO: make different tracking receivers configurable
if torch_ok and tb_ok:
component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
self.app.add_component("receiver", component)


class FedJob:
def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, key_metric="accuracy") -> None:
"""FedJob allows users to generate job configurations in a Pythonic way.
Expand Down Expand Up @@ -136,7 +188,7 @@ def to(
filter_type: FilterType = None,
id=None,
):
"""assign an `obj` to a target (server or clients).
"""assign an object to a target (server or clients).
Args:
obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type.
Expand Down Expand Up @@ -218,6 +270,51 @@ def to(
if self._components:
self._add_referenced_components(obj, target)

def to_server(
self,
obj: Any,
filter_type: FilterType = None,
id=None,
):
"""assign an object to the server.
Args:
obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type.
filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
Returns:
"""
if isinstance(obj, Executor):
raise ValueError("Use `job.to(executor, <client_name>)` or `job.to_clients(executor)` for Executors.")

self.to(obj=obj, target=SERVER_SITE_NAME, filter_type=filter_type, id=id)

def to_clients(
self,
obj: Any,
tasks: List[str] = None,
filter_type: FilterType = None,
id=None,
):
"""assign an object to all clients.
Args:
obj (Any): Object to be deployed.
tasks: In case object is an `Executor`, optional list of tasks the executor should handle.
Defaults to `None`. If `None`, all tasks will be handled using `[*]`.
filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
Returns:
"""
if isinstance(obj, Controller):
raise ValueError('Use `job.to(controller, "server")` or `job.to_server(controller)` for Controllers.')

self.to(obj=obj, target=ALL_SITES, tasks=tasks, filter_type=filter_type, id=id)

def as_id(self, obj: Any):
id = str(uuid.uuid4())
self._components[id] = obj
Expand Down Expand Up @@ -260,21 +357,50 @@ def _set_site_app(self, app: FedApp, target: str):
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)

def _set_all_app(self, client_app: ExecutorApp, server_app: ControllerApp):
if not isinstance(client_app, ExecutorApp):
raise ValueError(f"`client_app` needs to be of type `ExecutorApp` but was type {type(client_app)}")
if not isinstance(server_app, ControllerApp):
raise ValueError(f"`server_app` needs to be of type `ControllerApp` but was type {type(server_app)}")

client_config = client_app.get_app_config()
server_config = server_app.get_app_config()

app_config = FedAppConfig(server_app=server_config, client_app=client_config)
app_name = "app"

self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(ALL_SITES, app_name)

def _set_all_apps(self):
if not self._deployed:
for target in self._deploy_map:
self._set_site_app(self._deploy_map[target], target)
if ALL_SITES in self._deploy_map:
if SERVER_SITE_NAME not in self._deploy_map:
raise ValueError('Missing server components! Deploy using `to(obj, "server") or `to_server(obj)`')
self._set_all_app(client_app=self._deploy_map[ALL_SITES], server_app=self._deploy_map[SERVER_SITE_NAME])
else:
for target in self._deploy_map:
self._set_site_app(self._deploy_map[target], target)

self._deployed = True

def export_job(self, job_root):
self._set_all_apps()
self.job.generate_job_config(job_root)

def simulator_run(self, workspace, threads: int = None):
def simulator_run(self, workspace, n_clients: int = None, threads: int = None):
self._set_all_apps()

if ALL_SITES in self.clients and not n_clients:
raise ValueError("Clients were not specified using to(). Please provide the number of clients to simulate.")
elif ALL_SITES in self.clients and n_clients:
check_positive_int("n_clients", n_clients)
self.clients = [f"site-{i}" for i in range(1, n_clients + 1)]
elif self.clients and n_clients:
raise ValueError("You already specified clients using `to()`. Don't use `n_clients` in simulator_run.")

n_clients = len(self.clients)

if threads is None:
threads = n_clients

Expand All @@ -290,56 +416,6 @@ def _validate_target(self, target):
if not target:
raise ValueError("Must provide a valid target name")

if any(c in SPECIAL_CHARACTERS for c in target):
if any(c in SPECIAL_CHARACTERS for c in target) and target != ALL_SITES:
raise ValueError(f"target {target} name contains invalid character")
pass


class ExecutorApp(FedApp):
def __init__(self):
"""Wrapper around `ClientAppConfig`."""
super().__init__()
self._create_client_app()

def add_executor(self, executor, tasks=None):
if tasks is None:
tasks = ["*"] # Add executor for any task by default
self.app.add_executor(tasks, executor)

def _create_client_app(self):
self.app = ClientAppConfig()

component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.")
self.app.add_component("event_to_fed", component)


class ControllerApp(FedApp):
"""Wrapper around `ServerAppConfig`.
Args:
"""

def __init__(self, key_metric="accuracy"):
super().__init__()
self.key_metric = key_metric
self._create_server_app()

def add_controller(self, controller, id=None):
if id is None:
id = "controller"
self.app.add_workflow(self._gen_tracked_id(id), controller)

def _create_server_app(self):
self.app: ServerAppConfig = ServerAppConfig()

component = ValidationJsonGenerator()
self.app.add_component("json_generator", component)

if self.key_metric:
component = IntimeModelSelector(key_metric=self.key_metric)
self.app.add_component("model_selector", component)

# TODO: make different tracking receivers configurable
if torch_ok and tb_ok:
component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
self.app.add_component("receiver", component)
9 changes: 9 additions & 0 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def add_fed_app(self, app_name: str, fed_app: FedAppConfig):
self.fed_apps[app_name] = fed_app

def set_site_app(self, site_name: str, app_name: str):
"""assign an app to a certain site.
Args:
site_name: The target site name.
app_name: The app name.
Returns:
"""
if app_name not in self.fed_apps.keys():
raise RuntimeError(f"fed_app {app_name} does not exist.")

Expand Down
Loading

0 comments on commit c9a4375

Please sign in to comment.