diff --git a/python/fate/components/components/hetero_sbt.py b/python/fate/components/components/hetero_sbt.py index 3c8274c109..312bdb2f1d 100644 --- a/python/fate/components/components/hetero_sbt.py +++ b/python/fate/components/components/hetero_sbt.py @@ -54,7 +54,7 @@ def train( gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'), split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'), hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'), - he_param: cpn.parameter(type=params.he_param, default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'), train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True), train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True) @@ -70,14 +70,12 @@ def train( if role.is_guest: # initialize encrypt kit - option = {"kind": he_param.kind, "key_length": he_param.key_length} - en_kit = ctx.cipher.phe.setup(options=option) + ctx.cipher.set_phe(ctx.device, he_param.dict()) booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split, min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length, - objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub, - encrypt_kit=en_kit + objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub ) if train_model_input is not None: booster.from_model(train_model_input) diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py index 886f0c9290..e65654bfda 100644 --- a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py @@ -33,7 +33,7 @@ class HeteroSecureBoostGuest(HeteroBoostingTree): def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3, max_bin=32, l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True, - hist_sub=True, encrypt_kit=None + hist_sub=True ): super().__init__() @@ -60,7 +60,6 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar self._hist_sub = hist_sub # encryption - self._encrypt_kit = encrypt_kit self._gh_pack = gh_pack self._split_info_pack = split_info_pack @@ -94,7 +93,8 @@ def _check_encrypt_kit(self, ctx: Context): if self._encrypt_kit is None: # make sure cipher is initialized - kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) + kit = ctx.cipher.phe.setup() + self._encrypt_kit = kit # check encrypt function if not self._encrypt_kit.can_support_negative_number: self._gh_pack = True