Skip to content

Commit

Permalink
Update Workflow_Interface_VFL_Two_Party
Browse files Browse the repository at this point in the history
Added actor related changed to Workflow_Interface_VFL_Two_Party

Signed-off-by: Parth Mandaliya <parthx.mandaliya@intel.com>
  • Loading branch information
ParthM-GitHub committed Sep 28, 2023
1 parent 5a89b19 commit 4545663
Showing 1 changed file with 37 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,47 @@
"source": [
"# Setup participants\n",
"aggregator = Aggregator()\n",
"aggregator.private_attributes['trainloader'] = trainloader\n",
"aggregator.private_attributes['label_model'] = label_model\n",
"aggregator.private_attributes['label_model_optimizer'] = label_model_optimizer\n",
"\n",
"# Setup collaborators with private attributes\n",
"def callable_to_initialize_aggregator_private_attributes(train_loader,label_model,label_model_optimizer):\n",
" return {\"trainloader\": train_loader,\n",
" \"label_model\" : label_model,\n",
" \"label_model_optimizer\":label_model_optimizer\n",
" } \n",
"\n",
"# Setup aggregator private attributes via callable function\n",
"aggregator = Aggregator(\n",
" name=\"agg\",\n",
" private_attributes_callable=callable_to_initialize_aggregator_private_attributes,\n",
" train_loader = trainloader,\n",
" label_model=label_model,\n",
" label_model_optimizer=label_model_optimizer\n",
")\n",
"\n",
"# Setup collaborators private attributes via callable function\n",
"collaborator_names = ['Portland']\n",
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"\n",
"for idx, collaborator in enumerate(collaborators):\n",
" collaborator.private_attributes['data_model'] = data_model\n",
" collaborator.private_attributes['data_model_optimizer'] = data_model_optimizer\n",
" collaborator.private_attributes['trainloader'] = deepcopy(trainloader)\n",
"def callable_to_initialize_collaborator_private_attributes(index,data_model,data_model_optimizer,train_loader):\n",
" return {\n",
" \"data_model\": data_model,\n",
" \"data_model_optimizer\": data_model_optimizer,\n",
" \"trainloader\" : deepcopy(train_loader)\n",
" }\n",
"\n",
"collaborators = []\n",
"for idx, collaborator_name in enumerate(collaborator_names):\n",
" collaborators.append(\n",
" Collaborator(\n",
" name=collaborator_name,\n",
" private_attributes_callable=callable_to_initialize_collaborator_private_attributes,\n",
" index=idx,\n",
" data_model = data_model,\n",
" data_model_optimizer = data_model_optimizer,\n",
" train_loader = trainloader\n",
" )\n",
" )\n",
"\n",
"local_runtime = LocalRuntime(\n",
" aggregator=aggregator, collaborators=collaborators, backend='single_process')\n",
" aggregator=aggregator, collaborators=collaborators, backend='ray')\n",
"print(f'Local runtime collaborators = {local_runtime.collaborators}')\n",
"\n",
"epochs = 100\n",
Expand Down Expand Up @@ -252,7 +278,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.17"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 4545663

Please sign in to comment.