diff --git a/python/fate/components/components/hetero_sbt.py b/python/fate/components/components/hetero_sbt.py index 714d73e2db..3c8274c109 100644 --- a/python/fate/components/components/hetero_sbt.py +++ b/python/fate/components/components/hetero_sbt.py @@ -16,9 +16,8 @@ import logging from fate.arch import Context -from fate.arch.dataframe import DataFrame -from fate.components.components.utils import consts, tools -from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.components.components.utils import consts +from fate.components.core import GUEST, HOST, Role, cpn, params from fate.ml.ensemble import HeteroSecureBoostGuest, HeteroSecureBoostHost, BINARY_BCE, MULTI_CE, REGRESSION_L2 from fate.components.components.utils.tools import add_dataset_type from fate.components.components.utils import consts @@ -55,6 +54,7 @@ def train( gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'), split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'), hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'), + he_param: cpn.parameter(type=params.he_param, default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'), train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True), train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True) @@ -69,10 +69,16 @@ def train( if role.is_guest: + # initialize encrypt kit + option = {"kind": he_param.kind, "key_length": he_param.key_length} + en_kit = ctx.cipher.phe.setup(options=option) + booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split, min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length, - objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub) + objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub, + encrypt_kit=en_kit + ) if train_model_input is not None: booster.from_model(train_model_input) logger.info('sbt input model loaded, will start warmstarting') diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py index b8384ad0a4..886f0c9290 100644 --- a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py @@ -31,10 +31,10 @@ class HeteroSecureBoostGuest(HeteroBoostingTree): - def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3, - max_bin=32, encrypt_key_length=2048, l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True, - hist_sub=True - ) -> None: + def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3, max_bin=32, + l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True, + hist_sub=True, encrypt_kit=None + ): super().__init__() self.num_trees = num_trees @@ -60,8 +60,7 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar self._hist_sub = hist_sub # encryption - self._encrypt_kit = None - self._encrypt_key_length = encrypt_key_length + self._encrypt_kit = encrypt_kit self._gh_pack = gh_pack self._split_info_pack = split_info_pack @@ -91,8 +90,22 @@ def _compute_gh(self, data: DataFrame, scores: DataFrame, loss_func): loss_func.compute_hess(gh, label, predict) return gh - def _init_encrypt_kit(self, ctx): - kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": self._encrypt_key_length}) + def _check_encrypt_kit(self, ctx: Context): + + if self._encrypt_kit is None: + # make sure cipher is initialized + kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) + # check encrypt function + if not self._encrypt_kit.can_support_negative_number: + self._gh_pack = True + logger.info('current encrypt method cannot support neg num, gh pack is forced to be True') + if not self._encrypt_kit.can_support_squeeze: + self._split_info_pack = False + logger.info('current encrypt method cannot support compress, split info pack is forced to be False') + if not self._encrypt_kit.can_support_pack: + self._gh_pack = False + self._split_info_pack = False + logger.info('current encrypt method cannot support pack, gh pack is forced to be False') return kit def get_train_predict(self): @@ -172,7 +185,7 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No self._init_sample_scores(ctx, label, train_data) # init encryption kit - self._encrypt_kit= self._init_encrypt_kit(ctx) + self._encrypt_kit= self._check_encrypt_kit(ctx) # start tree fittingf for tree_idx, tree_ctx in ctx.on_iterations.ctxs_range(len(self._trees), len(self._trees)+self.num_trees): diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py index 440e1827ae..61ba041fdb 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py @@ -90,7 +90,7 @@ def set_encrypt_kit(self, kit): self._decryptor = kit.get_tensor_decryptor() logger.info('encrypt kit setup through setter') - def _init_encrypt_kit(self, ctx): + def _init_encrypt_kit(self, ctx: Context): kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) self._en_key_length = kit.key_size self._sk, self._pk, self._coder, self._evaluator, self._encryptor = kit.sk, kit.pk, kit.coder, kit.evaluator, kit.get_tensor_encryptor()