Skip to content

Commit

Permalink
Cifar10 split learning (NVIDIA#1168)
Browse files Browse the repository at this point in the history
run training with higher timeout value

rename classes

add training accuracy

remove printouts

add figure

update printouts

update readme

vertical split data

run psi

add notebook

update notebook

update notebook

take intersection.txt as input for split-learning

configure overlap

add todo

refactor to use FCI

update requirements

formatting

add validation

unify gitignore

revert network; remove unnecessary check; use stats pool for computation time

introduce cifar10 data utils

move splitnn example to vertical_federated_learning

move more files

deleted moved files

move to tutorials

address comments
  • Loading branch information
holgerroth authored Feb 16, 2023
1 parent b0a5b05 commit 468e0b6
Show file tree
Hide file tree
Showing 29 changed files with 3,238 additions and 49 deletions.
File renamed without changes.
23 changes: 0 additions & 23 deletions examples/cifar10/cifar10-sim/.gitignore

This file was deleted.

32 changes: 6 additions & 26 deletions examples/cifar10/pt/utils/cifar10_data_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,12 @@
import os

import numpy as np
import torchvision.datasets as datasets
from cifar10_data_utils import get_site_class_summary, load_cifar10_data

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext

CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments


def _get_site_class_summary(train_label, site_idx):
class_sum = {}

for site, data_idx in site_idx.items():
unq, unq_cnt = np.unique(train_label[data_idx], return_counts=True)
tmp = {int(unq[i]): int(unq_cnt[i]) for i in range(len(unq))}
class_sum[site] = tmp
return class_sum


class Cifar10DataSplitter(FLComponent):
def __init__(self, split_dir: str = None, num_sites: int = 8, alpha: float = 0.5, seed: int = 0):
Expand All @@ -68,8 +56,10 @@ def __init__(self, split_dir: str = None, num_sites: int = 8, alpha: float = 0.5
self.alpha = alpha
self.seed = seed

if self.split_dir is None:
raise ValueError("You need to define a valid `split_dir` when splitting the data.")
if alpha < 0.0:
raise ValueError(f"Alpha should be larger 0.0 but was {alpha}!")
raise ValueError(f"Alpha should be larger or equal 0.0 but was" f" {alpha}!")

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
Expand All @@ -85,8 +75,6 @@ def split(self, fl_ctx: FLContext):
site_idx, class_sum = self._partition_data()

# write to files
if self.split_dir is None:
raise ValueError("You need to define a valid `split_dir` when splitting the data.")
if not os.path.isdir(self.split_dir):
os.makedirs(self.split_dir)
sum_file_name = os.path.join(self.split_dir, "summary.txt")
Expand All @@ -101,16 +89,8 @@ def split(self, fl_ctx: FLContext):
site_file_name = site_file_path + str(site + 1) + ".npy"
np.save(site_file_name, np.array(site_idx[site]))

def load_cifar10_data(self):
# download data
train_dataset = datasets.CIFAR10(root=CIFAR10_ROOT, train=True, download=True)

# only training label is needed for doing split
train_label = np.array(train_dataset.targets)
return train_label

def _partition_data(self):
train_label = self.load_cifar10_data()
train_label = load_cifar10_data()

min_size = 0
K = 10
Expand Down Expand Up @@ -140,6 +120,6 @@ def _partition_data(self):
site_idx[j] = idx_batch[j]

# collect class summary
class_sum = _get_site_class_summary(train_label, site_idx)
class_sum = get_site_class_summary(train_label, site_idx)

return site_idx, class_sum
62 changes: 62 additions & 0 deletions examples/cifar10/pt/utils/cifar10_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This Dirichlet sampling strategy for creating a heterogeneous partition is adopted
# from FedMA (https://github.com/IBM/FedMA).

# MIT License

# Copyright (c) 2020 International Business Machines

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import torchvision.datasets as datasets

CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments


def load_cifar10_data():
# download data
train_dataset = datasets.CIFAR10(root=CIFAR10_ROOT, train=True, download=True)

# only training label is needed for doing split
train_label = np.array(train_dataset.targets)
return train_label


def get_site_class_summary(train_label, site_idx):
class_sum = {}

for site, data_idx in site_idx.items():
unq, unq_cnt = np.unique(train_label[data_idx], return_counts=True)
tmp = {int(unq[i]): int(unq_cnt[i]) for i in range(len(unq))}
class_sum[site] = tmp
return class_sum
11 changes: 11 additions & 0 deletions examples/tutorial/vertical_federated_learning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# ide
.idea/
.ipynb_checkpoints/

# python
__pycache__
.pyc

# virtual environments
nvflare_cifar10

7 changes: 7 additions & 0 deletions examples/tutorial/vertical_federated_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Vertical Federated Learning

## Split Learning
### Split learning with CIFAR-10
This [example](./cifar10-splitnn/README.md) includes instructions on how to run
[split learning](https://arxiv.org/abs/1810.06060) using the CIFAR-10 dataset
and the FL simulator.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Split Learning with CIFAR-10

This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the CIFAR-10 dataset and the FL simulator.

We assume one client holds the images, and the other clients holds the labels to compute losses and accuracy metrics.
Activations and corresponding gradients are being exchanged between the clients through the NVFlare server.

<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/>

For instructions of how to run CIFAR-10 in real-world deployment settings,
see the example on ["Real-world Federated Learning with CIFAR-10"](../cifar10-real-world/README.md).

## (Optional) Set up a virtual environment
```
python3 -m pip install --user --upgrade pip
python3 -m pip install --user virtualenv
```
(If needed) make all shell scripts executable using
```
find . -name ".sh" -exec chmod +x {} \;
```
initialize virtual environment.
```
source ./virtualenv/set_env.sh
```
install required packages for training
```
pip install --upgrade pip
pip install -r ./virtualenv/min-requirements.txt
```

## Start Jupyter notebook
Set `PYTHONPATH` to include custom files of this example and some reused files from the [CIFAR-10](../cifar10) examples:
```
export PYTHONPATH=${PWD}/src:${PWD}/../../../cifar10
```
Start a Jupyter Lab
```
jupyter lab .
```
and open [cifar10_split_learning.ipynb](./cifar10_split_learning.ipynb).

See [here](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) for installing Jupyter Lab.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys

from splitnn.cifar10_vertical_data_splitter import Cifar10VerticalDataSplitter

from nvflare.apis.fl_context import FLContext

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
import argparse

from nvflare.apis.fl_constant import ReservedKey


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split_dir", type=str, default="/tmp/cifar10_vert_splits", help="output folder")
parser.add_argument("--overlap", type=int, default=10_000, help="number of overlapping samples")
args = parser.parse_args()

splitter = Cifar10VerticalDataSplitter(split_dir=args.split_dir, overlap=args.overlap)

# set up a dummy context for logging
fl_ctx = FLContext()
fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "local")
fl_ctx.set_prop(ReservedKey.RUN_NUM, "_")

splitter.split(fl_ctx) # will download to CIFAR10_ROOT defined in
# Cifar10DataSplitter


if __name__ == "__main__":
main()
Loading

0 comments on commit 468e0b6

Please sign in to comment.