Note: This guide will not work for Mac M1/M2 systems as they do not support Keras with TensorFlow v1. Refer to the for the PyTorch guide which will work on Mac M1/M2 systems.
All commands are assumed to be run from the base directory at the top of this repository. We also assume other folders, like examples
, are located under the same directory.
In this example, we will train a Keras (TensorFlow v1) CNN model, as shown below, on MNIST data in the federated learning fashion.
num_classes = 10
img_rows, img_cols = 28, 28
if K.image_data_format() == 'channels_first':
input_shape = (1, img_rows, img_cols)
input_shape = (img_rows, img_cols, 1)
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
We highly recommend using Conda installation for IBM federated learning. If you don't have Conda, you can install it here.
Once Conda is installed, create a new environment for IBM federated learning by running the following. To use Keras with TensorFlow v1, you must use Python 3.7.
conda create -n <env_name> python=3.7
Once the conda environment is created, activate it by running the following.
conda activate <env_name>
Then install the IBM federated learning package using the wheel file (located in the federated-learning-lib
directory) with a Keras backend by running the following.
pip install "/path/to/federated_learning_lib.whl[keras]"
- The quotes are required if using the Zsh shell (this is the default shell for Mac).
- The latest IBM FL library supports model training using Scikit-learn, PyTorch, Keras (with TensorFlow v1), and TensorFlow v2. You must specify the desired model training backend when installing the IBM FL library. In this example, we specify to install with the Keras backend. See here for more details.
For example, run
python examples/ -n 2 -d mnist -pp 200
This command would generate 2 parties with 200 data points each, randomly sampled from the MNIST dataset. By default, the data is stored under the examples/data/mnist/random
Warning: test set and train set contain different labels
Party_ 0
nb_x_train: (200, 28, 28) nb_x_test: (5000, 28, 28)
* Label 0 samples: 22
* Label 1 samples: 30
* Label 2 samples: 16
* Label 3 samples: 22
* Label 4 samples: 17
* Label 5 samples: 25
* Label 6 samples: 12
* Label 7 samples: 15
* Label 8 samples: 19
* Label 9 samples: 22
Finished! :) Data saved in examples/data/mnist/random
Party_ 1
nb_x_train: (200, 28, 28) nb_x_test: (5000, 28, 28)
* Label 0 samples: 22
* Label 1 samples: 18
* Label 2 samples: 22
* Label 3 samples: 23
* Label 4 samples: 14
* Label 5 samples: 23
* Label 6 samples: 22
* Label 7 samples: 20
* Label 8 samples: 18
* Label 9 samples: 18
Finished! :) Data saved in examples/data/mnist/random
For a full description of the different options to prepare datasets, run python examples/ -h
For example, run:
python examples/ -n <num_parties> -f iter_avg -m keras -d mnist -p <path>
This command performs two tasks:
It specifies the machine learning model to be trained, in this case, a Keras CNN classifier.
It generates the configuration files necessary to train a
model via fusion algorithm iter_avg, assuming<num_parties>
parties join the federated learning training. You must also specify the dataset name via-d
and the party data path via-p
In this example, we run:
python examples/ -n 2 -f iter_avg -m keras -d mnist -p examples/data/mnist/random/
Hence, we generate 2 parties in our example, using the mnist
dataset and examples/data/mnist/random
as our data path.
Finished generating config file for aggregator. Files can be found in: <whl_directory>/examples/configs/iter_avg/keras/config_agg.yml
Finished generating config file for parties. Files can be found in: <whl_directory>/examples/configs/iter_avg/keras/config_party*.yml
You may also see warning messages which are fine. For a full description of the different options, run python examples/ -h
Below you can see samples of configuration files.
- Aggregator's configuration file:
port: 5000
enable: false
name: FlaskConnection
path: ibmfl.connection.flask_connection
sync: false
npz_file: examples/datasets/mnist.npz
name: MnistKerasDataHandler
path: ibmfl.util.data_handlers.mnist_keras_data_handler
name: IterAvgFusionHandler
path: ibmfl.aggregator.fusion.iter_avg_fusion_handler
max_timeout: 60
parties: 2
rounds: 3
termination_accuracy: 0.9
lr: 0.01
epochs: 3
name: ProtoHandler
path: ibmfl.aggregator.protohandler.proto_handler
- Party's configuration file:
port: 5000
port: 8085
enable: false
name: FlaskConnection
path: ibmfl.connection.flask_connection
sync: false
npz_file: examples/data/mnist/random/data_party0.npz
name: MnistKerasDataHandler
path: ibmfl.util.data_handlers.mnist_keras_data_handler
name: LocalTrainingHandler
name: KerasFLModel
path: ibmfl.model.keras_fl_model
model_definition: examples/configs/iter_avg/keras/compiled_keras.h5
model_name: keras-cnn
name: PartyProtocolHandler
Notice that the configuration files contain a data
section that is different for each party. In fact, each party's points to its own data, generated from the command in step 2.
To start the aggregator, open a terminal window running the IBM federated fearning environment set up beforehand, and check that you are in the correct directory. In the terminal run:
python -m ibmfl.aggregator.aggregator examples/configs/iter_avg/keras/config_agg.yml
where the path provided is the aggregator configuration file path.
2020-06-29 11:38:38,058 - ibmfl.util.config - INFO - Getting details from config file.
2020-06-29 11:38:45,001 - ibmfl.util.config - INFO - No model config provided for this setup.
2020-06-29 11:38:45,353 - ibmfl.util.config - INFO - No local training config provided for this setup.
2020-06-29 11:38:45,353 - ibmfl.connection.flask_connection - INFO - RestSender initialized
2020-06-29 11:38:45,353 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.START
<ibmfl.connection.router_handler.Router object at 0x106cb9630>
2020-06-29 11:38:45,353 - ibmfl.connection.flask_connection - INFO - Receiver Initialized
2020-06-29 11:38:45,353 - ibmfl.connection.flask_connection - INFO - Initializing Flask application
2020-06-29 11:38:45,356 - __main__ - INFO - Aggregator initialization successful
Then in the terminal, type START
and press enter.
2020-06-29 11:39:24,349 - root - INFO - State: States.CLI_WAIT
2020-06-29 11:39:24,349 - __main__ - INFO - Aggregator start successful
* Serving Flask app "ibmfl.connection.flask_connection" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment.
Use a production WSGI server instead.
* Debug mode: off
2020-06-29 11:39:24,356 - werkzeug - INFO - * Running on (Press CTRL+C to quit)
To start and register a new party, open one new terminal window for each party, running the IBM federated learning environment set up beforehand, and make sure you are in the correct directory. In the terminal run:
python -m examples/configs/iter_avg/keras/config_party<idx>.yml
where the path provided is the path to the party's configuration file.
Note: Each party will have a different configuration file; in our example, it is noted by changing config_party<idx>.yml
. For instance, to start the 1st party, one would run:
python -m examples/configs/iter_avg/keras/config_party0.yml
2020-06-29 11:40:30,420 - ibmfl.util.config - INFO - Getting config from file
2020-06-29 11:40:30,420 - ibmfl.util.config - INFO - Getting details from config file.
2020-06-29 11:40:33,617 - ibmfl.util.config - INFO - No fusion config provided for this setup.
2020-06-29 11:40:33,987 - ibmfl.connection.flask_connection - INFO - RestSender initialized
<ibmfl.connection.router_handler.Router object at 0x135de64a8>
2020-06-29 11:40:33,997 - ibmfl.connection.flask_connection - INFO - Receiver Initialized
2020-06-29 11:40:33,997 - ibmfl.connection.flask_connection - INFO - Initializing Flask application
2020-06-29 11:40:34,000 - __main__ - INFO - Party initialization successful
You may also see warning messages which are fine. In the terminal for each party, type START
and press enter to start the party. Then type REGISTER
and press enter to register the party for the federated learning task.
2020-06-29 11:41:35,732 - __main__ - INFO - Party start successful
* Serving Flask app "ibmfl.connection.flask_connection" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment.
Use a production WSGI server instead.
* Debug mode: off
2020-06-29 11:41:35,735 - werkzeug - INFO - * Running on (Press CTRL+C to quit)
2020-06-29 11:41:47,056 - __main__ - INFO - Registering party...
2020-06-29 11:41:47,071 - __main__ - INFO - Registration Successful
The aggregator terminal will also prompt out INFO to show that it receives the party's registration message (as shown in the third figure on the right).
2020-06-29 11:41:47,067 - ibmfl.connection.flask_connection - INFO - Request received for path :6
2020-06-29 11:41:47,068 - ibmfl.aggregator.protohandler.proto_handler - INFO - Adding party with id 78da36fb-2443-45d0-80d2-b438bf8c01a5
2020-06-29 11:41:47,069 - werkzeug - INFO - - - [29/Jun/2020 11:41:47] "POST /6 HTTP/1.1" 200 -
To initiate federated training, type TRAIN
in your aggregator terminal and press enter.
Note: In this example, we have 2 parties join the training and we run 3 global rounds, each round with 3 local epochs.
Outputs in the aggregator terminal after running the above command will look like:
2020-06-29 11:43:43,982 - root - INFO - State: States.PROC_TRAIN
2020-06-29 11:43:43,982 - __main__ - INFO - Initiating Global Training.
2020-06-29 11:43:43,982 - ibmfl.aggregator.fusion.fusion_handler - INFO - Warm start disabled.
2020-06-29 11:43:43,982 - ibmfl.aggregator.fusion.iter_avg_fusion_handler - INFO - Model updateNone
2020-06-29 11:43:43,983 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.SND_REQ
2020-06-29 11:43:44,090 - ibmfl.aggregator.protohandler.proto_handler - INFO - Total number of success responses :2
2020-06-29 11:43:44,090 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.QUORUM_WAIT
2020-06-29 11:43:44,090 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:47,423 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:47,429 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:47,480 - werkzeug - INFO - - - [29/Jun/2020 11:43:47] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:47,518 - werkzeug - INFO - - - [29/Jun/2020 11:43:47] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:49,094 - ibmfl.aggregator.protohandler.proto_handler - INFO - Timeout:60 Time spent:5
2020-06-29 11:43:49,094 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:49,094 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.PROC_RSP
2020-06-29 11:43:49,108 - ibmfl.aggregator.fusion.iter_avg_fusion_handler - INFO - Model update<ibmfl.model.model_update.ModelUpdate object at 0x111835828>
2020-06-29 11:43:49,108 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.SND_REQ
2020-06-29 11:43:49,318 - ibmfl.aggregator.protohandler.proto_handler - INFO - Total number of success responses :2
2020-06-29 11:43:49,318 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.QUORUM_WAIT
2020-06-29 11:43:49,318 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:51,096 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:51,131 - werkzeug - INFO - - - [29/Jun/2020 11:43:51] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:51,133 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:51,176 - werkzeug - INFO - - - [29/Jun/2020 11:43:51] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:54,321 - ibmfl.aggregator.protohandler.proto_handler - INFO - Timeout:60 Time spent:5
2020-06-29 11:43:54,321 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:54,321 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.PROC_RSP
2020-06-29 11:43:54,341 - ibmfl.aggregator.fusion.iter_avg_fusion_handler - INFO - Model update<ibmfl.model.model_update.ModelUpdate object at 0x120ceb128>
2020-06-29 11:43:54,341 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.SND_REQ
2020-06-29 11:43:54,550 - ibmfl.aggregator.protohandler.proto_handler - INFO - Total number of success responses :2
2020-06-29 11:43:54,550 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.QUORUM_WAIT
2020-06-29 11:43:54,550 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:56,553 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:56,583 - werkzeug - INFO - - - [29/Jun/2020 11:43:56] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:56,584 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:56,615 - werkzeug - INFO - - - [29/Jun/2020 11:43:56] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:59,553 - ibmfl.aggregator.protohandler.proto_handler - INFO - Timeout:60 Time spent:5
2020-06-29 11:43:59,553 - ibmfl.aggregator.protohandler.proto_handler - INFO - Target Qorum: 2
2020-06-29 11:43:59,553 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.PROC_RSP
2020-06-29 11:43:59,571 - ibmfl.aggregator.fusion.iter_avg_fusion_handler - INFO - Reached maximum global rounds. Finish training :)
2020-06-29 11:43:59,572 - __main__ - INFO - Finished Global Training
Outputs in party's (party 1) terminal after running the above command will look like:
2020-06-29 11:43:43,998 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:43,999 - - INFO - received a async request
2020-06-29 11:43:43,999 - - INFO - finished async request
2020-06-29 11:43:44,000 - werkzeug - INFO - - - [29/Jun/2020 11:43:44] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:44,000 - - INFO - Handling async request in a separate thread
2020-06-29 11:43:44,001 - - INFO - Received request from aggregator
2020-06-29 11:43:44,001 - - INFO - Received request in with message_type: 7
2020-06-29 11:43:44,001 - - INFO - Received request in PH 7
2020-06-29 11:43:44,001 - ibmfl.util.data_handlers.mnist_keras_data_handler - INFO - Loaded training data from examples/data/mnist/random/data_party0.npz
x_train shape: (200, 28, 28, 1)
200 train samples
5000 test samples
2020-06-29 11:43:44,052 - - INFO - No model update was provided.
2020-06-29 11:43:44,053 - - INFO - Local training started...
2020-06-29 11:43:44,053 - ibmfl.model.keras_fl_model - INFO - Using default hyperparameters: batch_size:128
2020-06-29 11:43:44,080 - tensorflow - WARNING - From /Users/ to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Epoch 1/3
200/200 [==============================] - 1s 5ms/step - loss: 2.2764 - acc: 0.1150
Epoch 2/3
200/200 [==============================] - 1s 4ms/step - loss: 2.0754 - acc: 0.2400
Epoch 3/3
200/200 [==============================] - 1s 6ms/step - loss: 1.7741 - acc: 0.4000
2020-06-29 11:43:47,353 - - INFO - Local training done, generating model update...
2020-06-29 11:43:47,359 - - INFO - successfully finished async request
2020-06-29 11:43:49,200 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:49,236 - - INFO - received a async request
2020-06-29 11:43:49,236 - - INFO - finished async request
2020-06-29 11:43:49,237 - werkzeug - INFO - - - [29/Jun/2020 11:43:49] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:49,238 - - INFO - Handling async request in a separate thread
2020-06-29 11:43:49,238 - - INFO - Received request from aggregator
2020-06-29 11:43:49,238 - - INFO - Received request in with message_type: 7
2020-06-29 11:43:49,238 - - INFO - Received request in PH 7
2020-06-29 11:43:49,238 - ibmfl.util.data_handlers.mnist_keras_data_handler - INFO - Loaded training data from examples/data/mnist/random/data_party0.npz
x_train shape: (200, 28, 28, 1)
200 train samples
5000 test samples
2020-06-29 11:43:49,277 - - INFO - Local model updated.
2020-06-29 11:43:49,277 - - INFO - Local training started...
2020-06-29 11:43:49,277 - ibmfl.model.keras_fl_model - INFO - Using default hyperparameters: batch_size:128
Epoch 1/3
200/200 [==============================] - 1s 4ms/step - loss: 1.4924 - acc: 0.6300
Epoch 2/3
200/200 [==============================] - 1s 3ms/step - loss: 1.8330 - acc: 0.4100
Epoch 3/3
200/200 [==============================] - 0s 2ms/step - loss: 1.0651 - acc: 0.6800
2020-06-29 11:43:51,074 - - INFO - Local training done, generating model update...
2020-06-29 11:43:51,078 - - INFO - successfully finished async request
2020-06-29 11:43:54,418 - ibmfl.connection.flask_connection - INFO - Request received for path :7
2020-06-29 11:43:54,455 - - INFO - received a async request
2020-06-29 11:43:54,456 - - INFO - finished async request
2020-06-29 11:43:54,458 - werkzeug - INFO - - - [29/Jun/2020 11:43:54] "POST /7 HTTP/1.1" 200 -
2020-06-29 11:43:54,459 - - INFO - Handling async request in a separate thread
2020-06-29 11:43:54,459 - - INFO - Received request from aggregator
2020-06-29 11:43:54,459 - - INFO - Received request in with message_type: 7
2020-06-29 11:43:54,459 - - INFO - Received request in PH 7
2020-06-29 11:43:54,459 - ibmfl.util.data_handlers.mnist_keras_data_handler - INFO - Loaded training data from examples/data/mnist/random/data_party0.npz
x_train shape: (200, 28, 28, 1)
200 train samples
5000 test samples
2020-06-29 11:43:54,502 - - INFO - Local model updated.
2020-06-29 11:43:54,502 - - INFO - Local training started...
2020-06-29 11:43:54,502 - ibmfl.model.keras_fl_model - INFO - Using default hyperparameters: batch_size:128
Epoch 1/3
200/200 [==============================] - 0s 2ms/step - loss: 0.9385 - acc: 0.7250
Epoch 2/3
200/200 [==============================] - 1s 4ms/step - loss: 1.0080 - acc: 0.7200
Epoch 3/3
200/200 [==============================] - 1s 4ms/step - loss: 0.7465 - acc: 0.7950
2020-06-29 11:43:56,514 - - INFO - Local training done, generating model update...
2020-06-29 11:43:56,518 - - INFO - successfully finished async request
Outputs from party 2 will be similar as party 1.
For a full list of supported commands, see examples/
. Sample outputs of issuing the EVAL
command in one of the parties' terminal after the global training.
2020-06-29 11:46:05,003 - ibmfl.util.data_handlers.mnist_keras_data_handler - INFO - Loaded training data from examples/data/mnist/random/data_party0.npz
x_train shape: (200, 28, 28, 1)
200 train samples
5000 test samples
5000/5000 [==============================] - 2s 303us/step
2020-06-29 11:46:06,542 - - INFO - {'loss': 0.6104391970634461, 'acc': 0.8152}
Users can also enter TRAIN
again at the aggregator's terminal if they want to continue the FL training. Entering SYNC
at the aggregator's terminal will trigger the synchronization of the current global model with parties, and SAVE
will trigger the parties to save their models at the local working directory.
Remember to use STOP
to terminate the aggregator's and parties' processes and exit. Outputs in the aggregator terminal after running STOP
2020-06-29 11:46:44,476 - root - INFO - State: States.PROC_STOP
2020-06-29 11:46:44,476 - ibmfl.aggregator.protohandler.proto_handler - INFO - State: States.SND_REQ
2020-06-29 11:46:44,581 - ibmfl.aggregator.protohandler.proto_handler - INFO - Total number of success responses :2
2020-06-29 11:46:44,582 - ibmfl.connection.flask_connection - INFO - Stopping Receiver and Sender
2020-06-29 11:46:44,584 - werkzeug - INFO - - - [29/Jun/2020 11:46:44] "POST /shutdown HTTP/1.1" 200 -
2020-06-29 11:46:44,585 - __main__ - INFO - Aggregator stop successful
Outputs in the party's terminal after running STOP
2020-06-29 11:47:01,587 - ibmfl.connection.flask_connection - INFO - Stopping Receiver and Sender
2020-06-29 11:47:01,591 - werkzeug - INFO - - - [29/Jun/2020 11:47:01] "POST /shutdown HTTP/1.1" 200 -