diff --git a/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py b/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py new file mode 100644 index 0000000000..af15043785 --- /dev/null +++ b/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedAvg, FedJob, ScriptExecutor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/cifar10_fl.py" + + job = FedJob(name="cifar10_fedavg") + + # Define the controller workflow and send to server + controller = FedAvg( + num_clients=n_clients, + num_rounds=num_rounds, + ) + job.to_server(controller) + + # Define the initial global model and send to server + job.to_server(Net()) + + # Send executor to all clients + executor = ScriptExecutor( + task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to_clients(executor) + + # job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients) diff --git a/nvflare/job_config/fed_job.py b/nvflare/job_config/fed_job.py index e006729fe8..bd72a8fce6 100644 --- a/nvflare/job_config/fed_job.py +++ b/nvflare/job_config/fed_job.py @@ -19,12 +19,14 @@ from nvflare.apis.executor import Executor from nvflare.apis.filter import Filter from nvflare.apis.impl.controller import Controller +from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME from nvflare.app_common.executors.script_executor import ScriptExecutor from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator from nvflare.fuel.utils.class_utils import get_component_init_parameters from nvflare.fuel.utils.import_utils import optional_import +from nvflare.fuel.utils.validation_utils import check_positive_int from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig from nvflare.job_config.fed_job_config import FedJobConfig @@ -103,6 +105,56 @@ def add_external_scripts(self, external_scripts: List): self.app.add_ext_script(_script) +class ExecutorApp(FedApp): + def __init__(self): + """Wrapper around `ClientAppConfig`.""" + super().__init__() + self._create_client_app() + + def add_executor(self, executor, tasks=None): + if tasks is None: + tasks = ["*"] # Add executor for any task by default + self.app.add_executor(tasks, executor) + + def _create_client_app(self): + self.app = ClientAppConfig() + + component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") + self.app.add_component("event_to_fed", component) + + +class ControllerApp(FedApp): + """Wrapper around `ServerAppConfig`. + + Args: + """ + + def __init__(self, key_metric="accuracy"): + super().__init__() + self.key_metric = key_metric + self._create_server_app() + + def add_controller(self, controller, id=None): + if id is None: + id = "controller" + self.app.add_workflow(self._gen_tracked_id(id), controller) + + def _create_server_app(self): + self.app: ServerAppConfig = ServerAppConfig() + + component = ValidationJsonGenerator() + self.app.add_component("json_generator", component) + + if self.key_metric: + component = IntimeModelSelector(key_metric=self.key_metric) + self.app.add_component("model_selector", component) + + # TODO: make different tracking receivers configurable + if torch_ok and tb_ok: + component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) + self.app.add_component("receiver", component) + + class FedJob: def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, key_metric="accuracy") -> None: """FedJob allows users to generate job configurations in a Pythonic way. @@ -136,7 +188,7 @@ def to( filter_type: FilterType = None, id=None, ): - """assign an `obj` to a target (server or clients). + """assign an object to a target (server or clients). Args: obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type. @@ -218,6 +270,51 @@ def to( if self._components: self._add_referenced_components(obj, target) + def to_server( + self, + obj: Any, + filter_type: FilterType = None, + id=None, + ): + """assign an object to the server. + + Args: + obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type. + filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`. + id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned. + + Returns: + + """ + if isinstance(obj, Executor): + raise ValueError("Use `job.to(executor, )` or `job.to_clients(executor)` for Executors.") + + self.to(obj=obj, target=SERVER_SITE_NAME, filter_type=filter_type, id=id) + + def to_clients( + self, + obj: Any, + tasks: List[str] = None, + filter_type: FilterType = None, + id=None, + ): + """assign an object to all clients. + + Args: + obj (Any): Object to be deployed. + tasks: In case object is an `Executor`, optional list of tasks the executor should handle. + Defaults to `None`. If `None`, all tasks will be handled using `[*]`. + filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`. + id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned. + + Returns: + + """ + if isinstance(obj, Controller): + raise ValueError('Use `job.to(controller, "server")` or `job.to_server(controller)` for Controllers.') + + self.to(obj=obj, target=ALL_SITES, tasks=tasks, filter_type=filter_type, id=id) + def as_id(self, obj: Any): id = str(uuid.uuid4()) self._components[id] = obj @@ -260,10 +357,30 @@ def _set_site_app(self, app: FedApp, target: str): self.job.add_fed_app(app_name, app_config) self.job.set_site_app(target, app_name) + def _set_all_app(self, client_app: ExecutorApp, server_app: ControllerApp): + if not isinstance(client_app, ExecutorApp): + raise ValueError(f"`client_app` needs to be of type `ExecutorApp` but was type {type(client_app)}") + if not isinstance(server_app, ControllerApp): + raise ValueError(f"`server_app` needs to be of type `ControllerApp` but was type {type(server_app)}") + + client_config = client_app.get_app_config() + server_config = server_app.get_app_config() + + app_config = FedAppConfig(server_app=server_config, client_app=client_config) + app_name = "app" + + self.job.add_fed_app(app_name, app_config) + self.job.set_site_app(ALL_SITES, app_name) + def _set_all_apps(self): if not self._deployed: - for target in self._deploy_map: - self._set_site_app(self._deploy_map[target], target) + if ALL_SITES in self._deploy_map: + if SERVER_SITE_NAME not in self._deploy_map: + raise ValueError('Missing server components! Deploy using `to(obj, "server") or `to_server(obj)`') + self._set_all_app(client_app=self._deploy_map[ALL_SITES], server_app=self._deploy_map[SERVER_SITE_NAME]) + else: + for target in self._deploy_map: + self._set_site_app(self._deploy_map[target], target) self._deployed = True @@ -271,10 +388,19 @@ def export_job(self, job_root): self._set_all_apps() self.job.generate_job_config(job_root) - def simulator_run(self, workspace, threads: int = None): + def simulator_run(self, workspace, n_clients: int = None, threads: int = None): self._set_all_apps() + if ALL_SITES in self.clients and not n_clients: + raise ValueError("Clients were not specified using to(). Please provide the number of clients to simulate.") + elif ALL_SITES in self.clients and n_clients: + check_positive_int("n_clients", n_clients) + self.clients = [f"site-{i}" for i in range(1, n_clients + 1)] + elif self.clients and n_clients: + raise ValueError("You already specified clients using `to()`. Don't use `n_clients` in simulator_run.") + n_clients = len(self.clients) + if threads is None: threads = n_clients @@ -290,56 +416,6 @@ def _validate_target(self, target): if not target: raise ValueError("Must provide a valid target name") - if any(c in SPECIAL_CHARACTERS for c in target): + if any(c in SPECIAL_CHARACTERS for c in target) and target != ALL_SITES: raise ValueError(f"target {target} name contains invalid character") pass - - -class ExecutorApp(FedApp): - def __init__(self): - """Wrapper around `ClientAppConfig`.""" - super().__init__() - self._create_client_app() - - def add_executor(self, executor, tasks=None): - if tasks is None: - tasks = ["*"] # Add executor for any task by default - self.app.add_executor(tasks, executor) - - def _create_client_app(self): - self.app = ClientAppConfig() - - component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") - self.app.add_component("event_to_fed", component) - - -class ControllerApp(FedApp): - """Wrapper around `ServerAppConfig`. - - Args: - """ - - def __init__(self, key_metric="accuracy"): - super().__init__() - self.key_metric = key_metric - self._create_server_app() - - def add_controller(self, controller, id=None): - if id is None: - id = "controller" - self.app.add_workflow(self._gen_tracked_id(id), controller) - - def _create_server_app(self): - self.app: ServerAppConfig = ServerAppConfig() - - component = ValidationJsonGenerator() - self.app.add_component("json_generator", component) - - if self.key_metric: - component = IntimeModelSelector(key_metric=self.key_metric) - self.app.add_component("model_selector", component) - - # TODO: make different tracking receivers configurable - if torch_ok and tb_ok: - component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) - self.app.add_component("receiver", component) diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index 509e35d382..f98bc9ce9f 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -65,6 +65,15 @@ def add_fed_app(self, app_name: str, fed_app: FedAppConfig): self.fed_apps[app_name] = fed_app def set_site_app(self, site_name: str, app_name: str): + """assign an app to a certain site. + + Args: + site_name: The target site name. + app_name: The app name. + + Returns: + + """ if app_name not in self.fed_apps.keys(): raise RuntimeError(f"fed_app {app_name} does not exist.")