diff --git a/python/fate/arch/histogram/histogram.py b/python/fate/arch/histogram/histogram.py index 34eaf0b5fc..4d9c0bdfed 100644 --- a/python/fate/arch/histogram/histogram.py +++ b/python/fate/arch/histogram/histogram.py @@ -159,7 +159,7 @@ def decode(self, coder, dtype): def unpack(self, coder, pack_num, offset_bit, precision, total_num): return HistogramPlainValues(coder.unpack_floats(self.data, offset_bit, pack_num, precision, total_num), - self.stride) + pack_num) def slice(self, start, end): if hasattr(self.data, "slice"): @@ -418,7 +418,8 @@ def i_squeeze(self, squeeze_map): def i_unpack_decode(self, coder_map): for name, value in self._data.items(): if name in coder_map: - coder, pack_num, offset_bit, precision, total_num = coder_map[name] + coder, pack_num, offset_bit, precision = coder_map[name] + total_num = (self.end - self.start) * self.num_node * pack_num self._data[name] = value.unpack(coder, pack_num, offset_bit, precision, total_num) return self diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py index 441f4f23c8..c4313673bf 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py @@ -29,12 +29,21 @@ logger = logging.getLogger(__name__) +<<<<<<< HEAD FIX_POINT_PRECISION = 52 +======= +FIX_POINT_PRECISION = 2**52 + +>>>>>>> dev-2.0.0-beta class HeteroDecisionTreeGuest(DecisionTree): def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_missing=False, goss=False, l1=0.1, l2=0, +<<<<<<< HEAD min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=True, objective=None): +======= + min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1, min_child_weight=1, gh_pack=False, objective=None): +>>>>>>> dev-2.0.0-beta super().__init__(max_depth, use_missing=use_missing, zero_as_missing=zero_as_missing, valid_features=valid_features) self.host_sitenames = None @@ -74,7 +83,10 @@ def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_ if gh_pack: if objective is None: raise ValueError('objective must be specified when gh_pack is True') +<<<<<<< HEAD self._pack_info = {} +======= +>>>>>>> dev-2.0.0-beta def set_encrypt_kit(self, kit): @@ -147,6 +159,7 @@ def _g_h_process(self, grad_and_hess: DataFrame): en_grad_hess = grad_and_hess.create_frame() +<<<<<<< HEAD def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encryptor, pack_num=2): pack_tensor = t.Tensor(s.values) pack_tensor[0] = pack_tensor[0] + offset @@ -158,6 +171,17 @@ def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encr def compute_offset_bit(sample_num, g_max, h_max): g_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety h_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1) +======= + def make_long_tensor(s: pd.Series, coder, pk, encryptor, offset=0, pack_num=2, shift_bit=52): + gh = t.LongTensor([int((s['g']+offset)*FIX_POINT_PRECISION), int(s['h']*FIX_POINT_PRECISION)]) + pack_vec = coder.pack_vec(gh, num_shift_bit=shift_bit, num_elem_each_pack=pack_num) + en = pk.encrypt_encoded(pack_vec, obfuscate=True) + return encryptor.lift(en, (len(en), 1), t.long, gh.device) + + def compute_offset_bit(sample_num, g_max, h_max): + g_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety + h_bit = int(np.log2(FIX_POINT_PRECISION * sample_num * h_max) + 1) +>>>>>>> dev-2.0.0-beta return max(g_bit, h_bit) if self._gh_pack: @@ -174,6 +198,7 @@ def compute_offset_bit(sample_num, g_max, h_max): self._g_abs_max = abs(float(grad_and_hess['g'].max()['g'])) + self._g_offset self._h_abs_max = 2 +<<<<<<< HEAD pack_num, total_num = 2, 2 shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max) partial_func = functools.partial(make_long_tensor, coder=self._coder, offset=self._g_offset, pk=self._pk, @@ -185,6 +210,14 @@ def compute_offset_bit(sample_num, g_max, h_max): self._pack_info['precision'] = FIX_POINT_PRECISION self._pack_info['pack_num'] = pack_num self._pack_info['total_num'] = total_num +======= + shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max) + + partial_func = functools.partial(make_long_tensor, coder=self._coder, offset=self._g_offset, pk=self._pk, + shift_bit=shift_bit, pack_num=2, encryptor=self._encryptor) + + en_grad_hess['gh'] = grad_and_hess.apply_row(partial_func) +>>>>>>> dev-2.0.0-beta else: en_grad_hess['g'] = self._encryptor.encrypt_tensor(grad_and_hess['g'].as_tensor()) en_grad_hess['h'] = self._encryptor.encrypt_tensor(grad_and_hess['h'].as_tensor()) @@ -266,7 +299,11 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, grad_and_hess: Da # compute histogram hist_inst, statistic_result = self.hist_builder.compute_hist(sub_ctx, cur_layer_node, train_df, grad_and_hess, sample_pos, node_map) # compute best splits +<<<<<<< HEAD split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack, self._pack_info) +======= + split_info = self.splitter.split(sub_ctx, statistic_result, cur_layer_node, node_map, self._sk, self._coder, self._gh_pack) +>>>>>>> dev-2.0.0-beta # update tree with best splits next_layer_nodes = self._update_tree(sub_ctx, cur_layer_node, split_info, train_df) # update feature importance diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py index 7f5750c76d..b744214b14 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py @@ -481,6 +481,7 @@ def _merge_splits(self, guest_splits, host_splits_list): return splits +<<<<<<< HEAD def _recover_pack_split(self, hist: ShuffledHistogram, schema, decode_schema=None): if decode_schema is not None: @@ -492,6 +493,14 @@ def _recover_pack_split(self, hist: ShuffledHistogram, schema, decode_schema=Non return host_hist def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack, pack_info): +======= + def _recover_pack_split(self, hist, schema): + print('schema is {}'.format(schema)) + host_hist = hist.decrypt(schema[0], schema[1]) + return host_hist + + def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack): +>>>>>>> dev-2.0.0-beta if sk is None or coder is None: raise ValueError('sk or coder is None, not able to decode host split points') @@ -508,6 +517,7 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code host_splits = [] if gh_pack: decrypt_schema = ({"gh":sk}, {"gh": (coder, torch.int64)}) +<<<<<<< HEAD # (coder, pack_num, offset_bit, precision, total_num) if pack_info is not None: decode_schema = {"gh": (coder, pack_info['pack_num'], pack_info['shift_bit'], pack_info['precision'])} @@ -523,6 +533,14 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code print('host_hist is ', host_hist) # coder.unpack_floats(host_hist, ) raise ValueError('cwj debug') +======= + else: + decrypt_schema = ({"g":sk, "h":sk}, {"g": (coder, torch.float32), "h": (coder, torch.float32)}) + + for idx, hist in enumerate(host_histograms): + host_sitename = ctx.hosts[idx].party[0] + '_' + ctx.hosts[idx].party[1] + host_hist = self._recover_pack_split(hist, decrypt_schema) +>>>>>>> dev-2.0.0-beta logger.debug('splitting host') host_split = self._find_best_splits(host_hist, host_sitename, cur_layer_node, reverse_node_map, recover_bucket=False) host_splits.append(host_split) @@ -536,13 +554,21 @@ def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, code def _host_split(self, ctx: Context, en_histogram, cur_layer_node): ctx.guest.put('hist', en_histogram) +<<<<<<< HEAD def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None, pack_info=None): +======= + def split(self, ctx: Context, histogram_statistic_result, cur_layer_node, node_map, sk=None, coder=None, gh_pack=None): +>>>>>>> dev-2.0.0-beta if ctx.is_on_guest: if sk is None or coder is None: raise ValueError('sk or coder is None, not able to decode host split points') assert gh_pack is not None and isinstance(gh_pack, bool), 'gh_pack should be bool, indicating if the gh is packed' +<<<<<<< HEAD return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack, pack_info) +======= + return self._guest_split(ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack) +>>>>>>> dev-2.0.0-beta elif ctx.is_on_host: return self._host_split(ctx, histogram_statistic_result, cur_layer_node) else: