Skip to content

Commit

Permalink
initial cifar10 split learning
Browse files Browse the repository at this point in the history
run training with higher timeout value

rename classes

add training accuracy

remove printouts

add figure

update printouts

update readme

vertical split data

run psi

add notebook

update notebook

update notebook

take intersection.txt as input for split-learning

configure overlap

add todo

refactor to use FCI

update requirements

formatting

add validation

unify gitignore
  • Loading branch information
holgerroth committed Feb 6, 2023
1 parent d30fd7f commit 6b91ea8
Show file tree
Hide file tree
Showing 29 changed files with 3,133 additions and 38 deletions.
File renamed without changes.
4 changes: 4 additions & 0 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ This example runs you through the process and includes instructions on running
[FedAvg](https://arxiv.org/abs/1602.05629) with streaming of TensorBoard metrics to the server during training
and [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/)
for secure server-side aggregation.

### [Split learning with CIFAR-10](./cifar10-splitnn/README.md)
This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060)
using the CIFAR-10 dataset and the FL simulator in a vertical FL scenario.
23 changes: 0 additions & 23 deletions examples/cifar10/cifar10-sim/.gitignore

This file was deleted.

43 changes: 43 additions & 0 deletions examples/cifar10/cifar10-splitnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Split Learning with CIFAR-10

This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the CIFAR-10 dataset and the FL simulator.

We assume one client holds the images, and the other clients holds the labels to compute losses and accuracy metrics.
Activations and corresponding gradients are being exchanged between the clients through the NVFlare server.

<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/>

For instructions of how to run CIFAR-10 in real-world deployment settings,
see the example on ["Real-world Federated Learning with CIFAR-10"](../cifar10-real-world/README.md).

## (Optional) Set up a virtual environment
```
python3 -m pip install --user --upgrade pip
python3 -m pip install --user virtualenv
```
(If needed) make all shell scripts executable using
```
find . -name ".sh" -exec chmod +x {} \;
```
initialize virtual environment.
```
source ./virtualenv/set_env.sh
```
install required packages for training
```
pip install --upgrade pip
pip install -r ./virtualenv/min-requirements.txt
```

## Start Jupyter notebook
Set `PYTHONPATH` to include custom files of this example:
```
export PYTHONPATH=${PWD}/..
```
Start a Jupyter Lab
```
jupyter lab .
```
and open [cifar10_split_learning.ipynb](./cifar10_split_learning.ipynb).

See [here](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) for installing Jupyter Lab.
1,439 changes: 1,439 additions & 0 deletions examples/cifar10/cifar10-splitnn/cifar10_split_learning.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/cifar10/cifar10-splitnn/figs/split_learning.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "cifar10_psi",
"deploy_map": {
"server": ["server"],
"site-1": ["site-1"],
"site-2": ["site-2"]
},
"min_clients": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"format_version": 2,
"workflows": [
{
"id": "DhPSIController",
"path": "nvflare.app_common.workflows.dh_psi_controller.DhPSIController",
"args": {
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"PSI"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor",
"args": {
"local_psi_id": "local_psi"
}
}
}
],
"components": [
{
"id": "local_psi",
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI",
"args": {
"psi_writer_id": "psi_writer",
"data_path": "/tmp/cifar10_vert_splits/site-1.npy"
}
},
{
"id": "psi_writer",
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter",
"args": {
"output_path": "psi/intersection.txt"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"PSI"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor",
"args": {
"local_psi_id": "local_psi"
}
}
}
],
"components": [
{
"id": "local_psi",
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI",
"args": {
"psi_writer_id": "psi_writer",
"data_path": "/tmp/cifar10_vert_splits/site-2.npy"
}
},
{
"id": "psi_writer",
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter",
"args": {
"output_path": "psi/intersection.txt"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "cifar10_splitnn",
"deploy_map": {
"server": ["server"],
"site-1": ["site-1"],
"site-2": ["site-2"]
},
"min_clients": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"format_version": 2,

"num_rounds": 15625,
"batch_size": 64,

"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "persistor",
"path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor",
"args": {
"model": {
"path": "pt.networks.cifar10_nets.ModerateCNN"
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "json_generator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
{
"id": "splitnn_ctl",
"path": "nvflare.app_common.workflows.splitnn_workflow.SplitNNController",
"args": {
"num_rounds" : "{num_rounds}",
"batch_size": "{batch_size}",
"start_round": 0,
"persistor_id": "persistor",
"task_timeout": 0,
"shareable_generator_id": "shareable_generator"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"format_version": 2,

"DATASET_ROOT": "/tmp/cifar10",
"INTERSECTION_FILE": "site-1-intersection.txt",

"executors": [
{
"tasks": [
"_splitnn_task_init_model_",
"_splitnn_task_train_"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],

"task_result_filters": [
],
"task_data_filters": [
],

"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
"args": {
"dataset_root": "{DATASET_ROOT}",
"intersection_file": "{INTERSECTION_FILE}",
"lr": 1e-2,
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 0}},
"timeit": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"format_version": 2,

"DATASET_ROOT": "/tmp/cifar10",
"INTERSECTION_FILE": "site-2-intersection.txt",

"executors": [
{
"tasks": [
"_splitnn_task_init_model_",
"_splitnn_task_train_"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],

"task_result_filters": [
],
"task_data_filters": [
],

"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
"args": {
"dataset_root": "{DATASET_ROOT}",
"intersection_file": "{INTERSECTION_FILE}",
"lr": 1e-2,
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 1}},
"timeit": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
nvflare>=2.3.0
torch
torchvision
tensorboard
openmined.psi
pandas
jupyterlab
6 changes: 6 additions & 0 deletions examples/cifar10/cifar10-splitnn/virtualenv/set_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

export projectname='nvflare_cifar10'

python3 -m venv ${projectname}
source ${projectname}/bin/activate
Loading

0 comments on commit 6b91ea8

Please sign in to comment.