Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial cifar10 split learning
Browse files Browse the repository at this point in the history
run training with higher timeout value

rename classes

add training accuracy

remove printouts

add figure

update printouts

update readme

vertical split data

run psi

add notebook

update notebook

update notebook

take intersection.txt as input for split-learning

configure overlap

add todo

refactor to use FCI
holgerroth committed Feb 1, 2023

Unverified

This user has not yet uploaded their public signing key.
1 parent c3082d1 commit cf580dc
Showing 28 changed files with 3,160 additions and 15 deletions.
4 changes: 4 additions & 0 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -12,3 +12,7 @@ This example runs you through the process and includes instructions on running
[FedAvg](https://arxiv.org/abs/1602.05629) with streaming of TensorBoard metrics to the server during training
and [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/)
for secure server-side aggregation.

### [Split learning with CIFAR-10](./cifar10-splitnn/README.md)
This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060)
using the CIFAR-10 dataset and the FL simulator in a vertical FL scenario.
23 changes: 23 additions & 0 deletions examples/cifar10/cifar10-splitnn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ide
.idea/
.ipynb_checkpoints/

# nvflare artifacts
log.txt
client_token.txt
*.fl
audit.log
transfer
workspaces

# python
__pycache__
.pyc

# virtual environments
nvflare_cifar10

# data
dataset
dataset*
*results*
43 changes: 43 additions & 0 deletions examples/cifar10/cifar10-splitnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Split Learning with CIFAR-10

This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the CIFAR-10 dataset and the FL simulator.

We assume one client holds the images, and the other clients holds the labels to compute losses and accuracy metrics.
Activations and corresponding gradients are being exchanged between the clients through the NVFlare server.

<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/>

For instructions of how to run CIFAR-10 in real-world deployment settings,
see the example on ["Real-world Federated Learning with CIFAR-10"](../cifar10-real-world/README.md).

## (Optional) Set up a virtual environment
```
python3 -m pip install --user --upgrade pip
python3 -m pip install --user virtualenv
```
(If needed) make all shell scripts executable using
```
find . -name ".sh" -exec chmod +x {} \;
```
initialize virtual environment.
```
source ./virtualenv/set_env.sh
```
install required packages for training
```
pip install --upgrade pip
pip install -r ./virtualenv/min-requirements.txt
```

## Start Jupyter notebook
Set `PYTHONPATH` to include custom files of this example:
```
export PYTHONPATH=${PWD}/..
```
Start a Jupyter Lab
```
jupyter lab .
```
and open [cifar10_split_learning.ipynb](./cifar10_split_learning.ipynb).

