-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
run training with higher timeout value rename classes add training accuracy remove printouts add figure update printouts update readme vertical split data run psi add notebook update notebook update notebook take intersection.txt as input for split-learning configure overlap add todo refactor to use FCI update requirements formatting add validation unify gitignore
- Loading branch information
1 parent
d30fd7f
commit 6b91ea8
Showing
29 changed files
with
3,133 additions
and
38 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Split Learning with CIFAR-10 | ||
|
||
This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the CIFAR-10 dataset and the FL simulator. | ||
|
||
We assume one client holds the images, and the other clients holds the labels to compute losses and accuracy metrics. | ||
Activations and corresponding gradients are being exchanged between the clients through the NVFlare server. | ||
|
||
<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/> | ||
|
||
For instructions of how to run CIFAR-10 in real-world deployment settings, | ||
see the example on ["Real-world Federated Learning with CIFAR-10"](../cifar10-real-world/README.md). | ||
|
||
## (Optional) Set up a virtual environment | ||
``` | ||
python3 -m pip install --user --upgrade pip | ||
python3 -m pip install --user virtualenv | ||
``` | ||
(If needed) make all shell scripts executable using | ||
``` | ||
find . -name ".sh" -exec chmod +x {} \; | ||
``` | ||
initialize virtual environment. | ||
``` | ||
source ./virtualenv/set_env.sh | ||
``` | ||
install required packages for training | ||
``` | ||
pip install --upgrade pip | ||
pip install -r ./virtualenv/min-requirements.txt | ||
``` | ||
|
||
## Start Jupyter notebook | ||
Set `PYTHONPATH` to include custom files of this example: | ||
``` | ||
export PYTHONPATH=${PWD}/.. | ||
``` | ||
Start a Jupyter Lab | ||
``` | ||
jupyter lab . | ||
``` | ||
and open [cifar10_split_learning.ipynb](./cifar10_split_learning.ipynb). | ||
|
||
See [here](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) for installing Jupyter Lab. |
1,439 changes: 1,439 additions & 0 deletions
1,439
examples/cifar10/cifar10-splitnn/cifar10_split_learning.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions
9
examples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/meta.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"name": "cifar10_psi", | ||
"deploy_map": { | ||
"server": ["server"], | ||
"site-1": ["site-1"], | ||
"site-2": ["site-2"] | ||
}, | ||
"min_clients": 2 | ||
} |
11 changes: 11 additions & 0 deletions
11
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/server/config/config_fed_server.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"format_version": 2, | ||
"workflows": [ | ||
{ | ||
"id": "DhPSIController", | ||
"path": "nvflare.app_common.workflows.dh_psi_controller.DhPSIController", | ||
"args": { | ||
} | ||
} | ||
] | ||
} |
34 changes: 34 additions & 0 deletions
34
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/site-1/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"format_version": 2, | ||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"PSI" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor", | ||
"args": { | ||
"local_psi_id": "local_psi" | ||
} | ||
} | ||
} | ||
], | ||
"components": [ | ||
{ | ||
"id": "local_psi", | ||
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI", | ||
"args": { | ||
"psi_writer_id": "psi_writer", | ||
"data_path": "/tmp/cifar10_vert_splits/site-1.npy" | ||
} | ||
}, | ||
{ | ||
"id": "psi_writer", | ||
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter", | ||
"args": { | ||
"output_path": "psi/intersection.txt" | ||
} | ||
} | ||
] | ||
} |
34 changes: 34 additions & 0 deletions
34
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/site-2/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"format_version": 2, | ||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"PSI" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor", | ||
"args": { | ||
"local_psi_id": "local_psi" | ||
} | ||
} | ||
} | ||
], | ||
"components": [ | ||
{ | ||
"id": "local_psi", | ||
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI", | ||
"args": { | ||
"psi_writer_id": "psi_writer", | ||
"data_path": "/tmp/cifar10_vert_splits/site-2.npy" | ||
} | ||
}, | ||
{ | ||
"id": "psi_writer", | ||
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter", | ||
"args": { | ||
"output_path": "psi/intersection.txt" | ||
} | ||
} | ||
] | ||
} |
9 changes: 9 additions & 0 deletions
9
examples/cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/meta.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"name": "cifar10_splitnn", | ||
"deploy_map": { | ||
"server": ["server"], | ||
"site-1": ["site-1"], | ||
"site-2": ["site-2"] | ||
}, | ||
"min_clients": 2 | ||
} |
47 changes: 47 additions & 0 deletions
47
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/server/config/config_fed_server.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"num_rounds": 15625, | ||
"batch_size": 64, | ||
|
||
"server": { | ||
"heart_beat_timeout": 600 | ||
}, | ||
"task_data_filters": [], | ||
"task_result_filters": [], | ||
"components": [ | ||
{ | ||
"id": "persistor", | ||
"path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor", | ||
"args": { | ||
"model": { | ||
"path": "pt.networks.cifar10_nets.ModerateCNN" | ||
} | ||
} | ||
}, | ||
{ | ||
"id": "shareable_generator", | ||
"name": "FullModelShareableGenerator", | ||
"args": {} | ||
}, | ||
{ | ||
"id": "json_generator", | ||
"name": "ValidationJsonGenerator", | ||
"args": {} | ||
} | ||
], | ||
"workflows": [ | ||
{ | ||
"id": "splitnn_ctl", | ||
"path": "nvflare.app_common.workflows.splitnn_workflow.SplitNNController", | ||
"args": { | ||
"num_rounds" : "{num_rounds}", | ||
"batch_size": "{batch_size}", | ||
"start_round": 0, | ||
"persistor_id": "persistor", | ||
"task_timeout": 0, | ||
"shareable_generator_id": "shareable_generator" | ||
} | ||
} | ||
] | ||
} |
41 changes: 41 additions & 0 deletions
41
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/site-1/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"DATASET_ROOT": "/tmp/cifar10", | ||
"INTERSECTION_FILE": "site-1-intersection.txt", | ||
|
||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"_splitnn_task_init_model_", | ||
"_splitnn_task_train_" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor", | ||
"args": { | ||
"learner_id": "cifar10-learner" | ||
} | ||
} | ||
} | ||
], | ||
|
||
"task_result_filters": [ | ||
], | ||
"task_data_filters": [ | ||
], | ||
|
||
"components": [ | ||
{ | ||
"id": "cifar10-learner", | ||
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN", | ||
"args": { | ||
"dataset_root": "{DATASET_ROOT}", | ||
"intersection_file": "{INTERSECTION_FILE}", | ||
"lr": 1e-2, | ||
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 0}}, | ||
"timeit": true | ||
} | ||
} | ||
] | ||
} |
41 changes: 41 additions & 0 deletions
41
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/site-2/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"DATASET_ROOT": "/tmp/cifar10", | ||
"INTERSECTION_FILE": "site-2-intersection.txt", | ||
|
||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"_splitnn_task_init_model_", | ||
"_splitnn_task_train_" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor", | ||
"args": { | ||
"learner_id": "cifar10-learner" | ||
} | ||
} | ||
} | ||
], | ||
|
||
"task_result_filters": [ | ||
], | ||
"task_data_filters": [ | ||
], | ||
|
||
"components": [ | ||
{ | ||
"id": "cifar10-learner", | ||
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN", | ||
"args": { | ||
"dataset_root": "{DATASET_ROOT}", | ||
"intersection_file": "{INTERSECTION_FILE}", | ||
"lr": 1e-2, | ||
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 1}}, | ||
"timeit": true | ||
} | ||
} | ||
] | ||
} |
7 changes: 7 additions & 0 deletions
7
examples/cifar10/cifar10-splitnn/virtualenv/min-requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
nvflare>=2.3.0 | ||
torch | ||
torchvision | ||
tensorboard | ||
openmined.psi | ||
pandas | ||
jupyterlab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
|
||
export projectname='nvflare_cifar10' | ||
|
||
python3 -m venv ${projectname} | ||
source ${projectname}/bin/activate |
Oops, something went wrong.