-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
run training with higher timeout value rename classes add training accuracy remove printouts add figure update printouts update readme vertical split data run psi add notebook update notebook update notebook take intersection.txt as input for split-learning configure overlap add todo refactor to use FCI
1 parent
c3082d1
commit cf580dc
Showing
28 changed files
with
3,160 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# ide | ||
.idea/ | ||
.ipynb_checkpoints/ | ||
|
||
# nvflare artifacts | ||
log.txt | ||
client_token.txt | ||
*.fl | ||
audit.log | ||
transfer | ||
workspaces | ||
|
||
# python | ||
__pycache__ | ||
.pyc | ||
|
||
# virtual environments | ||
nvflare_cifar10 | ||
|
||
# data | ||
dataset | ||
dataset* | ||
*results* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Split Learning with CIFAR-10 | ||
|
||
This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the CIFAR-10 dataset and the FL simulator. | ||
|
||
We assume one client holds the images, and the other clients holds the labels to compute losses and accuracy metrics. | ||
Activations and corresponding gradients are being exchanged between the clients through the NVFlare server. | ||
|
||
<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/> | ||
|
||
For instructions of how to run CIFAR-10 in real-world deployment settings, | ||
see the example on ["Real-world Federated Learning with CIFAR-10"](../cifar10-real-world/README.md). | ||
|
||
## (Optional) Set up a virtual environment | ||
``` | ||
python3 -m pip install --user --upgrade pip | ||
python3 -m pip install --user virtualenv | ||
``` | ||
(If needed) make all shell scripts executable using | ||
``` | ||
find . -name ".sh" -exec chmod +x {} \; | ||
``` | ||
initialize virtual environment. | ||
``` | ||
source ./virtualenv/set_env.sh | ||
``` | ||
install required packages for training | ||
``` | ||
pip install --upgrade pip | ||
pip install -r ./virtualenv/min-requirements.txt | ||
``` | ||
|
||
## Start Jupyter notebook | ||
Set `PYTHONPATH` to include custom files of this example: | ||
``` | ||
export PYTHONPATH=${PWD}/.. | ||
``` | ||
Start a Jupyter Lab | ||
``` | ||
jupyter lab . | ||
``` | ||
and open [cifar10_split_learning.ipynb](./cifar10_split_learning.ipynb). | ||
|
||
See [here](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) for installing Jupyter Lab. |
1,439 changes: 1,439 additions & 0 deletions
1,439
examples/cifar10/cifar10-splitnn/cifar10_split_learning.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions
9
examples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/meta.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"name": "cifar10_psi", | ||
"deploy_map": { | ||
"server": ["server"], | ||
"site-1": ["site-1"], | ||
"site-2": ["site-2"] | ||
}, | ||
"min_clients": 2 | ||
} |
11 changes: 11 additions & 0 deletions
11
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/server/config/config_fed_server.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"format_version": 2, | ||
"workflows": [ | ||
{ | ||
"id": "DhPSIController", | ||
"path": "nvflare.app_common.workflows.dh_psi_controller.DhPSIController", | ||
"args": { | ||
} | ||
} | ||
] | ||
} |
34 changes: 34 additions & 0 deletions
34
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/site-1/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"format_version": 2, | ||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"PSI" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor", | ||
"args": { | ||
"local_psi_id": "local_psi" | ||
} | ||
} | ||
} | ||
], | ||
"components": [ | ||
{ | ||
"id": "local_psi", | ||
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI", | ||
"args": { | ||
"psi_writer_id": "psi_writer", | ||
"data_path": "/tmp/cifar10_vert_splits/site-1.npy" | ||
} | ||
}, | ||
{ | ||
"id": "psi_writer", | ||
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter", | ||
"args": { | ||
"output_path": "psi/intersection.txt" | ||
} | ||
} | ||
] | ||
} |
34 changes: 34 additions & 0 deletions
34
...ples/cifar10/cifar10-splitnn/job_configs/cifar10_psi/site-2/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"format_version": 2, | ||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"PSI" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor", | ||
"args": { | ||
"local_psi_id": "local_psi" | ||
} | ||
} | ||
} | ||
], | ||
"components": [ | ||
{ | ||
"id": "local_psi", | ||
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI", | ||
"args": { | ||
"psi_writer_id": "psi_writer", | ||
"data_path": "/tmp/cifar10_vert_splits/site-2.npy" | ||
} | ||
}, | ||
{ | ||
"id": "psi_writer", | ||
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter", | ||
"args": { | ||
"output_path": "psi/intersection.txt" | ||
} | ||
} | ||
] | ||
} |
9 changes: 9 additions & 0 deletions
9
examples/cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/meta.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"name": "cifar10_splitnn", | ||
"deploy_map": { | ||
"server": ["server"], | ||
"site-1": ["site-1"], | ||
"site-2": ["site-2"] | ||
}, | ||
"min_clients": 2 | ||
} |
47 changes: 47 additions & 0 deletions
47
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/server/config/config_fed_server.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"num_rounds": 15625, | ||
"batch_size": 64, | ||
|
||
"server": { | ||
"heart_beat_timeout": 600 | ||
}, | ||
"task_data_filters": [], | ||
"task_result_filters": [], | ||
"components": [ | ||
{ | ||
"id": "persistor", | ||
"path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor", | ||
"args": { | ||
"model": { | ||
"path": "pt.networks.cifar10_nets.ModerateCNN" | ||
} | ||
} | ||
}, | ||
{ | ||
"id": "shareable_generator", | ||
"name": "FullModelShareableGenerator", | ||
"args": {} | ||
}, | ||
{ | ||
"id": "json_generator", | ||
"name": "ValidationJsonGenerator", | ||
"args": {} | ||
} | ||
], | ||
"workflows": [ | ||
{ | ||
"id": "splitnn_ctl", | ||
"path": "nvflare.app_common.workflows.splitnn_workflow.SplitNNController", | ||
"args": { | ||
"num_rounds" : "{num_rounds}", | ||
"batch_size": "{batch_size}", | ||
"start_round": 0, | ||
"persistor_id": "persistor", | ||
"task_timeout": 0, | ||
"shareable_generator_id": "shareable_generator" | ||
} | ||
} | ||
] | ||
} |
41 changes: 41 additions & 0 deletions
41
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/site-1/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"DATASET_ROOT": "/tmp/cifar10", | ||
"INTERSECTION_FILE": "site-1-intersection.txt", | ||
|
||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"_splitnn_task_init_model_", | ||
"_splitnn_task_train_" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor", | ||
"args": { | ||
"learner_id": "cifar10-learner" | ||
} | ||
} | ||
} | ||
], | ||
|
||
"task_result_filters": [ | ||
], | ||
"task_data_filters": [ | ||
], | ||
|
||
"components": [ | ||
{ | ||
"id": "cifar10-learner", | ||
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN", | ||
"args": { | ||
"dataset_root": "{DATASET_ROOT}", | ||
"intersection_file": "{INTERSECTION_FILE}", | ||
"lr": 1e-2, | ||
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 0}}, | ||
"timeit": true | ||
} | ||
} | ||
] | ||
} |
41 changes: 41 additions & 0 deletions
41
.../cifar10/cifar10-splitnn/job_configs/cifar10_splitnn/site-2/config/config_fed_client.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"format_version": 2, | ||
|
||
"DATASET_ROOT": "/tmp/cifar10", | ||
"INTERSECTION_FILE": "site-2-intersection.txt", | ||
|
||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"_splitnn_task_init_model_", | ||
"_splitnn_task_train_" | ||
], | ||
"executor": { | ||
"id": "Executor", | ||
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor", | ||
"args": { | ||
"learner_id": "cifar10-learner" | ||
} | ||
} | ||
} | ||
], | ||
|
||
"task_result_filters": [ | ||
], | ||
"task_data_filters": [ | ||
], | ||
|
||
"components": [ | ||
{ | ||
"id": "cifar10-learner", | ||
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN", | ||
"args": { | ||
"dataset_root": "{DATASET_ROOT}", | ||
"intersection_file": "{INTERSECTION_FILE}", | ||
"lr": 1e-2, | ||
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 1}}, | ||
"timeit": true | ||
} | ||
} | ||
] | ||
} |
6 changes: 6 additions & 0 deletions
6
examples/cifar10/cifar10-splitnn/virtualenv/min-requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
nvflare>=2.3.0 | ||
torch | ||
torchvision | ||
tensorboard | ||
openmined.psi | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
|
||
export projectname='nvflare_cifar10' | ||
|
||
python3 -m venv ${projectname} | ||
source ${projectname}/bin/activate |
480 changes: 480 additions & 0 deletions
480
examples/cifar10/pt/learners/cifar10_learner_splitnn.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import numpy as np | ||
import torch | ||
from pt.networks.cifar10_nets import ModerateCNN | ||
from pt.utils.cifar10_dataset import CIFAR10SplitNN | ||
from torch import optim | ||
from torch.utils.tensorboard import SummaryWriter | ||
from torchvision import datasets, transforms | ||
|
||
|
||
# TODO: maybe only use the part net that is being used | ||
# rather than inheriting the full net | ||
class SplitNN(ModerateCNN): | ||
def __init__(self, split_id): | ||
super().__init__() | ||
if split_id not in [0, 1]: | ||
raise ValueError(f"Only supports split_id '0' or '1' but was {self.split_id}") | ||
self.split_id = split_id | ||
|
||
if self.split_id == 0: | ||
self.split_forward = self.conv_layer | ||
elif self.split_id == 1: | ||
self.split_forward = self.fc_layer | ||
else: | ||
raise ValueError(f"Expected split_id to be '0' or '1' but was {self.split_id}") | ||
|
||
def forward(self, x): | ||
x = self.split_forward(x) | ||
return x | ||
|
||
def get_split_id(self): | ||
return self.split_id | ||
|
||
|
||
""" TESTING """ | ||
|
||
|
||
def print_grads(net): | ||
for name, param in net.named_parameters(): | ||
if param.grad is not None: | ||
print(name, "grad", param.grad.shape, torch.sum(param.grad).item()) | ||
else: | ||
print(name, "grad", None) | ||
|
||
|
||
def test_splitnn(): # TODO: move to unit testing | ||
"""Test SplitNN""" | ||
|
||
lr = 1e-2 | ||
epoch_max = 20 | ||
bs = 64 | ||
|
||
train_size = 50000 | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
net1 = SplitNN(split_id=0).to(device) | ||
net2 = SplitNN(split_id=1).to(device) | ||
|
||
optim1 = optim.SGD(net1.parameters(), lr=lr, momentum=0.9) | ||
optim2 = optim.SGD(net2.parameters(), lr=lr, momentum=0.9) | ||
|
||
transform_train = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.ToPILImage(), | ||
transforms.Pad(4, padding_mode="reflect"), | ||
transforms.RandomCrop(32), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
std=[x / 255.0 for x in [63.0, 62.1, 66.7]], | ||
), | ||
] | ||
) | ||
transform_valid = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
std=[x / 255.0 for x in [63.0, 62.1, 66.7]], | ||
), | ||
] | ||
) | ||
train_dataset = datasets.CIFAR10( | ||
root="/tmp/cifar10_vertical", | ||
train=True, | ||
download=True, | ||
transform=transform_train, | ||
) | ||
valid_dataset = datasets.CIFAR10( | ||
root="/tmp/cifar10_vertical", | ||
train=False, | ||
download=False, | ||
transform=transform_valid, | ||
) | ||
|
||
train_image_dataset = CIFAR10SplitNN( | ||
root="/tmp/cifar10_vertical", train=True, download=True, transform=transform_train, returns="image" | ||
) | ||
train_label_dataset = CIFAR10SplitNN( | ||
root="/tmp/cifar10_vertical", train=True, download=True, transform=transform_train, returns="label" | ||
) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=2) | ||
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=2) | ||
|
||
def valid(model1, model2, data_loader, device): | ||
model1.eval() | ||
model2.eval() | ||
with torch.no_grad(): | ||
correct, total = 0, 0 | ||
for _, (inputs, labels) in enumerate(data_loader): | ||
inputs, labels = inputs.to(device), labels.to(device) | ||
outputs = model1(inputs) | ||
outputs = model2(outputs) | ||
_, pred_label = torch.max(outputs.data, 1) | ||
|
||
total += inputs.data.size()[0] | ||
correct += (pred_label == labels.data).sum().item() | ||
metric = correct / float(total) | ||
return metric | ||
|
||
def train(inputs, targets, debug=False): | ||
# See also | ||
# https://github.com/Koukyosyumei/Attack_SplitNN/blob/main/src/attacksplitnn/splitnn/model.py | ||
|
||
"""Compute on site-1""" | ||
net1.train() | ||
optim1.zero_grad() | ||
|
||
x = net1.forward(inputs) # keep on site-1 | ||
x_sent = x.detach().requires_grad_() | ||
# send to site-2 | ||
|
||
""" Compute on site-2 """ | ||
net2.train() | ||
optim2.zero_grad() | ||
|
||
pred = net2.forward(x_sent) | ||
|
||
loss = criterion(pred, targets) | ||
|
||
loss.backward() | ||
optim2.step() | ||
|
||
return_grad = x_sent.grad | ||
|
||
if debug: | ||
print("return_grad", return_grad.shape, torch.sum(return_grad)) | ||
|
||
print("====== net2 grad: ======") | ||
print_grads(net2) | ||
|
||
# send gradients to site-1 | ||
|
||
""" Compute on site-1 """ | ||
x.backward(gradient=return_grad) | ||
optim1.step() | ||
|
||
if debug: | ||
print("====== net1 grad: ======") | ||
print_grads(net1) | ||
|
||
return loss.item() | ||
|
||
# main training loop | ||
writer = SummaryWriter("./") | ||
# epoch_len = len(train_loader) | ||
epoch_len = int(train_size / bs) | ||
for e in range(epoch_max): | ||
epoch_loss = 0 | ||
# for i, (inputs, targets) in enumerate(train_loader): | ||
# epoch_loss += train(inputs=inputs.to(device), targets=targets.to(device)) | ||
|
||
for i in range(epoch_len): | ||
batch_indices = np.random.randint(0, train_size - 1, bs) | ||
inputs = train_image_dataset.get_batch(batch_indices) | ||
targets = train_label_dataset.get_batch(batch_indices) | ||
loss = train(inputs=inputs.to(device), targets=targets.to(device)) | ||
epoch_loss += loss | ||
writer.add_scalar("loss", loss, e * epoch_len + i) | ||
|
||
train_acc = valid(net1, net2, train_loader, device) | ||
val_acc = valid(net1, net2, valid_loader, device) | ||
print( | ||
f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, " | ||
f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}" | ||
) | ||
writer.add_scalar("train_acc", train_acc, e) | ||
writer.add_scalar("val_acc", val_acc, e) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_splitnn() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os.path | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from nvflare.app_common.psi.psi_spec import PSI | ||
|
||
|
||
class Cifar10LocalPSI(PSI): | ||
def __init__(self, psi_writer_id: str, data_path: str = "/tmp/data.csv"): | ||
super().__init__(psi_writer_id) | ||
self.data_path = data_path | ||
self.data = {} | ||
|
||
if not os.path.isfile(self.data_path): | ||
raise RuntimeError(f"invalid data path {data_path}") | ||
|
||
def load_items(self) -> List[str]: | ||
_ext = os.path.splitext(self.data_path)[1] | ||
|
||
items = np.load(self.data_path) | ||
|
||
# important the PSI algorithms requires the items are unique strings | ||
if len(np.unique(items)) != len(items): | ||
raise ValueError("Expected all items to be unique!") | ||
|
||
return [str(i) for i in items] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import sys | ||
|
||
from cifar10_vertical_data_splitter import Cifar10VerticalDataSplitter | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
|
||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | ||
import argparse | ||
|
||
from nvflare.apis.fl_constant import ReservedKey | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--split_dir", type=str, default="/tmp/cifar10_vert_splits", help="output folder") | ||
parser.add_argument("--overlap", type=int, default=10_000, help="number of overlapping samples") | ||
args = parser.parse_args() | ||
|
||
splitter = Cifar10VerticalDataSplitter(split_dir=args.split_dir, overlap=args.overlap) | ||
|
||
# set up a dummy context for logging | ||
fl_ctx = FLContext() | ||
fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "local") | ||
fl_ctx.set_prop(ReservedKey.RUN_NUM, "_") | ||
|
||
splitter.split(fl_ctx) # will download to CIFAR10_ROOT defined in | ||
# Cifar10DataSplitter | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
94 changes: 94 additions & 0 deletions
94
examples/cifar10/pt/utils/cifar10_vertical_data_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import os | ||
|
||
import numpy as np | ||
from cifar10_data_splitter import Cifar10DataSplitter as Splitter | ||
|
||
from nvflare.apis.event_type import EventType | ||
from nvflare.apis.fl_component import FLComponent | ||
from nvflare.apis.fl_context import FLContext | ||
|
||
CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments | ||
|
||
|
||
class Cifar10VerticalDataSplitter(FLComponent): | ||
def __init__(self, split_dir: str = None, overlap: int = 10_000, seed: int = 0): | ||
super().__init__() | ||
self.split_dir = split_dir | ||
self.overlap = overlap | ||
self.seed = seed | ||
|
||
if self.split_dir is None: | ||
raise ValueError("You need to define a valid `split_dir` when splitting the data.") | ||
if overlap <= 0: | ||
raise ValueError(f"Alpha should be larger 0 but was {overlap}!") | ||
|
||
def handle_event(self, event_type: str, fl_ctx: FLContext): | ||
if event_type == EventType.START_RUN: | ||
self.split(fl_ctx) | ||
|
||
def split(self, fl_ctx: FLContext): | ||
np.random.seed(self.seed) | ||
|
||
self.log_info(fl_ctx, f"Partition CIFAR-10 dataset into vertically with {self.overlap} overlapping samples.") | ||
site_idx, class_sum = self._split_data() | ||
|
||
# write to files | ||
if not os.path.isdir(self.split_dir): | ||
os.makedirs(self.split_dir) | ||
sum_file_name = os.path.join(self.split_dir, "summary.txt") | ||
with open(sum_file_name, "w") as sum_file: | ||
sum_file.write("Class counts for overlap: \n") | ||
sum_file.write(json.dumps(class_sum)) | ||
|
||
for _site, _idx in site_idx.items(): | ||
site_file_name = os.path.join(self.split_dir, f"{_site}.npy") | ||
self.log_info(fl_ctx, f"save {site_file_name}") | ||
np.save(site_file_name, _idx) | ||
|
||
def _split_data(self): | ||
train_label = Splitter.load_cifar10_data() | ||
|
||
n_samples = len(train_label) | ||
|
||
if self.overlap > n_samples: | ||
raise ValueError( | ||
f"Chosen overlap of {self.overlap} is larger than " f"train dataset with {n_samples} entries." | ||
) | ||
|
||
sample_idx = np.arange(0, n_samples) | ||
|
||
overlap_idx = np.random.choice(sample_idx, size=np.int64(self.overlap), replace=False) | ||
|
||
remain_idx = list(set(sample_idx) - set(overlap_idx)) | ||
|
||
idx_1 = np.concatenate((overlap_idx, np.array(remain_idx))) | ||
# adding n_samples to remain_idx of site-2 to make sure no overlap | ||
# with idx_1 | ||
idx_2 = np.concatenate((overlap_idx, np.array(remain_idx) + n_samples)) | ||
|
||
# shuffle indexes again for client sites to simulate real world | ||
# scenario | ||
np.random.shuffle(idx_1) | ||
np.random.shuffle(idx_2) | ||
|
||
site_idx = {"overlap": overlap_idx, "site-1": idx_1, "site-2": idx_2} | ||
|
||
# collect class summary | ||
class_sum = Splitter.get_site_class_summary(train_label, {"overlap": overlap_idx}) | ||
|
||
return site_idx, class_sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config_file", | ||
type=str, | ||
default="./config_fed_client.json", | ||
help="config file in JSON format", | ||
) | ||
parser.add_argument( | ||
"--intersection_file", | ||
type=str, | ||
help="Intersection file with overlapping data indices", | ||
) | ||
args = parser.parse_args() | ||
|
||
with open(args.config_file, "r") as f: | ||
config = json.load(f) | ||
|
||
config["INTERSECTION_FILE"] = args.intersection_file | ||
|
||
with open(args.config_file, "w") as f: | ||
json.dump(config, f, indent=4) | ||
|
||
print(f"Modified {args.config_file} to use INTERSECTION_FILE={config['INTERSECTION_FILE']}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Vertical Federated Learning | ||
|
||
## Split Learning | ||
### Split learning with CIFAR-10 | ||
This [example](../cifar10/cifar10-splitnn/README.md) includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) using the CIFAR-10 dataset and the FL simulator. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from pt.learners.cifar10_learner_splitnn import SplitNNConstants | ||
|
||
from nvflare.apis.event_type import EventType | ||
from nvflare.apis.executor import Executor | ||
from nvflare.apis.fl_constant import ReturnCode | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable, make_reply | ||
from nvflare.apis.signal import Signal | ||
from nvflare.app_common.abstract.learner_spec import Learner | ||
|
||
|
||
class SplitNNLearnerExecutor(Executor): | ||
def __init__( | ||
self, | ||
learner_id, | ||
init_model_task_name=SplitNNConstants.TASK_INIT_MODEL, | ||
train_task_name=SplitNNConstants.TASK_TRAIN, | ||
): | ||
"""Key component to run learner on clients. | ||
Args: | ||
learner_id (str): id pointing to the learner object | ||
train_task_name (str, optional): label to dispatch train task. Defaults to AppConstants.TASK_TRAIN. | ||
submit_model_task_name (str, optional): label to dispatch submit model task. Defaults to AppConstants.TASK_SUBMIT_MODEL. | ||
validate_task_name (str, optional): label to dispatch validation task. Defaults to AppConstants.TASK_VALIDATION. | ||
""" | ||
super().__init__() | ||
self.learner_id = learner_id | ||
self.learner = None | ||
self.init_model_task_name = init_model_task_name | ||
self.train_task_name = train_task_name | ||
|
||
def handle_event(self, event_type: str, fl_ctx: FLContext): | ||
if event_type == EventType.START_RUN: | ||
self.initialize(fl_ctx) | ||
elif event_type == EventType.ABORT_TASK: | ||
try: | ||
if self.learner: | ||
self.learner.abort(fl_ctx) | ||
except Exception as e: | ||
self.log_exception(fl_ctx, f"learner abort exception: {e}") | ||
elif event_type == EventType.END_RUN: | ||
self.finalize(fl_ctx) | ||
|
||
def initialize(self, fl_ctx: FLContext): | ||
try: | ||
engine = fl_ctx.get_engine() | ||
self.learner = engine.get_component(self.learner_id) | ||
if not isinstance(self.learner, Learner): | ||
raise TypeError(f"learner must be Learner type. Got: {type(self.learner)}") | ||
self.learner.initialize(engine.get_all_components(), fl_ctx) | ||
except Exception as e: | ||
self.log_exception(fl_ctx, f"learner initialize exception: {e}") | ||
|
||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: | ||
self.log_info(fl_ctx, f"Client trainer got task: {task_name}") | ||
|
||
self.log_info(fl_ctx, f"Executing task {task_name}...") | ||
try: | ||
if task_name == self.init_model_task_name: | ||
self.log_info(fl_ctx, "Initializing model...") | ||
return self.learner.init_model(shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal) | ||
elif task_name == self.train_task_name: | ||
self.log_info(fl_ctx, "Running training...") | ||
return self.learner.train(shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal) | ||
else: | ||
self.log_error(fl_ctx, f"Could not handle task: {task_name}") | ||
return make_reply(ReturnCode.TASK_UNKNOWN) | ||
except Exception as e: | ||
# Task execution error, return EXECUTION_EXCEPTION Shareable | ||
self.log_exception(fl_ctx, f"learner execute exception: {e}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
def finalize(self, fl_ctx: FLContext): | ||
try: | ||
if self.learner: | ||
self.learner.finalize(fl_ctx) | ||
except Exception as e: | ||
self.log_exception(fl_ctx, f"learner finalize exception: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import sys | ||
|
||
import torch | ||
from nvflare.fuel.utils import fobs | ||
from typing import Type, Any | ||
import numpy as np | ||
from io import BytesIO | ||
|
||
|
||
class TensorDecomposer(fobs.Decomposer): | ||
def supported_type(self): | ||
return torch.Tensor | ||
|
||
def decompose(self, target: torch.Tensor) -> Any: | ||
stream = BytesIO() | ||
# torch.save uses Pickle so converting Tensor to ndarray first | ||
array = target.detach().cpu().numpy() | ||
np.save(stream, array, allow_pickle=False) | ||
return stream.getvalue() | ||
|
||
def recompose(self, data: Any) -> torch.Tensor: | ||
stream = BytesIO(data) | ||
array = np.load(stream, allow_pickle=False) | ||
return torch.from_numpy(array) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvflare.apis.client import Client | ||
from nvflare.apis.fl_constant import ReturnCode | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.impl.controller import ClientTask, Controller, Task | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.apis.signal import Signal | ||
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor | ||
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator | ||
from nvflare.app_common.app_constant import AppConstants | ||
from nvflare.app_common.app_event_type import AppEventType | ||
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector | ||
|
||
|
||
class SplitNNDataKind(object): | ||
ACTIVATIONS = "_splitnn_activations_" | ||
GRADIENT = "_splitnn_gradient_" | ||
|
||
|
||
class SplitNNConstants(object): | ||
BATCH_INDICES = "_splitnn_batch_indices_" | ||
DATA = "_splitnn_data_" | ||
BATCH_SIZE = "_splitnn_batch_size_" | ||
TARGET_NAMES = "_splitnn_target_names_" | ||
|
||
TASK_INIT_MODEL = "_splitnn_task_init_model_" | ||
TASK_LABEL_STEP = "_splitnn_task_label_step_" | ||
TASK_TRAIN = "_splitnn_task_train_" | ||
|
||
TASK_RESULT = "_splitnn_task_result_" | ||
TIMEOUT = 60.0 # timeout for waiting for reply from aux message request | ||
|
||
|
||
class SplitNNController(Controller): | ||
def __init__( | ||
self, | ||
num_rounds: int = 5000, | ||
start_round: int = 0, | ||
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, # used to init the models on both clients | ||
shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, | ||
init_model_task_name=SplitNNConstants.TASK_INIT_MODEL, | ||
train_task_name=SplitNNConstants.TASK_TRAIN, | ||
task_timeout: int = 10, | ||
ignore_result_error: bool = True, | ||
batch_size: int = 256, | ||
): | ||
"""The controller for Split Learning Workflow. | ||
The SplitNNController workflow defines Federated training on all clients. | ||
The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. | ||
Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The | ||
shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. | ||
The model_persistor also saves the model after training. | ||
Args: | ||
num_rounds (int, optional): The total number of training rounds. Defaults to 5. | ||
start_round (int, optional): Start round for training. Defaults to 0. | ||
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". | ||
shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator". | ||
init_model_task_name: Task name used to initialize the local models. | ||
train_task_name: Task name used for split learning. | ||
task_timeout (int, optional): timeout (in sec) to determine if one client fails | ||
to request the task which it is assigned to. Defaults to 10. | ||
ignore_result_error (bool, optional): whether this controller can proceed if result has errors. Defaults to True. | ||
Raises: | ||
TypeError: when any of input arguments does not have correct type | ||
ValueError: when any of input arguments is out of range | ||
""" | ||
Controller.__init__(self) | ||
|
||
# Check arguments | ||
if not isinstance(num_rounds, int): | ||
raise TypeError("`num_rounds` must be int but got {}".format(type(num_rounds))) | ||
if not isinstance(start_round, int): | ||
raise TypeError("`start_round` must be int but got {}".format(type(start_round))) | ||
if not isinstance(task_timeout, int): | ||
raise TypeError("`train_timeout` must be int but got {}".format(type(task_timeout))) | ||
if not isinstance(persistor_id, str): | ||
raise TypeError("`persistor_id` must be a string but got {}".format(type(persistor_id))) | ||
if not isinstance(shareable_generator_id, str): | ||
raise TypeError("`shareable_generator_id` must be a string but got {}".format(type(shareable_generator_id))) | ||
if not isinstance(init_model_task_name, str): | ||
raise TypeError("`init_model_task_name` must be a string but got {}".format(type(init_model_task_name))) | ||
if not isinstance(train_task_name, str): | ||
raise TypeError("`train_task_name` must be a string but got {}".format(type(train_task_name))) | ||
if num_rounds < 0: | ||
raise ValueError("num_rounds must be greater than or equal to 0.") | ||
if start_round < 0: | ||
raise ValueError("start_round must be greater than or equal to 0.") | ||
|
||
self.persistor_id = persistor_id | ||
self.shareable_generator_id = shareable_generator_id | ||
self.persistor = None | ||
self.shareable_generator = None | ||
|
||
# config data | ||
self._num_rounds = num_rounds | ||
self._start_round = start_round | ||
self._task_timeout = task_timeout | ||
self.ignore_result_error = ignore_result_error | ||
|
||
# workflow phases: init, train, validate | ||
self._phase = AppConstants.PHASE_INIT | ||
self._global_weights = None | ||
self._current_round = None | ||
|
||
# task names | ||
self.init_model_task_name = init_model_task_name | ||
self.train_task_name = train_task_name | ||
|
||
self.targets_names = ["site-1", "site-2"] | ||
self.nr_supported_clients = 2 | ||
self.batch_size = batch_size | ||
|
||
def start_controller(self, fl_ctx: FLContext): | ||
self.log_debug(fl_ctx, "starting controller") | ||
self.persistor = fl_ctx.get_engine().get_component(self.persistor_id) | ||
self.shareable_generator = fl_ctx.get_engine().get_component(self.shareable_generator_id) | ||
if not isinstance(self.persistor, LearnablePersistor): | ||
self.system_panic( | ||
f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}", fl_ctx | ||
) | ||
if not isinstance(self.shareable_generator, ShareableGenerator): | ||
self.system_panic( | ||
f"Shareable generator {self.shareable_generator_id} must be a Shareable Generator instance, " | ||
f"but got {type(self.shareable_generator)}", | ||
fl_ctx, | ||
) | ||
|
||
# initialize global model | ||
fl_ctx.set_prop(AppConstants.START_ROUND, self._start_round, private=True, sticky=True) | ||
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False) | ||
self._global_weights = self.persistor.load(fl_ctx) | ||
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True) | ||
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx) | ||
|
||
def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> bool: | ||
# submitted shareable is stored in client_task.result | ||
# we need to update task.data with that shareable so the next target | ||
# will get the updated shareable | ||
task = client_task.task | ||
result = client_task.result | ||
rc = result.get_return_code() | ||
|
||
if rc and rc != ReturnCode.OK: | ||
if self.ignore_result_error: | ||
self.log_error(fl_ctx, f"Ignore the task {task} result. Train result error code: {rc}") | ||
return False | ||
else: | ||
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]: | ||
self.system_panic( | ||
f"Peer context for task {task} is bad or missing. SplitNNController exiting.", fl_ctx=fl_ctx | ||
) | ||
return False | ||
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]: | ||
self.system_panic( | ||
f"Execution Exception in client task {task}. SplitNNController exiting.", fl_ctx=fl_ctx | ||
) | ||
return False | ||
elif rc in [ | ||
ReturnCode.EXECUTION_RESULT_ERROR, | ||
ReturnCode.TASK_DATA_FILTER_ERROR, | ||
ReturnCode.TASK_RESULT_FILTER_ERROR, | ||
]: | ||
self.system_panic( | ||
f"Execution result for task {task} is not a shareable. SplitNNController exiting.", | ||
fl_ctx=fl_ctx, | ||
) | ||
return False | ||
|
||
# assign result to current task | ||
if result: | ||
task.set_prop(SplitNNConstants.TASK_RESULT, result) | ||
|
||
return True | ||
|
||
def _check_targets(self, fl_ctx: FLContext): | ||
engine = fl_ctx.get_engine() | ||
targets = engine.get_clients() | ||
for t in targets: | ||
if t.name not in self.targets_names: | ||
self.system_panic(f"Client {t.name} not in expected target names: {self.targets_names}", fl_ctx) | ||
|
||
def _init_models(self, abort_signal: Signal, fl_ctx: FLContext): | ||
self._check_targets(fl_ctx) | ||
self.log_debug(fl_ctx, f"SplitNN initializing model {self.targets_names}.") | ||
|
||
# Create init_model_task_name | ||
data_shareable: Shareable = self.shareable_generator.learnable_to_shareable(self._global_weights, fl_ctx) | ||
task = Task( | ||
name=self.init_model_task_name, | ||
data=data_shareable, | ||
result_received_cb=self._process_result, | ||
) | ||
|
||
self.broadcast_and_wait( | ||
task=task, | ||
min_responses=self.nr_supported_clients, | ||
wait_time_after_min_received=0, | ||
fl_ctx=fl_ctx, | ||
abort_signal=abort_signal, | ||
) | ||
|
||
def _train(self, abort_signal: Signal, fl_ctx: FLContext): | ||
self._check_targets(fl_ctx) | ||
self.log_debug(fl_ctx, f"SplitNN training starting with {self.targets_names}.") | ||
|
||
# Create train_task | ||
data_shareable: Shareable = Shareable() | ||
data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) | ||
data_shareable.set_header(SplitNNConstants.BATCH_SIZE, self.batch_size) | ||
data_shareable.set_header(SplitNNConstants.TARGET_NAMES, self.targets_names) | ||
|
||
task = Task( | ||
name=self.train_task_name, | ||
data=data_shareable, | ||
result_received_cb=self._process_result, | ||
) | ||
|
||
self.broadcast_and_wait( | ||
task=task, | ||
min_responses=self.nr_supported_clients, | ||
wait_time_after_min_received=0, | ||
fl_ctx=fl_ctx, | ||
abort_signal=abort_signal, | ||
) | ||
|
||
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): | ||
try: | ||
self._check_targets(fl_ctx) | ||
self.log_debug(fl_ctx, f"Train with on {self.targets_names}") | ||
|
||
# 1. initialize models on clients | ||
self._init_models(abort_signal=abort_signal, fl_ctx=fl_ctx) | ||
|
||
# 2. Start split learning | ||
self._phase = AppConstants.PHASE_TRAIN | ||
self._train(abort_signal=abort_signal, fl_ctx=fl_ctx) | ||
|
||
self._phase = AppConstants.PHASE_FINISHED | ||
self.log_debug(fl_ctx, "SplitNN training ended.") | ||
except BaseException as e: | ||
error_msg = f"SplitNN control_flow exception {e}" | ||
self.log_error(fl_ctx, error_msg) | ||
self.system_panic(str(e), fl_ctx) | ||
|
||
def stop_controller(self, fl_ctx: FLContext): | ||
self._phase = AppConstants.PHASE_FINISHED | ||
self.log_debug(fl_ctx, "controller stopped") | ||
|
||
def process_result_of_unknown_task( | ||
self, | ||
client: Client, | ||
task_name: str, | ||
client_task_id: str, | ||
result: Shareable, | ||
fl_ctx: FLContext, | ||
): | ||
self.log_warning(fl_ctx, f"Dropped result of unknown task: {task_name} from client {client.name}.") | ||
|
||
def handle_event(self, event_type: str, fl_ctx: FLContext): | ||
super().handle_event(event_type, fl_ctx) | ||
if event_type == InfoCollector.EVENT_TYPE_GET_STATS: | ||
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None) | ||
if collector: | ||
if not isinstance(collector, GroupInfoCollector): | ||
raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector))) | ||
|
||
collector.add_info( | ||
group_name=self._name, | ||
info={"phase": self._phase, "current_round": self._current_round, "num_rounds": self._num_rounds}, | ||
) |