Skip to content

Commit

Permalink
Update hetero-sbt param
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Sep 8, 2023
1 parent 0cbfb1a commit 99a8ab8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
8 changes: 3 additions & 5 deletions python/fate/components/components/hetero_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train(
gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'),
split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'),
hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'),
he_param: cpn.parameter(type=params.he_param, default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'),
he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'),
train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True),
train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True)
Expand All @@ -70,14 +70,12 @@ def train(
if role.is_guest:

# initialize encrypt kit
option = {"kind": he_param.kind, "key_length": he_param.key_length}
en_kit = ctx.cipher.phe.setup(options=option)
ctx.cipher.set_phe(ctx.device, he_param.dict())

booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin,
l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split,
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length,
objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub,
encrypt_kit=en_kit
objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub
)
if train_model_input is not None:
booster.from_model(train_model_input)
Expand Down
6 changes: 3 additions & 3 deletions python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class HeteroSecureBoostGuest(HeteroBoostingTree):

def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3, max_bin=32,
l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True,
hist_sub=True, encrypt_kit=None
hist_sub=True
):

super().__init__()
Expand All @@ -60,7 +60,6 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar
self._hist_sub = hist_sub

# encryption
self._encrypt_kit = encrypt_kit
self._gh_pack = gh_pack
self._split_info_pack = split_info_pack

Expand Down Expand Up @@ -94,7 +93,8 @@ def _check_encrypt_kit(self, ctx: Context):

if self._encrypt_kit is None:
# make sure cipher is initialized
kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024})
kit = ctx.cipher.phe.setup()
self._encrypt_kit = kit
# check encrypt function
if not self._encrypt_kit.can_support_negative_number:
self._gh_pack = True
Expand Down

0 comments on commit 99a8ab8

Please sign in to comment.