Skip to content

Commit

Permalink
Merge pull request #5156 from FederatedAI/dev-2.0.0-sbt-param
Browse files Browse the repository at this point in the history
Dev 2.0.0 sbt param
  • Loading branch information
mgqa34 authored Sep 8, 2023
2 parents d81362d + 06dfcaa commit e5d3e40
Show file tree
Hide file tree
Showing 22 changed files with 432 additions and 400 deletions.
15 changes: 8 additions & 7 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def set_phe(self, device: device, options: typing.Optional[dict]):
def _set_default_phe(self):
if "phe" not in self._cipher_mapping:
self._cipher_mapping["phe"] = {}
if self._device == device.CPU:
self._cipher_mapping["phe"][device.CPU] = {"kind": "paillier", "key_length": 1024}
else:
logger.warning(f"no impl exists for device {self._device}, fallback to CPU")
self._cipher_mapping["phe"][device.CPU] = self._cipher_mapping["phe"].get(
device.CPU, {"kind": "paillier", "key_length": 1024}
)
if self._device not in self._cipher_mapping["phe"]:
if self._device == device.CPU:
self._cipher_mapping["phe"][device.CPU] = {"kind": "paillier", "key_length": 1024}
else:
logger.warning(f"no impl exists for device {self._device}, fallback to CPU")
self._cipher_mapping["phe"][device.CPU] = self._cipher_mapping["phe"].get(
device.CPU, {"kind": "paillier", "key_length": 1024}
)

@property
def phe(self):
Expand Down
22 changes: 15 additions & 7 deletions python/fate/components/components/hetero_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
import logging

from fate.arch import Context
from fate.arch.dataframe import DataFrame
from fate.components.components.utils import consts, tools
from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params
from fate.components.components.utils import consts
from fate.components.core import GUEST, HOST, Role, cpn, params
from fate.ml.ensemble import HeteroSecureBoostGuest, HeteroSecureBoostHost, BINARY_BCE, MULTI_CE, REGRESSION_L2
from fate.components.components.utils.tools import add_dataset_type
from fate.components.components.utils import consts
Expand Down Expand Up @@ -46,12 +45,15 @@ def train(
objective: cpn.parameter(type=params.string_choice(choice=[BINARY_BCE, MULTI_CE, REGRESSION_L2]), default=BINARY_BCE, \
desc='objective function, available: {}'.format([BINARY_BCE, MULTI_CE, REGRESSION_L2])),
num_class: cpn.parameter(type=params.conint(gt=0), default=2, desc='class number of multi classification, active when objective is {}'.format(MULTI_CE)),
encrypt_key_length: cpn.parameter(type=params.conint(gt=0), default=2048, desc='paillier encrypt key length'),
l2: cpn.parameter(type=params.confloat(gt=0), default=0.1, desc='L2 regularization'),
min_impurity_split: cpn.parameter(type=params.confloat(gt=0), default=1e-2, desc='min impurity when splitting a tree node'),
min_sample_split: cpn.parameter(type=params.conint(gt=0), default=2, desc='min sample to split a tree node'),
min_leaf_node: cpn.parameter(type=params.conint(gt=0), default=1, desc='mininum sample contained in a leaf node'),
min_child_weight: cpn.parameter(type=params.confloat(gt=0), default=1, desc='minumum hessian contained in a leaf node'),
gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'),
split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'),
hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'),
he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'),
train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True),
train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True)
Expand All @@ -66,10 +68,16 @@ def train(

if role.is_guest:

# initialize encrypt kit

logger.info('cwj he param is {}'.format(he_param.dict()))
ctx.cipher.set_phe(ctx.device, he_param.dict())

booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin,
l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split,
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length,
objective=objective, num_class=num_class)
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, objective=objective, num_class=num_class,
gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub
)
if train_model_input is not None:
booster.from_model(train_model_input)
logger.info('sbt input model loaded, will start warmstarting')
Expand All @@ -84,7 +92,7 @@ def train(

elif role.is_host:

booster = HeteroSecureBoostHost(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin)
booster = HeteroSecureBoostHost(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, hist_sub=hist_sub)
if train_model_input is not None:
booster.from_model(train_model_input)
logger.info('sbt input model loaded, will start warmstarting')
Expand Down
1 change: 0 additions & 1 deletion python/fate/components/components/nn/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _load_item(self):
if spec is None:
# Search for similar module names
suggestion = self._find_similar_module_names()
print('suggestion is {}'.format(suggestion))
if suggestion:
raise ValueError(
"Module: {} not found in the import path. Do you mean {}?".format(
Expand Down
4 changes: 2 additions & 2 deletions python/fate/components/components/nn/nn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fate.arch.dataframe._dataframe import DataFrame
from fate.components.components.utils import consts
import logging
from fate.ml.utils.predict_tools import to_fate_df, array_to_predict_df
from fate.ml.utils.predict_tools import to_dist_df, array_to_predict_df
from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION, OTHER, LABEL, PREDICT_SCORE


Expand Down Expand Up @@ -172,7 +172,7 @@ def get_nn_output_dataframe(
df[PREDICT_SCORE] = predictions.to_list()
df[match_id_name] = match_ids.flatten()
df[sample_id_name] = sample_ids.flatten()
df = to_fate_df(ctx, sample_id_name, match_id_name, df)
df = to_dist_df(ctx, sample_id_name, match_id_name, df)
return df
elif dataframe_format == 'fate_std' and task_type in [BINARY, MULTI, REGRESSION]:
df = array_to_predict_df(ctx, task_type, predictions, match_ids, sample_ids, match_id_name, sample_id_name, labels, threshold, classes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def convert_tuples_to_lists(data):
return data


class FateTorch(object):
class PatchedTorchModule(object):

def __init__(self):
t.nn.Module.__init__(self)
Expand All @@ -32,7 +32,7 @@ def to_dict(self):
return ret_dict


class FateTorchOptimizer(object):
class PatchedTorchOptimizer(object):

def __init__(self):
self.param_dict = dict()
Expand All @@ -50,7 +50,7 @@ def check_params(self, params):

if isinstance(
params,
FateTorch) or isinstance(
PatchedTorchModule) or isinstance(
params,
Sequential):
params.add_optimizer(self)
Expand All @@ -71,7 +71,7 @@ def register_optimizer(self, input_):
return
if isinstance(
input_,
FateTorch) or isinstance(
PatchedTorchModule) or isinstance(
input_,
Sequential):
input_.add_optimizer(self)
Expand Down Expand Up @@ -104,7 +104,7 @@ def to_dict(self):
layer_confs[ordered_name] = self._modules[k].to_dict()
idx += 1
ret_dict = {
'module_name': 'fate.components.components.nn.fate_torch.base',
'module_name': 'fate.components.components.nn.patched_torch.base',
'item_name': load_seq.__name__,
'kwargs': {'seq_conf': layer_confs}
}
Expand Down
Loading

0 comments on commit e5d3e40

Please sign in to comment.