Skip to content

Commit

Permalink
Merge branch 'dev-2.0.0-beta' into dev-2.0.0-gh
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Aug 28, 2023
2 parents 1f01c7c + 4c40c13 commit 34dd6ba
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def decode(self, coder, dtype):

def unpack(self, coder, pack_num, offset_bit, precision, total_num):
return HistogramPlainValues(coder.unpack_floats(self.data, offset_bit, pack_num, precision, total_num),
self.stride)
pack_num)

def slice(self, start, end):
if hasattr(self.data, "slice"):
Expand Down Expand Up @@ -418,7 +418,8 @@ def i_squeeze(self, squeeze_map):
def i_unpack_decode(self, coder_map):
for name, value in self._data.items():
if name in coder_map:
coder, pack_num, offset_bit, precision, total_num = coder_map[name]
coder, pack_num, offset_bit, precision = coder_map[name]
total_num = (self.end - self.start) * self.num_node * pack_num
self._data[name] = value.unpack(coder, pack_num, offset_bit, precision, total_num)
return self

Expand Down
37 changes: 37 additions & 0 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,21 @@

logger = logging.getLogger(__name__)

<<<<<<< HEAD
FIX_POINT_PRECISION = 52
=======
FIX_POINT_PRECISION = 2**52

>>>>>>> dev-2.0.0-beta

class HeteroDecisionTreeGuest(DecisionTree):

def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_missing=False, goss=False, l1=0.1, l2=0,
<<<<<<< HEAD
min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, objective=None):
=======
min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=False, objective=None):
>>>>>>> dev-2.0.0-beta

super().__init__(max_depth, use_missing=use_missing, zero_as_missing=zero_as_missing, valid_features=valid_features)
self.host_sitenames = None
Expand Down Expand Up @@ -74,7 +83,10 @@ def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_
if gh_pack:
if objective is None:
raise ValueError('objective must be specified when gh_pack is True')
<<<<<<< HEAD
self._pack_info = {}
=======
>>>>>>> dev-2.0.0-beta


def set_encrypt_kit(self, kit):
Expand Down Expand Up @@ -147,6 +159,7 @@ def _g_h_process(self, grad_and_hess: DataFrame):

en_grad_hess = grad_and_hess.create_frame()

<<<<<<< HEAD
def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encryptor, pack_num=2):
pack_tensor = t.Tensor(s.values)
pack_tensor[0] = pack_tensor[0] + offset
Expand All @@ -158,6 +171,17 @@ def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encr
def compute_offset_bit(sample_num, g_max, h_max):
g_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1)
=======
def make_long_tensor(s: pd.Series, coder, pk, encryptor, offset=0, pack_num=2, shift_bit=52):
gh = t.LongTensor([int((s['g']+offset)*FIX_POINT_PRECISION), int(s['h']*FIX_POINT_PRECISION)])
pack_vec = coder.pack_vec(gh, num_shift_bit=shift_bit, num_elem_each_pack=pack_num)
en = pk.encrypt_encoded(pack_vec, obfuscate=True)
return encryptor.lift(en, (len(en), 1), t.long, gh.device)

def compute_offset_bit(sample_num, g_max, h_max):
g_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * h_max) + 1)
>>>>>>> dev-2.0.0-beta
return max(g_bit, h_bit)

if self._gh_pack:
Expand All @@ -174,6 +198,7 @@ def compute_offset_bit(sample_num, g_max, h_max):
self._g_abs_max = abs(float(grad_and_hess['g'].max()['g'])) + self._g_offset
self._h_abs_max = 2

<<<<<<< HEAD
pack_num, total_num = 2, 2
shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max)
partial_func = functools.partial(make_long_tensor, coder=self._coder, offset=self._g_offset, pk=self._pk,
Expand All @@ -185,6 +210,14 @@ def compute_offset_bit(sample_num, g_max, h_max):
self._pack_info['precision'] = FIX_POINT_PRECISION
self._pack_info['pack_num'] = pack_num
self._pack_info['total_num'] = total_num
=======
shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max)

