Skip to content

Commit

Permalink
Update codes
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Aug 28, 2023
1 parent 3112729 commit f5e3cff
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
35 changes: 21 additions & 14 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@

logger = logging.getLogger(__name__)

FIX_POINT_PRECISION = 2**52

FIX_POINT_PRECISION = 52

class HeteroDecisionTreeGuest(DecisionTree):

def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_missing=False, goss=False, l1=0.1, l2=0,
min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=False, objective=None):
min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, objective=None):

super().__init__(max_depth, use_missing=use_missing, zero_as_missing=zero_as_missing, valid_features=valid_features)
self.host_sitenames = None
Expand Down Expand Up @@ -75,6 +74,7 @@ def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_
if gh_pack:
if objective is None:
raise ValueError('objective must be specified when gh_pack is True')
self._pack_info = {}


def set_encrypt_kit(self, kit):
Expand Down Expand Up @@ -147,15 +147,17 @@ def _g_h_process(self, grad_and_hess: DataFrame):

en_grad_hess = grad_and_hess.create_frame()

def make_long_tensor(s: pd.Series, coder, pk, encryptor, offset=0, pack_num=2, shift_bit=52):
gh = t.LongTensor([int((s['g']+offset)*FIX_POINT_PRECISION), int(s['h']*FIX_POINT_PRECISION)])
pack_vec = coder.pack_vec(gh, num_shift_bit=shift_bit, num_elem_each_pack=pack_num)
def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encryptor, pack_num=2):
pack_tensor = t.Tensor(s.values)
pack_tensor[0] = pack_tensor[0] + offset
pack_vec = coder.pack_floats(pack_tensor, shift_bit, pack_num, precision)
en = pk.encrypt_encoded(pack_vec, obfuscate=True)
return encryptor.lift(en, (len(en), 1), t.long, gh.device)
ret = encryptor.lift(en, (len(en), 1), pack_tensor.dtype, pack_tensor.device)
return ret

def compute_offset_bit(sample_num, g_max, h_max):
g_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * h_max) + 1)
g_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1)
return max(g_bit, h_bit)

if self._gh_pack:
Expand All @@ -172,12 +174,17 @@ def compute_offset_bit(sample_num, g_max, h_max):
self._g_abs_max = abs(float(grad_and_hess['g'].max()['g'])) + self._g_offset
self._h_abs_max = 2

pack_num, total_num = 2, 2
shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max)

partial_func = functools.partial(make_long_tensor, coder=self._coder, offset=self._g_offset, pk=self._pk,
shift_bit=shift_bit, pack_num=2, encryptor=self._encryptor)

shift_bit=shift_bit, pack_num=2, precision=FIX_POINT_PRECISION, encryptor=self._encryptor)
en_grad_hess['gh'] = grad_and_hess.apply_row(partial_func)

# record pack info
self._pack_info['shift_bit'] = shift_bit
self._pack_info['precision'] = FIX_POINT_PRECISION
self._pack_info['pack_num'] = pack_num
self._pack_info['total_num'] = total_num
else:
en_grad_hess['g'] = self._encryptor.encrypt_tensor(grad_and_hess['g'].as_tensor())
en_grad_hess['h'] = self._encryptor.encrypt_tensor(grad_and_hess['h'].as_tensor())
Expand Down Expand Up @@ -259,7 +266,7 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, grad_and_hess: Da
# compute histogram
hist_inst, statistic_result = self.hist_builder.compute_hist(sub_ctx, cur_layer_node, train_df, grad_and_hess, sample_pos, node_map)
# compute best splits
split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack)
split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack, self._pack_info)
# update tree with best splits
next_layer_nodes = self._update_tree(sub_ctx, cur_layer_node, split_info, train_df)
# update feature importance
Expand Down Expand Up @@ -306,4 +313,4 @@ def get_hyper_param(self):
@staticmethod
def from_model(model_dict):
return HeteroDecisionTreeGuest._from_model(model_dict, HeteroDecisionTreeGuest)


Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
from fate.arch.dataframe import DataFrame
from fate.arch import Context
from fate.arch.histogram.histogram import ShuffledHistogram


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -480,12 +481,16 @@ def _merge_splits(self, guest_splits, host_splits_list):

return splits

def _recover_pack_split(self, hist, schema):
print('schema is {}'.format(schema))
host_hist = hist.decrypt(schema[0], schema[1])
def _recover_pack_split(self, hist: ShuffledHistogram, schema, decode_schema=None):

if decode_schema is not None:
host_hist = hist.decrypt_(schema[0])
host_hist = host_hist.unpack_decode(decode_schema)
else:
host_hist = hist.decrypt(schema[0], schema[1])
return host_hist

def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack):
def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack, pack_info):

if sk is None or coder is None:
raise ValueError('sk or coder is None, not able to decode host split points')
Expand All @@ -502,12 +507,21 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code
host_splits = []
if gh_pack:
decrypt_schema = ({"gh":sk}, {"gh": (coder, torch.int64)})
# (coder, pack_num, offset_bit, precision, total_num)
if pack_info is not None:
decode_schema = {"gh": (coder, pack_info['pack_num'], pack_info['shift_bit'], pack_info['precision'], pack_info['total_num'])}
else:
raise ValueError('pack info is not provided')
else:
decrypt_schema = ({"g":sk, "h":sk}, {"g": (coder, torch.float32), "h": (coder, torch.float32)})
decode_schema = None

for idx, hist in enumerate(host_histograms):
host_sitename = ctx.hosts[idx].party[0] + '_' + ctx.hosts[idx].party[1]
host_hist = self._recover_pack_split(hist, decrypt_schema)
host_hist = self._recover_pack_split(hist, decrypt_schema, decode_schema)
print('host_hist is ', host_hist)
# coder.unpack_floats(host_hist, )
raise ValueError('cwj debug')
logger.debug('splitting host')
host_split = self._find_best_splits(host_hist, host_sitename, cur_layer_node, reverse_node_map, recover_bucket=False)
host_splits.append(host_split)
Expand All @@ -521,14 +535,14 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code
def _host_split(self, ctx: Context, en_histogram, cur_layer_node):
ctx.guest.put('hist', en_histogram)

def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None):
def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None, pack_info=None):

if ctx.is_on_guest:
if sk is None or coder is None:
raise ValueError('sk or coder is None, not able to decode host split points')
assert gh_pack is not None and isinstance(gh_pack, bool), 'gh_pack should be bool, indicating if the gh is packed'
return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack)
return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack, pack_info)
elif ctx.is_on_host:
return self._host_split(ctx, histogram_statistic_result, cur_layer_node)
else:
raise ValueError('illegal role {}'.format(ctx.role))
raise ValueError('illegal role {}'.format(ctx.role))

0 comments on commit f5e3cff

Please sign in to comment.