Skip to content

Commit

Permalink
update codes
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Sep 8, 2023
1 parent 5ab3676 commit f7293b3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def __init__(

@property
def can_support_negative_number(self):
return self._tensor_cipher.can_support_negative_number
return self._can_support_negative_number

@property
def can_support_squeeze(self):
return self._tensor_cipher.can_support_squeeze
return self._can_support_squeeze

@property
def can_support_pack(self):
return self._tensor_cipher.can_support_pack
return self._can_support_pack

@property
def key_size(self):
Expand Down
7 changes: 4 additions & 3 deletions python/fate/components/components/hetero_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def train(
objective: cpn.parameter(type=params.string_choice(choice=[BINARY_BCE, MULTI_CE, REGRESSION_L2]), default=BINARY_BCE, \
desc='objective function, available: {}'.format([BINARY_BCE, MULTI_CE, REGRESSION_L2])),
num_class: cpn.parameter(type=params.conint(gt=0), default=2, desc='class number of multi classification, active when objective is {}'.format(MULTI_CE)),
encrypt_key_length: cpn.parameter(type=params.conint(gt=0), default=2048, desc='paillier encrypt key length'),
l2: cpn.parameter(type=params.confloat(gt=0), default=0.1, desc='L2 regularization'),
min_impurity_split: cpn.parameter(type=params.confloat(gt=0), default=1e-2, desc='min impurity when splitting a tree node'),
min_sample_split: cpn.parameter(type=params.conint(gt=0), default=2, desc='min sample to split a tree node'),
Expand All @@ -70,12 +69,14 @@ def train(
if role.is_guest:

# initialize encrypt kit

logger.info('cwj he param is {}'.format(he_param.dict()))
ctx.cipher.set_phe(ctx.device, he_param.dict())

booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin,
l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split,
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length,
objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, objective=objective, num_class=num_class,
gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub
)
if train_model_input is not None:
booster.from_model(train_model_input)
Expand Down
5 changes: 5 additions & 0 deletions python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar
self._hist_sub = hist_sub

# encryption
self._encrypt_kit = None
self._gh_pack = gh_pack
self._split_info_pack = split_info_pack

Expand Down Expand Up @@ -96,6 +97,10 @@ def _check_encrypt_kit(self, ctx: Context):
kit = ctx.cipher.phe.setup()
self._encrypt_kit = kit
# check encrypt function
logger.info('encrypt kit info: can support negative number: {}, can support compress: {}, can support pack: {}'.format(
self._encrypt_kit.can_support_negative_number, self._encrypt_kit.can_support_squeeze, self._encrypt_kit.can_support_pack
))

if not self._encrypt_kit.can_support_negative_number:
self._gh_pack = True
logger.info('current encrypt method cannot support neg num, gh pack is forced to be True')
Expand Down

0 comments on commit f7293b3

Please sign in to comment.