Skip to content

Commit

Permalink
update mpc nn
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 21, 2023
1 parent b8bfade commit 16df77c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
1 change: 0 additions & 1 deletion python/fate/arch/protocol/mpc/nn/sshe/sa_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def forward(ctx, input, aggregator: "SSHEAggregator"):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
aggregator: "SSHEAggregator" = ctx.aggregator
aggregator.ctx.mpc.info(f"grad_outputs={grad_outputs}", dst=[0, 1])
ha = ctx.saved_tensors[0] if aggregator.ctx.rank == aggregator.rank_a else None
hb = ctx.saved_tensors[0] if aggregator.ctx.rank == aggregator.rank_b else None
dz = grad_outputs[0] if aggregator.ctx.rank == aggregator.rank_b else None
Expand Down
4 changes: 2 additions & 2 deletions python/fate/ml/mpc/sshe_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def validate_run(self, ctx: Context):
hb = torch.rand(num_samples, in_features_b, requires_grad=True, generator=torch.Generator().manual_seed(1))
h = ctx.mpc.cond_call(lambda: ha, lambda: hb, dst=0)

generator = torch.Generator().manual_seed(0)
# generator = torch.Generator().manual_seed(0)
layer = SSHEAggregatorLayer(
ctx,
in_features_a=in_features_a,
Expand All @@ -38,7 +38,7 @@ def validate_run(self, ctx: Context):
rank_a=0,
rank_b=1,
lr=lr,
generator=generator,
# generator=generator,
)
z = layer(h)
ctx.mpc.info(f"forward={z}", dst=[1])
Expand Down

0 comments on commit 16df77c

Please sign in to comment.