Skip to content

Commit

Permalink
Migrate Fate torch from 1.x version Signed-off-by: weijingchen <talki…
Browse files Browse the repository at this point in the history
…ngwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Jun 6, 2023
1 parent 35df013 commit 57f55c6
Show file tree
Hide file tree
Showing 8 changed files with 2,476 additions and 11 deletions.
7 changes: 3 additions & 4 deletions python/fate/components/components/homo_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
from fate.interface import Context
from fate.components.components.nn.setup.fate_setup import FateSetup
from fate.components.components.nn.nn_setup import NNSetup
from fate.components.components.nn.loader import Loader, _Source
from fate.components.components.nn.loader import Loader
from fate.arch.dataframe._dataframe import DataFrame
from fate.ml.nn.algo.homo.fedavg import FedAVGArguments, TrainingArguments
import logging


Expand Down Expand Up @@ -77,9 +76,9 @@ def train(
if setup_module != 'fate_setup':
if source == None:
# load from default folder
setup = Loader('fate.components.components.nn.setup.' + setup_module, setup_class, **setup_conf).load_inst()
setup = Loader('fate.components.components.nn.setup.' + setup_module, setup_class, **setup_conf).call_item()
else:
setup = Loader(setup_module, setup_class, source=source, **setup_conf).load_inst()
setup = Loader(setup_module, setup_class, source=source, **setup_conf).call_item()
assert isinstance(setup, NNSetup), 'loaded class must be a subclass of NNSetup class, but got {}'.format(type(setup))
else:
print('using default fate setup')
Expand Down
137 changes: 137 additions & 0 deletions python/fate/components/components/nn/fate_torch/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch as t
from fate.components.components.nn.loader import Loader
from torch.nn import Sequential as tSequential
import json


class FateTorch(object):

def __init__(self):
t.nn.Module.__init__(self)
self.param_dict = dict()
self.optimizer = None

def to_dict(self):
ret_dict ={
'module_name': 'torch.nn',
'item_name': str(type(self).__name__),
'kwargs': self.param_dict
}
return ret_dict


class FateTorchOptimizer(object):

def __init__(self):
self.param_dict = dict()
self.torch_class = None

def to_dict(self):
ret_dict ={
'module_name': 'torch.optim',
'item_name': type(self).__name__,
'kwargs': self.param_dict
}
return ret_dict

def check_params(self, params):

if isinstance(
params,
FateTorch) or isinstance(
params,
Sequential):
params.add_optimizer(self)
params = params.parameters()
else:
params = params

l_param = list(params)
if len(l_param) == 0:
# fake parameters, for the case that there are only cust model
return [t.nn.Parameter(t.Tensor([0]))]

return l_param

def register_optimizer(self, input_):

if input_ is None:
return
if isinstance(
input_,
FateTorch) or isinstance(
input_,
Sequential):
input_.add_optimizer(self)

def to_torch_instance(self, parameters):
return self.torch_class(parameters, **self.param_dict)


def load_seq(seq_conf: dict) -> None:

confs = list(dict(sorted(seq_conf.items())).values())
model_list = []
for conf in confs:
layer = Loader.from_dict(conf)()
model_list.append(layer)

return tSequential(*model_list)


class Sequential(tSequential):

def to_dict(self):
"""
get the structure of current sequential
"""
layer_confs = {}
idx = 0
for k in self._modules:
ordered_name = idx
layer_confs[ordered_name] = self._modules[k].to_dict()
idx += 1
ret_dict ={
'module_name': 'fate.components.components.nn.fate_torch.base',
'item_name': load_seq.__name__,
'kwargs': {'seq_conf': layer_confs}
}
return ret_dict

def to_json(self):
return json.dumps(self.to_dict(), indent=4)

def add_optimizer(self, opt):
setattr(self, 'optimizer', opt)

def add(self, layer):

if isinstance(layer, Sequential):
self._modules = layer._modules
# copy optimizer
if hasattr(layer, 'optimizer'):
setattr(self, 'optimizer', layer.optimizer)
elif isinstance(layer, FateTorch):
self.add_module(str(len(self)), layer)
# update optimizer if dont have
if not hasattr(self, 'optimizer') and hasattr(layer, 'optimizer'):
setattr(self, 'optimizer', layer.optimizer)
else:
raise ValueError(
'unknown input layer type {}, this type is not supported'.format(
type(layer)))

@staticmethod
def get_loss_config(loss: FateTorch):
return loss.to_dict()

def get_optimizer_config(self, optimizer=None):
if hasattr(self, 'optimizer'):
return self.optimizer.to_dict()
else:
return optimizer.to_dict()

def get_network_config(self):
return self.to_dict()


Loading

0 comments on commit 57f55c6

Please sign in to comment.