diff --git a/python/fate/ml/aggregator/plaintext_aggregator.py b/python/fate/ml/aggregator/plaintext_aggregator.py index 493abfb97f..b8d09cfc04 100644 --- a/python/fate/ml/aggregator/plaintext_aggregator.py +++ b/python/fate/ml/aggregator/plaintext_aggregator.py @@ -4,6 +4,8 @@ from typing import Union from .base import Aggregator import logging +from fate.arch.protocol._dh import SecureAggregatorClient as sa_client +from fate.arch.protocol._dh import SecureAggregatorServer as sa_server logger = logging.getLogger(__name__) @@ -18,12 +20,13 @@ class PlainTextAggregatorClient(Aggregator): PlainTextAggregatorClient is used to aggregate plain text data """ - def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sample_num=1) -> None: + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: super().__init__(ctx, aggregator_name) self.ctx = ctx self._weight = 1.0 - + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + if sample_num <= 0 and not isinstance(sample_num, int): raise ValueError("sample_num should be int greater than 0") @@ -44,25 +47,29 @@ def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sa logger.info("aggregate weight is {}".format(self._weight)) + self.model_aggregator = sa_client(prefix=self.aggregator_name+'_model', is_mock=True) + self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_client(prefix=self.aggregator_name+'_loss', is_mock=True) + self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + def _process_model(self, model): to_agg = None if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): to_agg = model * self._weight - return to_agg + return [to_agg] if isinstance(model, t.nn.Module): parameters = list(model.parameters()) - tmp_list = [[p.cpu().detach().numpy() for p in parameters if p.requires_grad]] + agg_list = [p.cpu().detach().numpy() for p in parameters if p.requires_grad] + elif isinstance(model, list): for p in model: assert isinstance( p, np.ndarray), 'expecting List[np.ndarray], but got {}'.format(p) - tmp_list = [model] + agg_list = model - to_agg = [[arr * self._weight for arr in arr_list] - for arr_list in tmp_list] - return to_agg + return agg_list def _recover_model(self, model, agg_model): @@ -75,48 +82,25 @@ def _recover_model(self, model, agg_model): else: return agg_model - def _send_loss(self, loss): - assert isinstance(loss, float) or isinstance( - loss, np.ndarray), 'illegal loss type {}, loss should be a float or a np array'.format(type(loss)) - loss_suffix = self.suffix['local_loss']() - self.ctx.arbiter.put(loss_suffix, loss) - - def _send_model(self, model: Union[np.ndarray, t.Tensor, t.nn.Module]): - """Sending model to arbiter for aggregation - - Parameters - ---------- - model : model can be: - A numpy array - A Weight instance(or subclass of Weights), see federatedml.framework.weights - List of numpy array - A pytorch model, is the subclass of torch.nn.Module - A pytorch optimizer, will extract param group from this optimizer as weights to aggregate - """ - # judge model type - to_agg_model = self._process_model(model) - suffix = self.suffix['local_model']() - self.ctx.arbiter.put(suffix, to_agg_model) - - def _get_aggregated_model(self): - return self.ctx.arbiter.get(self.suffix['agg_model']())[0] - - def _get_aggregated_loss(self): - return self.ctx.arbiter.get(self.suffix['agg_loss']())[0] - """ User API """ def model_aggregation(self, model): - self._send_model(model) - agg_model = self._get_aggregated_model() + to_send = self._process_model(model) + print('model is ', to_send) + agg_model = self.model_aggregator.secure_aggregate(self.ctx, to_send, self._weight) return self._recover_model(model, agg_model) def loss_aggregation(self, loss): - self._send_loss(loss) - + if isinstance(loss, t.Tensor): + loss = loss.detach.cpu().numpy() + else: + loss = np.array(loss) + loss = [loss] + agg_loss = self.loss_aggregator.secure_aggregate(self.ctx, loss, self._weight) + return agg_loss class PlainTextAggregatorServer(Aggregator): @@ -125,9 +109,10 @@ class PlainTextAggregatorServer(Aggregator): PlainTextAggregatorServer is used to aggregate plain text data """ - def __init__(self, ctx: Context, aggregator_name=None) -> None: + def __init__(self, ctx: Context, aggregator_name: str = None) -> None: super().__init__(ctx, aggregator_name) + weight_list = self._collect(self.suffix["local_weight"]()) weight_sum = sum(weight_list) ret_weight = [] @@ -138,6 +123,10 @@ def __init__(self, ctx: Context, aggregator_name=None) -> None: for idx, w in enumerate(ret_weight): self._broadcast(w, ret_suffix, idx) + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + self.model_aggregator = sa_server(prefix=self.aggregator_name+'_model', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_server(prefix=self.aggregator_name+'_loss', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + def _check_party_id(self, party_id): # party idx >= -1, int if not isinstance(party_id, int): @@ -145,15 +134,11 @@ def _check_party_id(self, party_id): if party_id < -1: raise ValueError("party_id should be greater than -1") - def _collect(self, suffix, party_idx=-1): - self._check_party_id(party_idx) + def _collect(self, suffix): guest_item = [self.ctx.guest.get(suffix)] host_item = self.ctx.hosts.get(suffix) combine_list = guest_item + host_item - if party_idx == -1: - return combine_list - else: - return combine_list[party_idx] + return combine_list def _broadcast(self, data, suffix, party_idx=-1): self._check_party_id(party_idx) @@ -165,54 +150,13 @@ def _broadcast(self, data, suffix, party_idx=-1): else: self.ctx.hosts[party_idx - 1].put(suffix, data) - def _aggregate_model(self, party_idx=-1): - - # get suffix - suffix = self.suffix['local_model']() - # recv params for aggregation - models = self._collect(suffix=suffix, party_idx=party_idx) - agg_result = None - # Aggregate numpy groups - if isinstance(models[0], list): - # aggregation - agg_result = models[0] - # aggregate numpy model weights from all clients - for params_group in models[1:]: - for agg_params, params in zip( - agg_result, params_group): - for agg_p, p in zip(agg_params, params): - # agg_p: NumpyWeights or numpy array - agg_p += p - else: - raise ValueError('invalid aggregation format: {}'.format(models)) - - if agg_result is None: - raise ValueError( - 'can not aggregate receive model, format is illegal: {}'.format(models)) - - return agg_result - - def _aggregate_loss(self, party_idx=-1): - - # get loss - loss_suffix = self.suffix['local_loss']() - losses = self._collect(suffix=loss_suffix, party_idx=-1) - total_loss = losses[0] - for loss in losses[1:]: - total_loss += loss - - return total_loss - """ User API """ - def model_aggregation(self, party_idx=-1): - agg_model = self._aggregate_model(party_idx=party_idx) - suffix = self.suffix['agg_model']() - self._broadcast(agg_model, suffix=suffix, party_idx=party_idx) - return agg_model + def model_aggregation(self, ranks=None): + self.model_aggregator.secure_aggregate(self.ctx, ranks=ranks) - def loss_aggregation(self, party_idx=-1): - agg_loss = self._aggregate_loss(party_idx=party_idx) - return agg_loss \ No newline at end of file + def loss_aggregation(self, ranks=None): + self.loss_aggregator.secure_aggregate(self.ctx, ranks=ranks) + \ No newline at end of file diff --git a/python/fate/ml/aggregator/test/test_aggregator.py b/python/fate/ml/aggregator/test/test_aggregator.py new file mode 100644 index 0000000000..177e8e210c --- /dev/null +++ b/python/fate/ml/aggregator/test/test_aggregator.py @@ -0,0 +1,72 @@ +import sys +import torch as t + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + +if __name__ == "__main__": + + epoch = 10 + + if sys.argv[1] == "guest": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(guest) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i in range(epoch): + client.model_aggregation(model) + elif sys.argv[1] == "host": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(host) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i in range(epoch): + client.model_aggregation(model) + + else: + + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorServer + ctx = create_ctx(arbiter) + server = PlainTextAggregatorServer(ctx) + + for i in range(epoch): + server.model_aggregation() + diff --git a/python/fate/ml/aggregator/test/test_fate_utils.py b/python/fate/ml/aggregator/test/test_fate_utils.py index 4a3c2fab6b..51ddd34311 100644 --- a/python/fate/ml/aggregator/test/test_fate_utils.py +++ b/python/fate/ml/aggregator/test/test_fate_utils.py @@ -32,7 +32,7 @@ def create_ctx(local): import numpy as np ctx = create_ctx(guest) - client = SecureAggregatorClient() + client = SecureAggregatorClient(is_mock=True) client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) print('ranks are {}'.format([ctx.guest.rank, *ctx.hosts.ranks])) print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) @@ -41,12 +41,12 @@ def create_ctx(local): import numpy as np ctx = create_ctx(host) - client = SecureAggregatorClient() + client = SecureAggregatorClient(is_mock=True) client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) else: from fate.arch.protocol import SecureAggregatorServer ctx = create_ctx(arbiter) - server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks]) + server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks], is_mock=True) server.secure_aggregate(ctx) \ No newline at end of file