partial_func = functools.partial(make_long_tensor, coder=self._coder, offset=self._g_offset, pk=self._pk,
shift_bit=shift_bit, pack_num=2, encryptor=self._encryptor)

en_grad_hess['gh'] = grad_and_hess.apply_row(partial_func)
>>>>>>> dev-2.0.0-beta
else:
en_grad_hess['g'] = self._encryptor.encrypt_tensor(grad_and_hess['g'].as_tensor())
en_grad_hess['h'] = self._encryptor.encrypt_tensor(grad_and_hess['h'].as_tensor())
Expand Down Expand Up @@ -266,7 +299,11 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, grad_and_hess: Da
# compute histogram
hist_inst, statistic_result = self.hist_builder.compute_hist(sub_ctx, cur_layer_node, train_df, grad_and_hess, sample_pos, node_map)
# compute best splits
<<<<<<< HEAD
split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack, self._pack_info)
=======
split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack)
>>>>>>> dev-2.0.0-beta
# update tree with best splits
next_layer_nodes = self._update_tree(sub_ctx, cur_layer_node, split_info, train_df)
# update feature importance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def _merge_splits(self, guest_splits, host_splits_list):

return splits

<<<<<<< HEAD
def _recover_pack_split(self, hist: ShuffledHistogram, schema, decode_schema=None):

if decode_schema is not None:
Expand All @@ -492,6 +493,14 @@ def _recover_pack_split(self, hist: ShuffledHistogram, schema, decode_schema=Non
return host_hist

def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack, pack_info):
=======
def _recover_pack_split(self, hist, schema):
print('schema is {}'.format(schema))
host_hist = hist.decrypt(schema[0], schema[1])
return host_hist

def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack):
>>>>>>> dev-2.0.0-beta

if sk is None or coder is None:
raise ValueError('sk or coder is None, not able to decode host split points')
Expand All @@ -508,6 +517,7 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code
host_splits = []
if gh_pack:
decrypt_schema = ({"gh":sk}, {"gh": (coder, torch.int64)})
<<<<<<< HEAD
# (coder, pack_num, offset_bit, precision, total_num)
if pack_info is not None:
decode_schema = {"gh": (coder, pack_info['pack_num'], pack_info['shift_bit'], pack_info['precision'])}
Expand All @@ -523,6 +533,14 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code
print('host_hist is ', host_hist)
# coder.unpack_floats(host_hist, )
raise ValueError('cwj debug')
=======
else:
decrypt_schema = ({"g":sk, "h":sk}, {"g": (coder, torch.float32), "h": (coder, torch.float32)})

for idx, hist in enumerate(host_histograms):
host_sitename = ctx.hosts[idx].party[0] + '_' + ctx.hosts[idx].party[1]
host_hist = self._recover_pack_split(hist, decrypt_schema)
>>>>>>> dev-2.0.0-beta
logger.debug('splitting host')
host_split = self._find_best_splits(host_hist, host_sitename, cur_layer_node, reverse_node_map, recover_bucket=False)
host_splits.append(host_split)
Expand All @@ -536,13 +554,21 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code
def _host_split(self, ctx: Context, en_histogram, cur_layer_node):
ctx.guest.put('hist', en_histogram)

<<<<<<< HEAD
def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None, pack_info=None):
=======
def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None):
>>>>>>> dev-2.0.0-beta

if ctx.is_on_guest:
if sk is None or coder is None:
raise ValueError('sk or coder is None, not able to decode host split points')
assert gh_pack is not None and isinstance(gh_pack, bool), 'gh_pack should be bool, indicating if the gh is packed'
<<<<<<< HEAD
return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack, pack_info)
=======
return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack)
>>>>>>> dev-2.0.0-beta
elif ctx.is_on_host:
return self._host_split(ctx, histogram_statistic_result, cur_layer_node)
else:
Expand Down

0 comments on commit 34dd6ba

Please sign in to comment.