Skip to content

Commit

Permalink
support he param
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Sep 8, 2023
1 parent f27c4ae commit 0cbfb1a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
14 changes: 10 additions & 4 deletions python/fate/components/components/hetero_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
import logging

from fate.arch import Context
from fate.arch.dataframe import DataFrame
from fate.components.components.utils import consts, tools
from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params
from fate.components.components.utils import consts
from fate.components.core import GUEST, HOST, Role, cpn, params
from fate.ml.ensemble import HeteroSecureBoostGuest, HeteroSecureBoostHost, BINARY_BCE, MULTI_CE, REGRESSION_L2
from fate.components.components.utils.tools import add_dataset_type
from fate.components.components.utils import consts
Expand Down Expand Up @@ -55,6 +54,7 @@ def train(
gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'),
split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'),
hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'),
he_param: cpn.parameter(type=params.he_param, default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'),
train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True),
train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True)
Expand All @@ -69,10 +69,16 @@ def train(

if role.is_guest:

# initialize encrypt kit
option = {"kind": he_param.kind, "key_length": he_param.key_length}
en_kit = ctx.cipher.phe.setup(options=option)

booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin,
l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split,
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, encrypt_key_length=encrypt_key_length,
objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub)
objective=objective, num_class=num_class, gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub,
encrypt_kit=en_kit
)
if train_model_input is not None:
booster.from_model(train_model_input)
logger.info('sbt input model loaded, will start warmstarting')
Expand Down
31 changes: 22 additions & 9 deletions python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

class HeteroSecureBoostGuest(HeteroBoostingTree):

def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3,
max_bin=32, encrypt_key_length=2048, l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True,
hist_sub=True
) -> None:
def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binary:bce', num_class=3, max_bin=32,
l2=0.1, l1=0, min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, split_info_pack=True,
hist_sub=True, encrypt_kit=None
):

super().__init__()
self.num_trees = num_trees
Expand All @@ -60,8 +60,7 @@ def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, objective='binar
self._hist_sub = hist_sub

# encryption
self._encrypt_kit = None
self._encrypt_key_length = encrypt_key_length
self._encrypt_kit = encrypt_kit
self._gh_pack = gh_pack
self._split_info_pack = split_info_pack

Expand Down Expand Up @@ -91,8 +90,22 @@ def _compute_gh(self, data: DataFrame, scores: DataFrame, loss_func):
loss_func.compute_hess(gh, label, predict)
return gh

def _init_encrypt_kit(self, ctx):
kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": self._encrypt_key_length})
def _check_encrypt_kit(self, ctx: Context):

if self._encrypt_kit is None:
# make sure cipher is initialized
kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024})
# check encrypt function
if not self._encrypt_kit.can_support_negative_number:
self._gh_pack = True
logger.info('current encrypt method cannot support neg num, gh pack is forced to be True')
if not self._encrypt_kit.can_support_squeeze:
self._split_info_pack = False
logger.info('current encrypt method cannot support compress, split info pack is forced to be False')
if not self._encrypt_kit.can_support_pack:
self._gh_pack = False
self._split_info_pack = False
logger.info('current encrypt method cannot support pack, gh pack is forced to be False')
return kit

def get_train_predict(self):
Expand Down Expand Up @@ -172,7 +185,7 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No
self._init_sample_scores(ctx, label, train_data)

# init encryption kit
self._encrypt_kit= self._init_encrypt_kit(ctx)
self._encrypt_kit= self._check_encrypt_kit(ctx)

# start tree fittingf
for tree_idx, tree_ctx in ctx.on_iterations.ctxs_range(len(self._trees), len(self._trees)+self.num_trees):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def set_encrypt_kit(self, kit):
self._decryptor = kit.get_tensor_decryptor()
logger.info('encrypt kit setup through setter')

def _init_encrypt_kit(self, ctx):
def _init_encrypt_kit(self, ctx: Context):
kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024})
self._en_key_length = kit.key_size
self._sk, self._pk, self._coder, self._evaluator, self._encryptor = kit.sk, kit.pk, kit.coder, kit.evaluator, kit.get_tensor_encryptor()
Expand Down

0 comments on commit 0cbfb1a

Please sign in to comment.