Skip to content

Commit

Permalink
add model param overflow check(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Aug 9, 2023
1 parent 5806f84 commit b81083f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 65 deletions.
12 changes: 6 additions & 6 deletions python/fate/components/components/coordinated_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def train(
method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0}
),
),
tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4),
early_stop: cpn.parameter(
type=params.string_choice(["weight_diff", "diff", "abs"]),
default="diff",
desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}",
),
tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4),
early_stop: cpn.parameter(
type=params.string_choice(["weight_diff", "diff", "abs"]),
default="diff",
desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}",
),
init_param: cpn.parameter(
type=params.init_param(),
default=params.InitParam(method="zeros", fit_intercept=True),
Expand Down
46 changes: 2 additions & 44 deletions python/fate/ml/glm/hetero/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fate.arch import Context, dataframe
from fate.ml.abc.module import HeteroModule
from fate.ml.utils import predict_tools
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param, check_overflow
from fate.ml.utils._optimizer import LRScheduler, Optimizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -338,46 +338,13 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No
else:
g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X, Y, weight)

"""h = X.shape[0]
# logger.info(f"h: {h}")
Xw = torch.matmul(X, w.detach())
d = 0.25 * Xw - 0.5 * Y
loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y)
if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm
Xw_h_all = batch_ctx.hosts.get("Xw_h")
for Xw_h in Xw_h_all:
d += Xw_h
#loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
# loss += 0.25 / h * torch.matmul(Xw.T, Xw_h)
loss += torch.matmul((0.25 / h * Xw - 0.5 / h * Y).T, Xw_h)
if weight:
# logger.info(f"weight: {weight.tolist()}")
d = d * weight
batch_ctx.hosts.put("d", d)
for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.125 / h * Xw2_h
h_loss_list = batch_ctx.hosts.get("h_loss")
for h_loss in h_loss_list:
if h_loss is not None:
loss += h_loss
if len(Xw_h_all) == 1:
batch_ctx.arbiter.put(loss=loss)
# gradient
g = 1 / h * torch.matmul(X.T, d)"""
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept"))
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")

w = self.optimizer.update_weights(w, g, self.init_param.get("fit_intercept"), self.lr_scheduler.lr)
# logger.info(f"w={w}")
check_overflow(w)

self.is_converged = iter_ctx.arbiter("converge_flag").get()
if self.is_converged:
Expand All @@ -403,11 +370,6 @@ def predict(self, ctx, test_data):
return pred

def get_model(self):
"""w = self.w.tolist()
intercept = None
if self.init_param.get("fit_intercept"):
w = w[:-1]
intercept = w[-1]"""
param = serialize_param(self.w, self.init_param.get("fit_intercept"))
return {
# "w": w,
Expand All @@ -421,10 +383,6 @@ def get_model(self):
}

def restore(self, model):
"""w = model["w"]
if model["fit_intercept"]:
w.append(model["intercept"])
self.w = torch.tensor(w)"""
self.w = deserialize_param(model["param"], model["fit_intercept"])
self.optimizer = Optimizer()
self.lr_scheduler = LRScheduler()
Expand Down
18 changes: 3 additions & 15 deletions python/fate/ml/glm/hetero/coordinated_lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fate.arch import Context
from fate.arch.dataframe import DataLoader
from fate.ml.abc.module import HeteroModule
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param, check_overflow
from fate.ml.utils._optimizer import LRScheduler, Optimizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -240,20 +240,6 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No
logger.info(f"self.optimizer set epoch{i}")
for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader):
X = batch_data.x
"""
h = X.shape[0]
Xw_h = 0.25 * torch.matmul(X, w.detach())
batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h))
batch_ctx.guest.put("Xw2_h", encryptor.encrypt(torch.matmul(Xw_h.T, Xw_h)))
loss_norm = self.optimizer.loss_norm(w)
if loss_norm is not None:
batch_ctx.guest.put("h_loss", encryptor.encrypt(loss_norm))
else:
batch_ctx.guest.put(h_loss=loss_norm)
d = batch_ctx.guest.get("d")
g = 1 / h * torch.matmul(X.T, d)"""
if is_centralized:
g = self.centralized_compute_gradient(batch_ctx, encryptor, w, X)
else:
Expand All @@ -264,6 +250,8 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No
g = batch_ctx.arbiter.get("g")

w = self.optimizer.update_weights(w, g, False, self.lr_scheduler.lr)
check_overflow(w)

self.is_converged = iter_ctx.arbiter("converge_flag").get()
if self.is_converged:
self.end_epoch = i
Expand Down
5 changes: 5 additions & 0 deletions python/fate/ml/utils/_model_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,8 @@ def deserialize_param(param, fit_intercept=False):
dtype = param["dtype"]
w = torch.tensor(w, dtype=getattr(torch, dtype))
return w


def check_overflow(param, threshold=1e8):
if (torch.abs(param) > threshold).any():
raise ValueError(f"Value(s) greater than {threshold} found in model param, please check.")

0 comments on commit b81083f

Please sign in to comment.