-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate Fate torch from 1.x version Signed-off-by: weijingchen <talki…
…ngwallace@sohu.com> Signed-off-by: cwj <talkingwallace@sohu.com>
- Loading branch information
1 parent
35df013
commit 57f55c6
Showing
8 changed files
with
2,476 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
python/fate/components/components/nn/fate_torch/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch as t | ||
from fate.components.components.nn.loader import Loader | ||
from torch.nn import Sequential as tSequential | ||
import json | ||
|
||
|
||
class FateTorch(object): | ||
|
||
def __init__(self): | ||
t.nn.Module.__init__(self) | ||
self.param_dict = dict() | ||
self.optimizer = None | ||
|
||
def to_dict(self): | ||
ret_dict ={ | ||
'module_name': 'torch.nn', | ||
'item_name': str(type(self).__name__), | ||
'kwargs': self.param_dict | ||
} | ||
return ret_dict | ||
|
||
|
||
class FateTorchOptimizer(object): | ||
|
||
def __init__(self): | ||
self.param_dict = dict() | ||
self.torch_class = None | ||
|
||
def to_dict(self): | ||
ret_dict ={ | ||
'module_name': 'torch.optim', | ||
'item_name': type(self).__name__, | ||
'kwargs': self.param_dict | ||
} | ||
return ret_dict | ||
|
||
def check_params(self, params): | ||
|
||
if isinstance( | ||
params, | ||
FateTorch) or isinstance( | ||
params, | ||
Sequential): | ||
params.add_optimizer(self) | ||
params = params.parameters() | ||
else: | ||
params = params | ||
|
||
l_param = list(params) | ||
if len(l_param) == 0: | ||
# fake parameters, for the case that there are only cust model | ||
return [t.nn.Parameter(t.Tensor([0]))] | ||
|
||
return l_param | ||
|
||
def register_optimizer(self, input_): | ||
|
||
if input_ is None: | ||
return | ||
if isinstance( | ||
input_, | ||
FateTorch) or isinstance( | ||
input_, | ||
Sequential): | ||
input_.add_optimizer(self) | ||
|
||
def to_torch_instance(self, parameters): | ||
return self.torch_class(parameters, **self.param_dict) | ||
|
||
|
||
def load_seq(seq_conf: dict) -> None: | ||
|
||
confs = list(dict(sorted(seq_conf.items())).values()) | ||
model_list = [] | ||
for conf in confs: | ||
layer = Loader.from_dict(conf)() | ||
model_list.append(layer) | ||
|
||
return tSequential(*model_list) | ||
|
||
|
||
class Sequential(tSequential): | ||
|
||
def to_dict(self): | ||
""" | ||
get the structure of current sequential | ||
""" | ||
layer_confs = {} | ||
idx = 0 | ||
for k in self._modules: | ||
ordered_name = idx | ||
layer_confs[ordered_name] = self._modules[k].to_dict() | ||
idx += 1 | ||
ret_dict ={ | ||
'module_name': 'fate.components.components.nn.fate_torch.base', | ||
'item_name': load_seq.__name__, | ||
'kwargs': {'seq_conf': layer_confs} | ||
} | ||
return ret_dict | ||
|
||
def to_json(self): | ||
return json.dumps(self.to_dict(), indent=4) | ||
|
||
def add_optimizer(self, opt): | ||
setattr(self, 'optimizer', opt) | ||
|
||
def add(self, layer): | ||
|
||
if isinstance(layer, Sequential): | ||
self._modules = layer._modules | ||
# copy optimizer | ||
if hasattr(layer, 'optimizer'): | ||
setattr(self, 'optimizer', layer.optimizer) | ||
elif isinstance(layer, FateTorch): | ||
self.add_module(str(len(self)), layer) | ||
# update optimizer if dont have | ||
if not hasattr(self, 'optimizer') and hasattr(layer, 'optimizer'): | ||
setattr(self, 'optimizer', layer.optimizer) | ||
else: | ||
raise ValueError( | ||
'unknown input layer type {}, this type is not supported'.format( | ||
type(layer))) | ||
|
||
@staticmethod | ||
def get_loss_config(loss: FateTorch): | ||
return loss.to_dict() | ||
|
||
def get_optimizer_config(self, optimizer=None): | ||
if hasattr(self, 'optimizer'): | ||
return self.optimizer.to_dict() | ||
else: | ||
return optimizer.to_dict() | ||
|
||
def get_network_config(self): | ||
return self.to_dict() | ||
|
||
|
Oops, something went wrong.