Skip to content

Commit

Permalink
Add split points compression for classification & regression
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Aug 29, 2023
1 parent 528bb0b commit e5dbc13
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
4 changes: 2 additions & 2 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def setup(self, options):
from fate.arch.protocol.phe.paillier import evaluator, keygen
from fate.arch.tensor.phe import PHETensorCipher

sk, pk, coder = keygen(key_length)
sk, pk, coder = keygen(key_size)
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, key_size, pk, sk, evaluator, coder, tensor_cipher) -> None:
self._tensor_cipher = tensor_cipher

@property
def key_size():
def key_size(self):
return self._key_size

def get_tensor_encryptor(self):
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def decrypt(self, sk):
return HistogramEncodedValues(data, self.stride)

def squeeze(self, pack_num, offset_bit):
data = self.evaluator.squeeze(pack_num, offset_bit, self.pk)
data = self.evaluator.pack_squeeze(self.data, pack_num, offset_bit, self.pk)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.stride)

def i_chunking_cumsum(self, chunk_sizes: typing.List[int]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def create_ctx(local):
if __name__ == '__main__':

party = sys.argv[1]
max_depth = 2
num_tree = 1
max_depth = 3
num_tree = 3
from sklearn.metrics import roc_auc_score as auc
if party == 'guest':

Expand Down Expand Up @@ -96,11 +96,11 @@ def create_ctx(local):
# load tree
# tree_dict = pickle.load(open('host_tree.pkl', 'rb'))
# trees.from_model(tree_dict)
trees.predict(ctx, data_host)
# trees.predict(ctx, data_host)

# fit again
new_tree = HeteroSecureBoostHost(1, max_depth=3)
new_tree.from_model(trees.get_model())
new_tree.fit(ctx, data_host)
# new_tree = HeteroSecureBoostHost(1, max_depth=3)
# new_tree.from_model(trees.get_model())
# new_tree.fit(ctx, data_host)


11 changes: 9 additions & 2 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pandas as pd
import torch as t
import numpy as np
import math


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,8 +160,8 @@ def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encr
return ret

def compute_offset_bit(sample_num, g_max, h_max):
g_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(np.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1)
g_bit = int(math.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety
h_bit = int(math.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1)
return max(g_bit, h_bit)

if self._gh_pack:
Expand Down Expand Up @@ -190,6 +191,7 @@ def compute_offset_bit(sample_num, g_max, h_max):
self._pack_info['precision'] = FIX_POINT_PRECISION
self._pack_info['pack_num'] = pack_num
self._pack_info['total_pack_num'] = total_pack_num
self._pack_info['split_point_shift_bit'] = shift_bit * pack_num
else:
logger.info('not using gh pack')
en_grad_hess['g'] = self._encryptor.encrypt_tensor(grad_and_hess['g'].as_tensor())
Expand Down Expand Up @@ -245,6 +247,11 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, grad_and_hess: Da
# Send Encrypted Grad and Hess
self._send_gh(ctx, grad_and_hess)

# send pack info
send_pack_info = {'total_pack_num': self._pack_info['total_pack_num'], 'split_point_shift_bit': self._pack_info['split_point_shift_bit']} \
if self._gh_pack else {}
ctx.hosts.put('pack_info', send_pack_info)

# init histogram builder
self.hist_builder = SBTHistogramBuilder(bin_train_data, binning_dict, None)

Expand Down
9 changes: 7 additions & 2 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _get_sample_on_local_nodes, _update_sample_pos, FeatureImportance
from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder, DistributedHistogram
from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter
from fate.arch.histogram.histogram import ShuffledHistogram
from fate.arch import Context
from fate.arch.dataframe import DataFrame
import numpy as np
Expand All @@ -38,6 +39,7 @@ def __init__(self, max_depth=3, valid_features=None, use_missing=False, zero_as_
self._pk = None
self._evaluator = None
self._gh_pack = True
self._pack_info = None

def _convert_split_id(self, ctx: Context, cur_layer_nodes: List[Node], node_map: dict, hist_builder: SBTHistogramBuilder, hist_inst: DistributedHistogram, splitter: FedSBTSplitter, data: DataFrame):

Expand Down Expand Up @@ -133,6 +135,7 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, binning_dict: dic
en_grad_and_hess: DataFrame = ret[0]
self._gh_pack = ret[1]
self._pk, self._evaluator = ctx.guest.get('en_kit')
self._pack_info = ctx.guest.get('pack_info')
root_node = self._initialize_root_node(ctx, train_df)

# init histogram builder
Expand All @@ -151,8 +154,10 @@ def booster_fit(self, ctx: Context, bin_train_data: DataFrame, binning_dict: dic
node_map = {n.nid: idx for idx, n in enumerate(cur_layer_node)}
# compute histogram with encrypted grad and hess
logger.info('train_df is {} grad hess is {}, {}, gh pack {}'.format(train_df, en_grad_and_hess, en_grad_and_hess.columns, self._gh_pack))
hist_inst, en_statistic_result = self.hist_builder.compute_hist(sub_ctx, cur_layer_node, train_df, en_grad_and_hess, sample_pos, node_map, \
pk=self._pk, evaluator=self._evaluator, gh_pack=self._gh_pack)
hist_inst, en_statistic_result = self.hist_builder.compute_hist(sub_ctx, cur_layer_node, train_df, en_grad_and_hess, sample_pos, node_map, pk=self._pk, evaluator=self._evaluator, gh_pack=self._gh_pack)
if self._gh_pack:
print(type(en_statistic_result))
en_statistic_result.squeeze({'gh': (self._pack_info['total_pack_num'], self._pack_info['split_point_shift_bit'])})
self.splitter.split(sub_ctx, en_statistic_result, cur_layer_node, node_map)
cur_layer_node, next_layer_nodes = self._sync_nodes(sub_ctx)
self._convert_split_id(sub_ctx, cur_layer_node, node_map, self.hist_builder, hist_inst, self.splitter, train_df)
Expand Down

0 comments on commit e5dbc13

Please sign in to comment.