See [here](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) for installing Jupyter Lab.
1,439 changes: 1,439 additions & 0 deletions examples/cifar10/cifar10-splitnn/cifar10_split_learning.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/cifar10/cifar10-splitnn/figs/split_learning.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "cifar10_psi",
"deploy_map": {
"server": ["server"],
"site-1": ["site-1"],
"site-2": ["site-2"]
},
"min_clients": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"format_version": 2,
"workflows": [
{
"id": "DhPSIController",
"path": "nvflare.app_common.workflows.dh_psi_controller.DhPSIController",
"args": {
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"PSI"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor",
"args": {
"local_psi_id": "local_psi"
}
}
}
],
"components": [
{
"id": "local_psi",
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI",
"args": {
"psi_writer_id": "psi_writer",
"data_path": "/tmp/cifar10_vert_splits/site-1.npy"
}
},
{
"id": "psi_writer",
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter",
"args": {
"output_path": "psi/intersection.txt"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"PSI"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.psi.psi_executor.PSIExecutor",
"args": {
"local_psi_id": "local_psi"
}
}
}
],
"components": [
{
"id": "local_psi",
"path": "pt.utils.cifar10_local_psi.Cifar10LocalPSI",
"args": {
"psi_writer_id": "psi_writer",
"data_path": "/tmp/cifar10_vert_splits/site-2.npy"
}
},
{
"id": "psi_writer",
"path": "nvflare.app_common.psi.psi_file_writer.FilePsiWriter",
"args": {
"output_path": "psi/intersection.txt"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "cifar10_splitnn",
"deploy_map": {
"server": ["server"],
"site-1": ["site-1"],
"site-2": ["site-2"]
},
"min_clients": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"format_version": 2,

"num_rounds": 15625,
"batch_size": 64,

"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "persistor",
"path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor",
"args": {
"model": {
"path": "pt.networks.cifar10_nets.ModerateCNN"
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "json_generator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
{
"id": "splitnn_ctl",
"path": "nvflare.app_common.workflows.splitnn_workflow.SplitNNController",
"args": {
"num_rounds" : "{num_rounds}",
"batch_size": "{batch_size}",
"start_round": 0,
"persistor_id": "persistor",
"task_timeout": 0,
"shareable_generator_id": "shareable_generator"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"format_version": 2,

"DATASET_ROOT": "/tmp/cifar10",
"INTERSECTION_FILE": "site-1-intersection.txt",

"executors": [
{
"tasks": [
"_splitnn_task_init_model_",
"_splitnn_task_train_"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],

"task_result_filters": [
],
"task_data_filters": [
],

"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
"args": {
"dataset_root": "{DATASET_ROOT}",
"intersection_file": "{INTERSECTION_FILE}",
"lr": 1e-2,
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 0}},
"timeit": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"format_version": 2,

"DATASET_ROOT": "/tmp/cifar10",
"INTERSECTION_FILE": "site-2-intersection.txt",

"executors": [
{
"tasks": [
"_splitnn_task_init_model_",
"_splitnn_task_train_"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.splitnn_learner_executor.SplitNNLearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],

"task_result_filters": [
],
"task_data_filters": [
],

"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
"args": {
"dataset_root": "{DATASET_ROOT}",
"intersection_file": "{INTERSECTION_FILE}",
"lr": 1e-2,
"model": {"path": "pt.networks.split_nn.SplitNN", "args": {"split_id": 1}},
"timeit": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nvflare>=2.3.0
torch
torchvision
tensorboard
openmined.psi
pandas
6 changes: 6 additions & 0 deletions examples/cifar10/cifar10-splitnn/virtualenv/set_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

export projectname='nvflare_cifar10'

python3 -m venv ${projectname}
source ${projectname}/bin/activate
480 changes: 480 additions & 0 deletions examples/cifar10/pt/learners/cifar10_learner_splitnn.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/cifar10/pt/networks/cifar10_nets.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ def __init__(self):
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(), # added to support split learning
)

self.fc_layer = nn.Sequential(
196 changes: 196 additions & 0 deletions examples/cifar10/pt/networks/split_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import numpy as np
import torch
from pt.networks.cifar10_nets import ModerateCNN
from pt.utils.cifar10_dataset import CIFAR10SplitNN
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms


# TODO: maybe only use the part net that is being used
# rather than inheriting the full net
class SplitNN(ModerateCNN):
def __init__(self, split_id):
super().__init__()
if split_id not in [0, 1]:
raise ValueError(f"Only supports split_id '0' or '1' but was {self.split_id}")
self.split_id = split_id

if self.split_id == 0:
self.split_forward = self.conv_layer
elif self.split_id == 1:
self.split_forward = self.fc_layer
else:
raise ValueError(f"Expected split_id to be '0' or '1' but was {self.split_id}")

def forward(self, x):
x = self.split_forward(x)
return x

def get_split_id(self):
return self.split_id


""" TESTING """


def print_grads(net):
for name, param in net.named_parameters():
if param.grad is not None:
print(name, "grad", param.grad.shape, torch.sum(param.grad).item())
else:
print(name, "grad", None)


def test_splitnn(): # TODO: move to unit testing
"""Test SplitNN"""

lr = 1e-2
epoch_max = 20
bs = 64

train_size = 50000

criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net1 = SplitNN(split_id=0).to(device)
net2 = SplitNN(split_id=1).to(device)

optim1 = optim.SGD(net1.parameters(), lr=lr, momentum=0.9)
optim2 = optim.SGD(net2.parameters(), lr=lr, momentum=0.9)

transform_train = transforms.Compose(
[
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Pad(4, padding_mode="reflect"),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
),
]
)
transform_valid = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
),
]
)
train_dataset = datasets.CIFAR10(
root="/tmp/cifar10_vertical",
train=True,
download=True,
transform=transform_train,
)
valid_dataset = datasets.CIFAR10(
root="/tmp/cifar10_vertical",
train=False,
download=False,
transform=transform_valid,
)

train_image_dataset = CIFAR10SplitNN(
root="/tmp/cifar10_vertical", train=True, download=True, transform=transform_train, returns="image"
)
train_label_dataset = CIFAR10SplitNN(
root="/tmp/cifar10_vertical", train=True, download=True, transform=transform_train, returns="label"
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=2)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=2)

def valid(model1, model2, data_loader, device):
model1.eval()
model2.eval()
with torch.no_grad():
correct, total = 0, 0
for _, (inputs, labels) in enumerate(data_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model1(inputs)
outputs = model2(outputs)
_, pred_label = torch.max(outputs.data, 1)

total += inputs.data.size()[0]
correct += (pred_label == labels.data).sum().item()
metric = correct / float(total)
return metric

def train(inputs, targets, debug=False):
# See also
# https://github.com/Koukyosyumei/Attack_SplitNN/blob/main/src/attacksplitnn/splitnn/model.py

"""Compute on site-1"""
net1.train()
optim1.zero_grad()

x = net1.forward(inputs) # keep on site-1
x_sent = x.detach().requires_grad_()
# send to site-2

""" Compute on site-2 """
net2.train()
optim2.zero_grad()

pred = net2.forward(x_sent)

loss = criterion(pred, targets)

loss.backward()
optim2.step()

return_grad = x_sent.grad

if debug:
print("return_grad", return_grad.shape, torch.sum(return_grad))

print("====== net2 grad: ======")
print_grads(net2)

# send gradients to site-1

""" Compute on site-1 """
x.backward(gradient=return_grad)
optim1.step()

if debug:
print("====== net1 grad: ======")
print_grads(net1)

return loss.item()

# main training loop
writer = SummaryWriter("./")
# epoch_len = len(train_loader)
epoch_len = int(train_size / bs)
for e in range(epoch_max):
epoch_loss = 0
# for i, (inputs, targets) in enumerate(train_loader):
# epoch_loss += train(inputs=inputs.to(device), targets=targets.to(device))

for i in range(epoch_len):
batch_indices = np.random.randint(0, train_size - 1, bs)
inputs = train_image_dataset.get_batch(batch_indices)
targets = train_label_dataset.get_batch(batch_indices)
loss = train(inputs=inputs.to(device), targets=targets.to(device))
epoch_loss += loss
writer.add_scalar("loss", loss, e * epoch_len + i)

train_acc = valid(net1, net2, train_loader, device)
val_acc = valid(net1, net2, valid_loader, device)
print(
f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, "
f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
)
writer.add_scalar("train_acc", train_acc, e)
writer.add_scalar("val_acc", val_acc, e)


if __name__ == "__main__":
test_splitnn()
31 changes: 16 additions & 15 deletions examples/cifar10/pt/utils/cifar10_data_splitter.py
Original file line number Diff line number Diff line change
@@ -50,16 +50,6 @@
CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments


def _get_site_class_summary(train_label, site_idx):
class_sum = {}

for site, data_idx in site_idx.items():
unq, unq_cnt = np.unique(train_label[data_idx], return_counts=True)
tmp = {int(unq[i]): int(unq_cnt[i]) for i in range(len(unq))}
class_sum[site] = tmp
return class_sum


class Cifar10DataSplitter(FLComponent):
def __init__(self, split_dir: str = None, num_sites: int = 8, alpha: float = 0.5, seed: int = 0):
super().__init__()
@@ -68,8 +58,10 @@ def __init__(self, split_dir: str = None, num_sites: int = 8, alpha: float = 0.5
self.alpha = alpha
self.seed = seed

if self.split_dir is None:
raise ValueError("You need to define a valid `split_dir` when splitting the data.")
if alpha < 0.0:
raise ValueError(f"Alpha should be larger 0.0 but was {alpha}!")
raise ValueError(f"Alpha should be larger or equal 0.0 but was" f" {alpha}!")

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
@@ -85,8 +77,6 @@ def split(self, fl_ctx: FLContext):
site_idx, class_sum = self._partition_data()

# write to files
if self.split_dir is None:
raise ValueError("You need to define a valid `split_dir` when splitting the data.")
if not os.path.isdir(self.split_dir):
os.makedirs(self.split_dir)
sum_file_name = os.path.join(self.split_dir, "summary.txt")
@@ -101,14 +91,25 @@ def split(self, fl_ctx: FLContext):
site_file_name = site_file_path + str(site + 1) + ".npy"
np.save(site_file_name, np.array(site_idx[site]))

def load_cifar10_data(self):
@staticmethod
def load_cifar10_data():
# download data
train_dataset = datasets.CIFAR10(root=CIFAR10_ROOT, train=True, download=True)

# only training label is needed for doing split
train_label = np.array(train_dataset.targets)
return train_label

@staticmethod
def get_site_class_summary(train_label, site_idx):
class_sum = {}

for site, data_idx in site_idx.items():
unq, unq_cnt = np.unique(train_label[data_idx], return_counts=True)
tmp = {int(unq[i]): int(unq_cnt[i]) for i in range(len(unq))}
class_sum[site] = tmp
return class_sum

def _partition_data(self):
train_label = self.load_cifar10_data()

@@ -140,6 +141,6 @@ def _partition_data(self):
site_idx[j] = idx_batch[j]

# collect class summary
class_sum = _get_site_class_summary(train_label, site_idx)
class_sum = self.get_site_class_summary(train_label, site_idx)

return site_idx, class_sum
73 changes: 73 additions & 0 deletions examples/cifar10/pt/utils/cifar10_dataset.py
Original file line number Diff line number Diff line change
@@ -56,3 +56,76 @@ def __getitem__(self, index):

def __len__(self):
return len(self.data)


class CIFAR10SplitNN(object): # TODO: use torch.utils.data.Dataset with batch sampling
def __init__(self, root, train=True, transform=None, download=False, returns="all", intersect_idx=None):
"""CIFAR-10 dataset with index to extract a mini-batch based on given batch indices
Useful for SplitNN training
Args:
root: data root
data_idx: to specify the data for a particular client site.
If index provided, extract subset, otherwise use the whole set
train: whether to use the training or validation split (default: True)
transform: image transforms
download: whether to download the data (default: False)
returns: specify which data the client has
intersect_idx: indices of samples intersecting between both
participating sites. Intersection indices will be sorted to
ensure that data is aligned on both sites.
Returns:
A PyTorch dataset
"""
self.root = root
self.train = train
self.transform = transform
self.download = download
self.returns = returns
self.intersect_idx = intersect_idx
self.orig_size = 0

if self.intersect_idx is not None:
self.intersect_idx = np.sort(self.intersect_idx).astype(np.int64)

self.data, self.target = self.__build_cifar_subset__()

def __build_cifar_subset__(self):
# if intersect index provided, extract subset, otherwise use the whole
# set
cifar_dataobj = datasets.CIFAR10(self.root, self.train, self.transform, self.download)
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
self.orig_size = len(data)
if self.intersect_idx is not None:
data = data[self.intersect_idx]
target = target[self.intersect_idx]
return data, target

def __getitem__(self, index):
img, target = self.data[index], self.target[index]
if self.transform is not None:
img = self.transform(img)
return img, target

# TODO: this can probably made more efficient using batch_sampler
def get_batch(self, batch_indices):
img_batch = []
target_batch = []
for idx in batch_indices:
img, target = self.__getitem__(idx)
img_batch.append(img)
target_batch.append(torch.tensor(target, dtype=torch.long))
img_batch = torch.stack(img_batch, dim=0)
target_batch = torch.stack(target_batch, dim=0)
if self.returns == "all":
return img_batch, target_batch
elif self.returns == "image":
return img_batch
elif self.returns == "label":
return target_batch
else:
raise ValueError(f"Expected `returns` to be 'all', 'image', or 'label', but got '{self.returns}'")

def __len__(self):
return len(self.data)
40 changes: 40 additions & 0 deletions examples/cifar10/pt/utils/cifar10_local_psi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from typing import List

import numpy as np

from nvflare.app_common.psi.psi_spec import PSI


class Cifar10LocalPSI(PSI):
def __init__(self, psi_writer_id: str, data_path: str = "/tmp/data.csv"):
super().__init__(psi_writer_id)
self.data_path = data_path
self.data = {}

if not os.path.isfile(self.data_path):
raise RuntimeError(f"invalid data path {data_path}")

def load_items(self) -> List[str]:
_ext = os.path.splitext(self.data_path)[1]

items = np.load(self.data_path)

# important the PSI algorithms requires the items are unique strings
if len(np.unique(items)) != len(items):
raise ValueError("Expected all items to be unique!")

return [str(i) for i in items]
46 changes: 46 additions & 0 deletions examples/cifar10/pt/utils/cifar10_split_data_vertical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys

from cifar10_vertical_data_splitter import Cifar10VerticalDataSplitter

from nvflare.apis.fl_context import FLContext

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
import argparse

from nvflare.apis.fl_constant import ReservedKey


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split_dir", type=str, default="/tmp/cifar10_vert_splits", help="output folder")
parser.add_argument("--overlap", type=int, default=10_000, help="number of overlapping samples")
args = parser.parse_args()

splitter = Cifar10VerticalDataSplitter(split_dir=args.split_dir, overlap=args.overlap)

# set up a dummy context for logging
fl_ctx = FLContext()
fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "local")
fl_ctx.set_prop(ReservedKey.RUN_NUM, "_")

splitter.split(fl_ctx) # will download to CIFAR10_ROOT defined in
# Cifar10DataSplitter


if __name__ == "__main__":
main()
94 changes: 94 additions & 0 deletions examples/cifar10/pt/utils/cifar10_vertical_data_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import numpy as np
from cifar10_data_splitter import Cifar10DataSplitter as Splitter

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext

CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments


class Cifar10VerticalDataSplitter(FLComponent):
def __init__(self, split_dir: str = None, overlap: int = 10_000, seed: int = 0):
super().__init__()
self.split_dir = split_dir
self.overlap = overlap
self.seed = seed

if self.split_dir is None:
raise ValueError("You need to define a valid `split_dir` when splitting the data.")
if overlap <= 0:
raise ValueError(f"Alpha should be larger 0 but was {overlap}!")

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.split(fl_ctx)

def split(self, fl_ctx: FLContext):
np.random.seed(self.seed)

self.log_info(fl_ctx, f"Partition CIFAR-10 dataset into vertically with {self.overlap} overlapping samples.")
site_idx, class_sum = self._split_data()

# write to files
if not os.path.isdir(self.split_dir):
os.makedirs(self.split_dir)
sum_file_name = os.path.join(self.split_dir, "summary.txt")
with open(sum_file_name, "w") as sum_file:
sum_file.write("Class counts for overlap: \n")
sum_file.write(json.dumps(class_sum))

for _site, _idx in site_idx.items():
site_file_name = os.path.join(self.split_dir, f"{_site}.npy")
self.log_info(fl_ctx, f"save {site_file_name}")
np.save(site_file_name, _idx)

def _split_data(self):
train_label = Splitter.load_cifar10_data()

n_samples = len(train_label)

if self.overlap > n_samples:
raise ValueError(
f"Chosen overlap of {self.overlap} is larger than " f"train dataset with {n_samples} entries."
)

sample_idx = np.arange(0, n_samples)

overlap_idx = np.random.choice(sample_idx, size=np.int64(self.overlap), replace=False)

remain_idx = list(set(sample_idx) - set(overlap_idx))

idx_1 = np.concatenate((overlap_idx, np.array(remain_idx)))
# adding n_samples to remain_idx of site-2 to make sure no overlap
# with idx_1
idx_2 = np.concatenate((overlap_idx, np.array(remain_idx) + n_samples))

# shuffle indexes again for client sites to simulate real world
# scenario
np.random.shuffle(idx_1)
np.random.shuffle(idx_2)

site_idx = {"overlap": overlap_idx, "site-1": idx_1, "site-2": idx_2}

# collect class summary
class_sum = Splitter.get_site_class_summary(train_label, {"overlap": overlap_idx})

return site_idx, class_sum
46 changes: 46 additions & 0 deletions examples/cifar10/pt/utils/set_intersection_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
type=str,
default="./config_fed_client.json",
help="config file in JSON format",
)
parser.add_argument(
"--intersection_file",
type=str,
help="Intersection file with overlapping data indices",
)
args = parser.parse_args()

with open(args.config_file, "r") as f:
config = json.load(f)

config["INTERSECTION_FILE"] = args.intersection_file

with open(args.config_file, "w") as f:
json.dump(config, f, indent=4)

print(f"Modified {args.config_file} to use INTERSECTION_FILE={config['INTERSECTION_FILE']}")


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions examples/vertical_federated_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Vertical Federated Learning

## Split Learning
### Split learning with CIFAR-10
This [example](../cifar10/cifar10-splitnn/README.md) includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) using the CIFAR-10 dataset and the FL simulator.
93 changes: 93 additions & 0 deletions nvflare/app_common/executors/splitnn_learner_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pt.learners.cifar10_learner_splitnn import SplitNNConstants

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learner_spec import Learner


class SplitNNLearnerExecutor(Executor):
def __init__(
self,
learner_id,
init_model_task_name=SplitNNConstants.TASK_INIT_MODEL,
train_task_name=SplitNNConstants.TASK_TRAIN,
):
"""Key component to run learner on clients.
Args:
learner_id (str): id pointing to the learner object
train_task_name (str, optional): label to dispatch train task. Defaults to AppConstants.TASK_TRAIN.
submit_model_task_name (str, optional): label to dispatch submit model task. Defaults to AppConstants.TASK_SUBMIT_MODEL.
validate_task_name (str, optional): label to dispatch validation task. Defaults to AppConstants.TASK_VALIDATION.
"""
super().__init__()
self.learner_id = learner_id
self.learner = None
self.init_model_task_name = init_model_task_name
self.train_task_name = train_task_name

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize(fl_ctx)
elif event_type == EventType.ABORT_TASK:
try:
if self.learner:
self.learner.abort(fl_ctx)
except Exception as e:
self.log_exception(fl_ctx, f"learner abort exception: {e}")
elif event_type == EventType.END_RUN:
self.finalize(fl_ctx)

def initialize(self, fl_ctx: FLContext):
try:
engine = fl_ctx.get_engine()
self.learner = engine.get_component(self.learner_id)
if not isinstance(self.learner, Learner):
raise TypeError(f"learner must be Learner type. Got: {type(self.learner)}")
self.learner.initialize(engine.get_all_components(), fl_ctx)
except Exception as e:
self.log_exception(fl_ctx, f"learner initialize exception: {e}")

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"Client trainer got task: {task_name}")

self.log_info(fl_ctx, f"Executing task {task_name}...")
try:
if task_name == self.init_model_task_name:
self.log_info(fl_ctx, "Initializing model...")
return self.learner.init_model(shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal)
elif task_name == self.train_task_name:
self.log_info(fl_ctx, "Running training...")
return self.learner.train(shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal)
else:
self.log_error(fl_ctx, f"Could not handle task: {task_name}")
return make_reply(ReturnCode.TASK_UNKNOWN)
except Exception as e:
# Task execution error, return EXECUTION_EXCEPTION Shareable
self.log_exception(fl_ctx, f"learner execute exception: {e}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def finalize(self, fl_ctx: FLContext):
try:
if self.learner:
self.learner.finalize(fl_ctx)
except Exception as e:
self.log_exception(fl_ctx, f"learner finalize exception: {e}")
37 changes: 37 additions & 0 deletions nvflare/app_common/pt/pt_decomposers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import torch
from nvflare.fuel.utils import fobs
from typing import Type, Any
import numpy as np
from io import BytesIO


class TensorDecomposer(fobs.Decomposer):
def supported_type(self):
return torch.Tensor

def decompose(self, target: torch.Tensor) -> Any:
stream = BytesIO()
# torch.save uses Pickle so converting Tensor to ndarray first
array = target.detach().cpu().numpy()
np.save(stream, array, allow_pickle=False)
return stream.getvalue()

def recompose(self, data: Any) -> torch.Tensor:
stream = BytesIO(data)
array = np.load(stream, allow_pickle=False)
return torch.from_numpy(array)
285 changes: 285 additions & 0 deletions nvflare/app_common/workflows/splitnn_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.client import Client
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector


class SplitNNDataKind(object):
ACTIVATIONS = "_splitnn_activations_"
GRADIENT = "_splitnn_gradient_"


class SplitNNConstants(object):
BATCH_INDICES = "_splitnn_batch_indices_"
DATA = "_splitnn_data_"
BATCH_SIZE = "_splitnn_batch_size_"
TARGET_NAMES = "_splitnn_target_names_"

TASK_INIT_MODEL = "_splitnn_task_init_model_"
TASK_LABEL_STEP = "_splitnn_task_label_step_"
TASK_TRAIN = "_splitnn_task_train_"

TASK_RESULT = "_splitnn_task_result_"
TIMEOUT = 60.0 # timeout for waiting for reply from aux message request


class SplitNNController(Controller):
def __init__(
self,
num_rounds: int = 5000,
start_round: int = 0,
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, # used to init the models on both clients
shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID,
init_model_task_name=SplitNNConstants.TASK_INIT_MODEL,
train_task_name=SplitNNConstants.TASK_TRAIN,
task_timeout: int = 10,
ignore_result_error: bool = True,
batch_size: int = 256,
):
"""The controller for Split Learning Workflow.
The SplitNNController workflow defines Federated training on all clients.
The model persistor (persistor_id) is used to load the initial global model which is sent to all clients.
Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The
shareable generator is used to convert the aggregated weights to shareable and shareable back to weights.
The model_persistor also saves the model after training.
Args:
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): Start round for training. Defaults to 0.
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator".
init_model_task_name: Task name used to initialize the local models.
train_task_name: Task name used for split learning.
task_timeout (int, optional): timeout (in sec) to determine if one client fails
to request the task which it is assigned to. Defaults to 10.
ignore_result_error (bool, optional): whether this controller can proceed if result has errors. Defaults to True.
Raises:
TypeError: when any of input arguments does not have correct type
ValueError: when any of input arguments is out of range
"""
Controller.__init__(self)

# Check arguments
if not isinstance(num_rounds, int):
raise TypeError("`num_rounds` must be int but got {}".format(type(num_rounds)))
if not isinstance(start_round, int):
raise TypeError("`start_round` must be int but got {}".format(type(start_round)))
if not isinstance(task_timeout, int):
raise TypeError("`train_timeout` must be int but got {}".format(type(task_timeout)))
if not isinstance(persistor_id, str):
raise TypeError("`persistor_id` must be a string but got {}".format(type(persistor_id)))
if not isinstance(shareable_generator_id, str):
raise TypeError("`shareable_generator_id` must be a string but got {}".format(type(shareable_generator_id)))
if not isinstance(init_model_task_name, str):
raise TypeError("`init_model_task_name` must be a string but got {}".format(type(init_model_task_name)))
if not isinstance(train_task_name, str):
raise TypeError("`train_task_name` must be a string but got {}".format(type(train_task_name)))
if num_rounds < 0:
raise ValueError("num_rounds must be greater than or equal to 0.")
if start_round < 0:
raise ValueError("start_round must be greater than or equal to 0.")

self.persistor_id = persistor_id
self.shareable_generator_id = shareable_generator_id
self.persistor = None
self.shareable_generator = None

# config data
self._num_rounds = num_rounds
self._start_round = start_round
self._task_timeout = task_timeout
self.ignore_result_error = ignore_result_error

# workflow phases: init, train, validate
self._phase = AppConstants.PHASE_INIT
self._global_weights = None
self._current_round = None

# task names
self.init_model_task_name = init_model_task_name
self.train_task_name = train_task_name

self.targets_names = ["site-1", "site-2"]
self.nr_supported_clients = 2
self.batch_size = batch_size

def start_controller(self, fl_ctx: FLContext):
self.log_debug(fl_ctx, "starting controller")
self.persistor = fl_ctx.get_engine().get_component(self.persistor_id)
self.shareable_generator = fl_ctx.get_engine().get_component(self.shareable_generator_id)
if not isinstance(self.persistor, LearnablePersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}", fl_ctx
)
if not isinstance(self.shareable_generator, ShareableGenerator):
self.system_panic(
f"Shareable generator {self.shareable_generator_id} must be a Shareable Generator instance, "
f"but got {type(self.shareable_generator)}",
fl_ctx,
)

# initialize global model
fl_ctx.set_prop(AppConstants.START_ROUND, self._start_round, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False)
self._global_weights = self.persistor.load(fl_ctx)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True)
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> bool:
# submitted shareable is stored in client_task.result
# we need to update task.data with that shareable so the next target
# will get the updated shareable
task = client_task.task
result = client_task.result
rc = result.get_return_code()

if rc and rc != ReturnCode.OK:
if self.ignore_result_error:
self.log_error(fl_ctx, f"Ignore the task {task} result. Train result error code: {rc}")
return False
else:
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]:
self.system_panic(
f"Peer context for task {task} is bad or missing. SplitNNController exiting.", fl_ctx=fl_ctx
)
return False
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]:
self.system_panic(
f"Execution Exception in client task {task}. SplitNNController exiting.", fl_ctx=fl_ctx
)
return False
elif rc in [
ReturnCode.EXECUTION_RESULT_ERROR,
ReturnCode.TASK_DATA_FILTER_ERROR,
ReturnCode.TASK_RESULT_FILTER_ERROR,
]:
self.system_panic(
f"Execution result for task {task} is not a shareable. SplitNNController exiting.",
fl_ctx=fl_ctx,
)
return False

# assign result to current task
if result:
task.set_prop(SplitNNConstants.TASK_RESULT, result)

return True

def _check_targets(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
targets = engine.get_clients()
for t in targets:
if t.name not in self.targets_names:
self.system_panic(f"Client {t.name} not in expected target names: {self.targets_names}", fl_ctx)

def _init_models(self, abort_signal: Signal, fl_ctx: FLContext):
self._check_targets(fl_ctx)
self.log_debug(fl_ctx, f"SplitNN initializing model {self.targets_names}.")

# Create init_model_task_name
data_shareable: Shareable = self.shareable_generator.learnable_to_shareable(self._global_weights, fl_ctx)
task = Task(
name=self.init_model_task_name,
data=data_shareable,
result_received_cb=self._process_result,
)

self.broadcast_and_wait(
task=task,
min_responses=self.nr_supported_clients,
wait_time_after_min_received=0,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)

def _train(self, abort_signal: Signal, fl_ctx: FLContext):
self._check_targets(fl_ctx)
self.log_debug(fl_ctx, f"SplitNN training starting with {self.targets_names}.")

# Create train_task
data_shareable: Shareable = Shareable()
data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds)
data_shareable.set_header(SplitNNConstants.BATCH_SIZE, self.batch_size)
data_shareable.set_header(SplitNNConstants.TARGET_NAMES, self.targets_names)

task = Task(
name=self.train_task_name,
data=data_shareable,
result_received_cb=self._process_result,
)

self.broadcast_and_wait(
task=task,
min_responses=self.nr_supported_clients,
wait_time_after_min_received=0,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
try:
self._check_targets(fl_ctx)
self.log_debug(fl_ctx, f"Train with on {self.targets_names}")

# 1. initialize models on clients
self._init_models(abort_signal=abort_signal, fl_ctx=fl_ctx)

# 2. Start split learning
self._phase = AppConstants.PHASE_TRAIN
self._train(abort_signal=abort_signal, fl_ctx=fl_ctx)

self._phase = AppConstants.PHASE_FINISHED
self.log_debug(fl_ctx, "SplitNN training ended.")
except BaseException as e:
error_msg = f"SplitNN control_flow exception {e}"
self.log_error(fl_ctx, error_msg)
self.system_panic(str(e), fl_ctx)

def stop_controller(self, fl_ctx: FLContext):
self._phase = AppConstants.PHASE_FINISHED
self.log_debug(fl_ctx, "controller stopped")

def process_result_of_unknown_task(
self,
client: Client,
task_name: str,
client_task_id: str,
result: Shareable,
fl_ctx: FLContext,
):
self.log_warning(fl_ctx, f"Dropped result of unknown task: {task_name} from client {client.name}.")

def handle_event(self, event_type: str, fl_ctx: FLContext):
super().handle_event(event_type, fl_ctx)
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector)))

collector.add_info(
group_name=self._name,
info={"phase": self._phase, "current_round": self._current_round, "num_rounds": self._num_rounds},
)

0 comments on commit cf580dc

Please sign in to comment.