diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index e1cbf119df..484e741ac4 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -119,15 +119,15 @@ def __init__( @property def can_support_negative_number(self): - return self._tensor_cipher.can_support_negative_number + return self._can_support_negative_number @property def can_support_squeeze(self): - return self._tensor_cipher.can_support_squeeze + return self._can_support_squeeze @property def can_support_pack(self): - return self._tensor_cipher.can_support_pack + return self._can_support_pack @property def key_size(self): diff --git a/python/fate/components/components/hetero_sbt.py b/python/fate/components/components/hetero_sbt.py index 312bdb2f1d..19508848ce 100644 --- a/python/fate/components/components/hetero_sbt.py +++ b/python/fate/components/components/hetero_sbt.py @@ -45,7 +45,6 @@ def train( objective: cpn.parameter(type=params.string_choice(choice=[BINARY_BCE, MULTI_CE, REGRESSION_L2]), default=BINARY_BCE, \ desc='objective function, available: {}'.format([BINARY_BCE, MULTI_CE, REGRESSION_L2])), num_class: cpn.parameter(type=params.conint(gt=0), default=2, desc='class number of multi classification, active when objective is {}'.format(MULTI_CE)), - encrypt_key_length: cpn.parameter(type=params.conint(gt=0), default=2048, desc='paillier encrypt key length'), l2: cpn.parameter(type=params.confloat(gt=0), default=0.1, desc='L2 regularization'), min_impurity_split: cpn.parameter(type=params.confloat(gt=0), default=1e-2, desc='min impurity when splitting a tree node'), min_sample_split: cpn.parameter(type=params.conint(gt=0), default=2, desc='min sample to split a tree node'), @@ -70,12 +69,14 @@ def train( if role.is_guest: # initialize encrypt kit + + logger.info('cwj he param is {}'.format(he_param.dict())) ctx.cipher.set_phe(ctx.device, he_param.dict()) booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split, - min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length, - objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub + min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, objective=objective, num_class=num_class, + gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub ) if train_model_input is not None: booster.from_model(train_model_input) diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py index e65654bfda..b03e944a48 100644 --- a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py @@ -60,6 +60,7 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar self._hist_sub = hist_sub # encryption + self._encrypt_kit = None self._gh_pack = gh_pack self._split_info_pack = split_info_pack @@ -96,6 +97,10 @@ def _check_encrypt_kit(self, ctx: Context): kit = ctx.cipher.phe.setup() self._encrypt_kit = kit # check encrypt function + logger.info('encrypt kit info: can support negative number: {}, can support compress: {}, can support pack: {}'.format( + self._encrypt_kit.can_support_negative_number, self._encrypt_kit.can_support_squeeze, self._encrypt_kit.can_support_pack + )) + if not self._encrypt_kit.can_support_negative_number: self._gh_pack = True logger.info('current encrypt method cannot support neg num, gh pack is forced to be True')