Skip to content

Commit

Permalink
Updates of Hetero-NN Model
Browse files Browse the repository at this point in the history
1. Optimizer grad computation
2. Support 1 side training
3. Support No guest bottom model
Signed-off-by: weijingchen <talkingwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Oct 19, 2023
1 parent f0508b4 commit bf52b03
Show file tree
Hide file tree
Showing 15 changed files with 587 additions and 285 deletions.
1 change: 1 addition & 0 deletions python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def get(self, name: str, tag: str, parties: List[PartyMeta]) -> List:

for party in parties:
_tagged_key = self._federation_object_key(name, tag, party, self._party)

results.append(self._meta.wait_status_set(_tagged_key))

rtn = []
Expand Down
1 change: 1 addition & 0 deletions python/fate/arch/federation/standalone/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def pull(
if (name, tag, party) in self._get_history:
raise ValueError(f"get from {party} with duplicate tag: {name}.{tag}")
self._get_history.add((name, tag, party))

rtn = self._federation.get(name=name, tag=tag, parties=parties)
return [Table(r) if isinstance(r, RawTable) else r for r in rtn]

Expand Down
128 changes: 0 additions & 128 deletions python/fate/ml/nn/hetero/agg_layer/plaintext_agg_layer.py

This file was deleted.

File renamed without changes.
77 changes: 0 additions & 77 deletions python/fate/ml/nn/hetero/model/guest.py

This file was deleted.

56 changes: 0 additions & 56 deletions python/fate/ml/nn/hetero/model/host.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ def backward_loss(z, backward_error):
return t.sum(z * backward_error)

class InteractiveLayer(t.nn.Module):
def __init__(self, ctx: Context):
def __init__(self):
super().__init__()
self.ctx = ctx
self._ctx = None
self._fw_suffix = "interactive_fw_{}"
self._bw_suffix = "interactive_bw_{}"
self._pred_suffix = "interactive_pred_{}"
self._fw_count = 0
self._bw_count = 0
self._pred_count = 0
self._has_ctx = False

def forward(self, x):
raise NotImplementedError()
Expand All @@ -24,6 +25,17 @@ def backward(self, error):
def predict(self, x):
raise NotImplementedError()

def set_context(self, ctx: Context):
self._ctx = ctx
self._has_ctx = True
def has_context(self):
return self._has_ctx
@property
def ctx(self):
if self._ctx is None or self._has_ctx == False:
raise ValueError("Context is not set yet, please call set_context() first")
return self._ctx

def _clear_state(self):
pass

Expand Down
Loading

0 comments on commit bf52b03

Please sign in to comment.