Skip to content

Commit

Permalink
Adapt new aggregator
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
talkingwallace authored and sagewe committed Jul 21, 2023
1 parent 4a32f58 commit e364e82
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 97 deletions.
132 changes: 38 additions & 94 deletions python/fate/ml/aggregator/plaintext_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Union
from .base import Aggregator
import logging
from fate.arch.protocol._dh import SecureAggregatorClient as sa_client
from fate.arch.protocol._dh import SecureAggregatorServer as sa_server


logger = logging.getLogger(__name__)
Expand All @@ -18,12 +20,13 @@ class PlainTextAggregatorClient(Aggregator):
PlainTextAggregatorClient is used to aggregate plain text data
"""

def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sample_num=1) -> None:
def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None:

super().__init__(ctx, aggregator_name)
self.ctx = ctx
self._weight = 1.0

self.aggregator_name = 'default' if aggregator_name is None else aggregator_name

if sample_num <= 0 and not isinstance(sample_num, int):
raise ValueError("sample_num should be int greater than 0")

Expand All @@ -44,25 +47,29 @@ def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sa

logger.info("aggregate weight is {}".format(self._weight))

self.model_aggregator = sa_client(prefix=self.aggregator_name+'_model', is_mock=True)
self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks])
self.loss_aggregator = sa_client(prefix=self.aggregator_name+'_loss', is_mock=True)
self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks])

def _process_model(self, model):

to_agg = None
if isinstance(model, np.ndarray) or isinstance(model, t.Tensor):
to_agg = model * self._weight
return to_agg
return [to_agg]

if isinstance(model, t.nn.Module):
parameters = list(model.parameters())
tmp_list = [[p.cpu().detach().numpy() for p in parameters if p.requires_grad]]
agg_list = [p.cpu().detach().numpy() for p in parameters if p.requires_grad]

elif isinstance(model, list):
for p in model:
assert isinstance(
p, np.ndarray), 'expecting List[np.ndarray], but got {}'.format(p)
tmp_list = [model]
agg_list = model

to_agg = [[arr * self._weight for arr in arr_list]
for arr_list in tmp_list]
return to_agg
return agg_list

def _recover_model(self, model, agg_model):

Expand All @@ -75,48 +82,25 @@ def _recover_model(self, model, agg_model):
else:
return agg_model

def _send_loss(self, loss):
assert isinstance(loss, float) or isinstance(
loss, np.ndarray), 'illegal loss type {}, loss should be a float or a np array'.format(type(loss))
loss_suffix = self.suffix['local_loss']()
self.ctx.arbiter.put(loss_suffix, loss)

def _send_model(self, model: Union[np.ndarray, t.Tensor, t.nn.Module]):
"""Sending model to arbiter for aggregation
Parameters
----------
model : model can be:
A numpy array
A Weight instance(or subclass of Weights), see federatedml.framework.weights
List of numpy array
A pytorch model, is the subclass of torch.nn.Module
A pytorch optimizer, will extract param group from this optimizer as weights to aggregate
"""
# judge model type
to_agg_model = self._process_model(model)
suffix = self.suffix['local_model']()
self.ctx.arbiter.put(suffix, to_agg_model)

def _get_aggregated_model(self):
return self.ctx.arbiter.get(self.suffix['agg_model']())[0]

def _get_aggregated_loss(self):
return self.ctx.arbiter.get(self.suffix['agg_loss']())[0]

"""
User API
"""

def model_aggregation(self, model):

self._send_model(model)
agg_model = self._get_aggregated_model()
to_send = self._process_model(model)
print('model is ', to_send)
agg_model = self.model_aggregator.secure_aggregate(self.ctx, to_send, self._weight)
return self._recover_model(model, agg_model)

def loss_aggregation(self, loss):
self._send_loss(loss)

if isinstance(loss, t.Tensor):
loss = loss.detach.cpu().numpy()
else:
loss = np.array(loss)
loss = [loss]
agg_loss = self.loss_aggregator.secure_aggregate(self.ctx, loss, self._weight)
return agg_loss


class PlainTextAggregatorServer(Aggregator):
Expand All @@ -125,9 +109,10 @@ class PlainTextAggregatorServer(Aggregator):
PlainTextAggregatorServer is used to aggregate plain text data
"""

def __init__(self, ctx: Context, aggregator_name=None) -> None:
def __init__(self, ctx: Context, aggregator_name: str = None) -> None:

super().__init__(ctx, aggregator_name)

weight_list = self._collect(self.suffix["local_weight"]())
weight_sum = sum(weight_list)
ret_weight = []
Expand All @@ -138,22 +123,22 @@ def __init__(self, ctx: Context, aggregator_name=None) -> None:
for idx, w in enumerate(ret_weight):
self._broadcast(w, ret_suffix, idx)

