Skip to content

Commit

Permalink
rename fate_torch to pacthed_torch
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Sep 8, 2023
1 parent f4672be commit 6d72e40
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 329 deletions.
1 change: 0 additions & 1 deletion python/fate/components/components/nn/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _load_item(self):
if spec is None:
# Search for similar module names
suggestion = self._find_similar_module_names()
print('suggestion is {}'.format(suggestion))
if suggestion:
raise ValueError(
"Module: {} not found in the import path. Do you mean {}?".format(
Expand Down
4 changes: 2 additions & 2 deletions python/fate/components/components/nn/nn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fate.arch.dataframe._dataframe import DataFrame
from fate.components.components.utils import consts
import logging
from fate.ml.utils.predict_tools import to_fate_df, array_to_predict_df
from fate.ml.utils.predict_tools import to_dist_df, array_to_predict_df
from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION, OTHER, LABEL, PREDICT_SCORE


Expand Down Expand Up @@ -172,7 +172,7 @@ def get_nn_output_dataframe(
df[PREDICT_SCORE] = predictions.to_list()
df[match_id_name] = match_ids.flatten()
df[sample_id_name] = sample_ids.flatten()
df = to_fate_df(ctx, sample_id_name, match_id_name, df)
df = to_dist_df(ctx, sample_id_name, match_id_name, df)
return df
elif dataframe_format == 'fate_std' and task_type in [BINARY, MULTI, REGRESSION]:
df = array_to_predict_df(ctx, task_type, predictions, match_ids, sample_ids, match_id_name, sample_id_name, labels, threshold, classes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def convert_tuples_to_lists(data):
return data


class FateTorch(object):
class PatchedTorchModule(object):

def __init__(self):
t.nn.Module.__init__(self)
Expand All @@ -32,7 +32,7 @@ def to_dict(self):
return ret_dict


class FateTorchOptimizer(object):
class PatchedTorchOptimizer(object):

def __init__(self):
self.param_dict = dict()
Expand All @@ -50,7 +50,7 @@ def check_params(self, params):

if isinstance(
params,
FateTorch) or isinstance(
PatchedTorchModule) or isinstance(
params,
Sequential):
params.add_optimizer(self)
Expand All @@ -71,7 +71,7 @@ def register_optimizer(self, input_):
return
if isinstance(
input_,
FateTorch) or isinstance(
PatchedTorchModule) or isinstance(
input_,
Sequential):
input_.add_optimizer(self)
Expand Down Expand Up @@ -104,7 +104,7 @@ def to_dict(self):
layer_confs[ordered_name] = self._modules[k].to_dict()
idx += 1
ret_dict = {
'module_name': 'fate.components.components.nn.fate_torch.base',
'module_name': 'fate.components.components.nn.patched_torch.base',
'item_name': load_seq.__name__,
'kwargs': {'seq_conf': layer_confs}
}
Expand Down
Loading

0 comments on commit 6d72e40

Please sign in to comment.