Skip to content

Commit

Permalink
Add support for l1 regularization
Browse files Browse the repository at this point in the history
Signed-off-by: weijingchen <talkingwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Oct 10, 2023
1 parent e334353 commit d1795cf
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 12 deletions.
6 changes: 3 additions & 3 deletions python/fate/components/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def homo_lr(self):
return homo_lr

@_lazy_cpn
def hetero_sbt(self):
from .hetero_secureboost import hetero_sbt
def hetero_secureboost(self):
from .hetero_secureboost import hetero_secureboost

return hetero_sbt
return hetero_secureboost

@_lazy_cpn
def dataframe_transformer(self):
Expand Down
5 changes: 3 additions & 2 deletions python/fate/components/components/hetero_secureboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def train(
objective: cpn.parameter(type=params.string_choice(choice=[BINARY_BCE, MULTI_CE, REGRESSION_L2]), default=BINARY_BCE, \
desc='objective function, available: {}'.format([BINARY_BCE, MULTI_CE, REGRESSION_L2])),
num_class: cpn.parameter(type=params.conint(gt=0), default=2, desc='class number of multi classification, active when objective is {}'.format(MULTI_CE)),
l2: cpn.parameter(type=params.confloat(gt=0), default=0.1, desc='L2 regularization'),
l1: cpn.parameter(type=params.confloat(ge=0), default=0, desc='L1 regularization'),
l2: cpn.parameter(type=params.confloat(ge=0), default=0.1, desc='L2 regularization'),
min_impurity_split: cpn.parameter(type=params.confloat(gt=0), default=1e-2, desc='min impurity when splitting a tree node'),
min_sample_split: cpn.parameter(type=params.conint(gt=0), default=2, desc='min sample to split a tree node'),
min_leaf_node: cpn.parameter(type=params.conint(gt=0), default=1, desc='mininum sample contained in a leaf node'),
Expand All @@ -73,7 +74,7 @@ def train(
ctx.cipher.set_phe(ctx.device, he_param.dict())

booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, complete_secure=complete_secure,
learning_rate=learning_rate, max_bin=max_bin,
learning_rate=learning_rate, max_bin=max_bin, l1=l1,
l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split,
min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, objective=objective, num_class=num_class,
gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_ctx(local, context_name):

party = sys.argv[1]
max_depth = 3
num_tree = 5
num_tree = 3

if party == "guest":

Expand All @@ -48,15 +48,17 @@ def create_ctx(local, context_name):
df["sample_id"] = [i for i in range(len(df))]

reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32")

data_guest = reader.to_frame(ctx, df)

trees = HeteroSecureBoostGuest(num_tree, max_depth=max_depth)
trees = HeteroSecureBoostGuest(num_tree, max_depth=max_depth, l1=1.0)
trees.fit(ctx, data_guest)
pred = trees.get_train_predict().as_pd_df()

pred_ = trees.predict(ctx, data_guest).as_pd_df()

# compute auc
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(pred["label"], pred["predict_score"])
print(auc)

elif party == "host":

ctx = create_ctx(host, get_current_datetime_str())
Expand Down
3 changes: 3 additions & 0 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def __init__(
raise ValueError("objective must be specified when gh_pack is True")
self._pack_info = {}

# param checking
assert l1 >= 0 and l2 >= 0, "l1 and l2 should be non-negative, got l1: {}, l2: {}".format(l1, l2)

def set_encrypt_kit(self, kit):
self._encrypt_kit = kit
self._en_key_length = kit.key_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,33 @@ def get_bucket(self, idx):
def truncate(f, n=TREE_DECIMAL_ROUND):
return np.floor(f * 10 ** n) / 10 ** n

def _l1_reg(self, g):

if self.l1 == 0:
return g
if isinstance(g, torch.Tensor):
g[g < -self.l1] += self.l1
g[g > self.l1] -= self.l1
g[(g <= self.l1) & (g >= -self.l1)] = 0
else:
if g < - self.l1:
return g + self.l1
elif g > self.l1:
return g - self.l1
else:
return 0
return g

def node_gain(self, g, h):
g, h = self.truncate(g), self.truncate(h)
g = self._l1_reg(g)
if isinstance(h, np.ndarray):
h[h == 0] = np.nan
score = g * g / (h + self.l2)
score = (g * g ) / (h + self.l2)
return score

def node_weight(self, sum_grad, sum_hess):
sum_grad = self._l1_reg(sum_grad)
weight = -(sum_grad / (sum_hess + self.l2))
return self.truncate(weight)

Expand Down

0 comments on commit d1795cf

Please sign in to comment.