self.aggregator_name = 'default' if aggregator_name is None else aggregator_name
self.model_aggregator = sa_server(prefix=self.aggregator_name+'_model', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks])
self.loss_aggregator = sa_server(prefix=self.aggregator_name+'_loss', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks])

def _check_party_id(self, party_id):
# party idx >= -1, int
if not isinstance(party_id, int):
raise ValueError("party_id should be int")
if party_id < -1:
raise ValueError("party_id should be greater than -1")

def _collect(self, suffix, party_idx=-1):
self._check_party_id(party_idx)
def _collect(self, suffix):
guest_item = [self.ctx.guest.get(suffix)]
host_item = self.ctx.hosts.get(suffix)
combine_list = guest_item + host_item
if party_idx == -1:
return combine_list
else:
return combine_list[party_idx]
return combine_list

def _broadcast(self, data, suffix, party_idx=-1):
self._check_party_id(party_idx)
Expand All @@ -165,54 +150,13 @@ def _broadcast(self, data, suffix, party_idx=-1):
else:
self.ctx.hosts[party_idx - 1].put(suffix, data)

def _aggregate_model(self, party_idx=-1):

# get suffix
suffix = self.suffix['local_model']()
# recv params for aggregation
models = self._collect(suffix=suffix, party_idx=party_idx)
agg_result = None
# Aggregate numpy groups
if isinstance(models[0], list):
# aggregation
agg_result = models[0]
# aggregate numpy model weights from all clients
for params_group in models[1:]:
for agg_params, params in zip(
agg_result, params_group):
for agg_p, p in zip(agg_params, params):
# agg_p: NumpyWeights or numpy array
agg_p += p
else:
raise ValueError('invalid aggregation format: {}'.format(models))

if agg_result is None:
raise ValueError(
'can not aggregate receive model, format is illegal: {}'.format(models))

return agg_result

def _aggregate_loss(self, party_idx=-1):

# get loss
loss_suffix = self.suffix['local_loss']()
losses = self._collect(suffix=loss_suffix, party_idx=-1)
total_loss = losses[0]
for loss in losses[1:]:
total_loss += loss

return total_loss

"""
User API
"""

def model_aggregation(self, party_idx=-1):
agg_model = self._aggregate_model(party_idx=party_idx)
suffix = self.suffix['agg_model']()
self._broadcast(agg_model, suffix=suffix, party_idx=party_idx)
return agg_model
def model_aggregation(self, ranks=None):
self.model_aggregator.secure_aggregate(self.ctx, ranks=ranks)

def loss_aggregation(self, party_idx=-1):
agg_loss = self._aggregate_loss(party_idx=party_idx)
return agg_loss
def loss_aggregation(self, ranks=None):
self.loss_aggregator.secure_aggregate(self.ctx, ranks=ranks)

72 changes: 72 additions & 0 deletions python/fate/ml/aggregator/test/test_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import sys
import torch as t


arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed"


def create_ctx(local):
from fate.arch import Context
from fate.arch.computing.standalone import CSession
from fate.arch.federation.standalone import StandaloneFederation
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
computing = CSession()
return Context(computing=computing,
federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))


if __name__ == "__main__":

epoch = 10

if sys.argv[1] == "guest":
from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient

ctx = create_ctx(guest)
client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean')
model = t.nn.Sequential(
t.nn.Linear(10, 10),
t.nn.ReLU(),
t.nn.Linear(10, 1),
t.nn.Sigmoid()
)

for i in range(epoch):
client.model_aggregation(model)
elif sys.argv[1] == "host":
from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient

ctx = create_ctx(host)
client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean')
model = t.nn.Sequential(
t.nn.Linear(10, 10),
t.nn.ReLU(),
t.nn.Linear(10, 1),
t.nn.Sigmoid()
)

for i in range(epoch):
client.model_aggregation(model)

else:

from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorServer
ctx = create_ctx(arbiter)
server = PlainTextAggregatorServer(ctx)

for i in range(epoch):
server.model_aggregation()

6 changes: 3 additions & 3 deletions python/fate/ml/aggregator/test/test_fate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_ctx(local):
import numpy as np

ctx = create_ctx(guest)
client = SecureAggregatorClient()
client = SecureAggregatorClient(is_mock=True)
client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks])
print('ranks are {}'.format([ctx.guest.rank, *ctx.hosts.ranks]))
print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))]))
Expand All @@ -41,12 +41,12 @@ def create_ctx(local):
import numpy as np

ctx = create_ctx(host)
client = SecureAggregatorClient()
client = SecureAggregatorClient(is_mock=True)
client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks])
print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))]))
else:
from fate.arch.protocol import SecureAggregatorServer

ctx = create_ctx(arbiter)
server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks])
server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks], is_mock=True)
server.secure_aggregate(ctx)

0 comments on commit e364e82

Please sign in to comment.