Skip to content

Commit

Permalink
add init method
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 27, 2023
1 parent ea7536d commit cae7ceb
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 24 deletions.
5 changes: 2 additions & 3 deletions python/fate/arch/context/_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,11 @@ def sshe(self):

return SSHE

def random_tensor(self, shape, src=0, generator=None):
import torch
def init_tensor(self, shape, init_func, src):
from fate.arch.protocol.mpc.primitives import ArithmeticSharedTensor

if self.rank == src:
return ArithmeticSharedTensor(self._ctx, torch.rand(shape, generator=generator), src=src)
return ArithmeticSharedTensor(self._ctx, init_func(shape), src=src)

else:
return ArithmeticSharedTensor(self._ctx, None, size=shape, src=src)
13 changes: 9 additions & 4 deletions python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import typing

import torch

from fate.arch.context import Context
from fate.arch.utils.trace import auto_trace
from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings
from fate.arch.protocol.mpc.mpc import FixedPointEncoder
from fate.arch.utils.trace import auto_trace


class SSHELinearRegressionLayer:
Expand All @@ -13,9 +17,10 @@ def __init__(
out_features,
rank_a,
rank_b,
wa_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
wb_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
precision_bits=None,
sync_shape=True,
generator=None,
):
self.ctx = ctx
self.rank_a = rank_a
Expand All @@ -41,8 +46,8 @@ def __init__(
in_features_b is not None, "in_features_b must be specified when sync_shape is False", dst=rank_b
)

self.wa = ctx.mpc.random_tensor(shape=(in_features_a, out_features), src=rank_a, generator=generator)
self.wb = ctx.mpc.random_tensor(shape=(in_features_b, out_features), src=rank_b, generator=generator)
self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a)
self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b)
self.phe_cipher = ctx.cipher.phe.setup()
self.precision_bits = precision_bits

Expand Down
12 changes: 8 additions & 4 deletions python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import typing

import torch

from fate.arch.context import Context
from fate.arch.utils.trace import auto_trace
from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings
from fate.arch.protocol.mpc.mpc import FixedPointEncoder
from fate.arch.utils.trace import auto_trace


class SSHELogisticRegressionLayer:
Expand All @@ -14,9 +17,10 @@ def __init__(
out_features,
rank_a,
rank_b,
wa_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
wb_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
precision_bits=None,
sync_shape=True,
generator=None,
):
self.ctx = ctx
self.rank_a = rank_a
Expand All @@ -42,8 +46,8 @@ def __init__(
in_features_b is not None, "in_features_b must be specified when sync_shape is False", dst=rank_b
)

self.wa = ctx.mpc.random_tensor(shape=(in_features_a, out_features), src=rank_a, generator=generator)
self.wb = ctx.mpc.random_tensor(shape=(in_features_b, out_features), src=rank_b, generator=generator)
self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a)
self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b)
self.phe_cipher = ctx.cipher.phe.setup()
self.precision_bits = precision_bits

Expand Down
27 changes: 15 additions & 12 deletions python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing
from typing import Any, Iterator

import torch
from torch.nn import Parameter

from fate.arch.context import Context
from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings
from fate.arch.utils.trace import auto_trace

from fate.arch.protocol.mpc.mpc import FixedPointEncoder
from fate.arch.utils.trace import auto_trace


class SSHENeuralNetworkAggregatorLayer(torch.nn.Module):
Expand All @@ -19,18 +19,20 @@ def __init__(
out_features,
rank_a,
rank_b,
wa_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
wb_init_fn: typing.Callable[[typing.Tuple], torch.Tensor],
precision_bits=None,
generator=None,
):
self.aggregator = SSHENeuralNetworkAggregator(
ctx,
in_features_a,
in_features_b,
out_features,
rank_a,
rank_b,
in_features_a=in_features_a,
in_features_b=in_features_b,
out_features=out_features,
rank_a=rank_a,
rank_b=rank_b,
encoder=FixedPointEncoder(precision_bits),
generator=generator,
wa_init_fn=wa_init_fn,
wb_init_fn=wb_init_fn,
)
super().__init__()

Expand Down Expand Up @@ -75,13 +77,14 @@ def __init__(
rank_a,
rank_b,
encoder,
wa_init_fn,
wb_init_fn,
precision_bits=None,
generator=None,
):
self.ctx = ctx
self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_nn_aggregator_layer")
self.wa = ctx.mpc.random_tensor(shape=(in_features_a, out_features), src=rank_a, generator=generator)
self.wb = ctx.mpc.random_tensor(shape=(in_features_b, out_features), src=rank_b, generator=generator)
self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a)
self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b)
self.phe_cipher = ctx.cipher.phe.setup()
self.rank_a = rank_a
self.rank_b = rank_b
Expand Down
3 changes: 2 additions & 1 deletion python/fate/ml/mpc/sshe_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def validate_run(self, ctx: Context):
out_features=out_features,
rank_a=0,
rank_b=1,
generator=generator,
wa_init_fn=lambda shape: torch.rand(shape, generator=generator),
wb_init_fn=lambda shape: torch.rand(shape, generator=generator),
)
optimizer = SSHENeuralNetworkOptimizerSGD(ctx, layer.parameters(), lr=lr)
z = layer(h)
Expand Down

0 comments on commit cae7ceb

Please sign in to comment.