Skip to content

Commit

Permalink
fix metrics
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jul 11, 2023
1 parent 12452c3 commit baa257c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 60 deletions.
62 changes: 25 additions & 37 deletions python/fate/ml/glm/coordinated_linr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,28 @@
import logging

import torch

from fate.arch import Context
from fate.arch.dataframe import DataLoader
from fate.ml.abc.module import HeteroModule
from fate.ml.utils._convergence import converge_func_factory
from fate.ml.utils._optimizer import separate, Optimizer, LRScheduler
from fate.ml.utils._optimizer import LRScheduler, Optimizer, separate

logger = logging.getLogger(__name__)


class CoordinatedLinRModuleArbiter(HeteroModule):
def __init__(
self,
epochs,
early_stop,
tol,
batch_size,
optimizer_param,
learning_rate_param

):
def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param):
self.epochs = epochs
self.batch_size = batch_size
self.early_stop = early_stop
self.tol = tol
self.optimizer = Optimizer(optimizer_param["method"],
optimizer_param["penalty"],
optimizer_param["alpha"],
optimizer_param["optimizer_params"])
self.lr_scheduler = LRScheduler(learning_rate_param["method"],
learning_rate_param["scheduler_params"])
self.optimizer = Optimizer(
optimizer_param["method"],
optimizer_param["penalty"],
optimizer_param["alpha"],
optimizer_param["optimizer_params"],
)
self.lr_scheduler = LRScheduler(learning_rate_param["method"], learning_rate_param["scheduler_params"])
"""self.optimizer = Optimizer(optimizer_param.method,
optimizer_param.penalty,
optimizer_param.alpha,
Expand All @@ -58,12 +49,14 @@ def __init__(
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler)
single_estimator = HeteroLinrEstimatorArbiter(
epochs=self.epochs,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler,
)
single_estimator.fit_model(ctx, decryptor)
self.estimator = single_estimator

Expand All @@ -82,14 +75,7 @@ def from_model(cls, model):

class HeteroLinrEstimatorArbiter(HeteroModule):
def __init__(
self,
epochs=None,
early_stop=None,
tol=None,
batch_size=None,
optimizer=None,
learning_rate_scheduler=None

self, epochs=None, early_stop=None, tol=None, batch_size=None, optimizer=None, learning_rate_scheduler=None
):
self.epochs = epochs
self.batch_size = batch_size
Expand Down Expand Up @@ -160,13 +146,15 @@ def fit_model(self, ctx, decryptor):
logger.info("Multiple hosts exist, do not compute loss.")

if iter_loss is not None:
iter_ctx.metrics.log_loss("linr_loss", iter_loss.tolist(), step=i)
if self.early_stop == 'weight_diff':
iter_ctx.metrics.log_loss("linr_loss", iter_loss.tolist())
if self.early_stop == "weight_diff":
self.is_converged = self.converge_func.is_converge(iter_g)
else:
if iter_loss is None:
raise ValueError("Multiple host situation, loss early stop function is not available."
"You should use 'weight_diff' instead")
raise ValueError(
"Multiple host situation, loss early stop function is not available."
"You should use 'weight_diff' instead"
)
self.is_converged = self.converge_func.is_converge(iter_loss)

iter_ctx.hosts.put("converge_flag", self.is_converged)
Expand All @@ -186,7 +174,7 @@ def to_model(self):
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"end_epoch": self.end_epoch,
"converged": self.is_converged
"converged": self.is_converged,
}

def restore(self, model):
Expand Down
52 changes: 29 additions & 23 deletions python/fate/ml/glm/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging

import torch

from fate.arch import Context
from fate.arch.dataframe import DataLoader
from fate.ml.abc.module import HeteroModule
Expand Down Expand Up @@ -52,8 +51,9 @@ def fit(self, ctx: Context) -> None:
self.optimizer_param["alpha"],
self.optimizer_param["optimizer_params"],
)
lr_scheduler = LRScheduler(self.learning_rate_param["method"],
self.learning_rate_param["scheduler_params"])
lr_scheduler = LRScheduler(
self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"]
)
single_estimator = CoordinatedLREstimatorArbiter(
epochs=self.epochs,
early_stop=self.early_stop,
Expand All @@ -71,8 +71,9 @@ def fit(self, ctx: Context) -> None:
self.optimizer_param["alpha"],
self.optimizer_param["optimizer_params"],
)
lr_scheduler = LRScheduler(self.learning_rate_param["method"],
self.learning_rate_param["scheduler_params"])
lr_scheduler = LRScheduler(
self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"]
)
single_estimator = CoordinatedLREstimatorArbiter(
epochs=self.epochs,
early_stop=self.early_stop,
Expand All @@ -91,23 +92,28 @@ def to_model(self):
all_estimator[label] = estimator.get_model()
else:
all_estimator = self.estimator.get_model()
return {"data": {"estimator": all_estimator},
"meta": {"epochs": self.epochs,
"ovr": self.ovr,
"early_stop": self.early_stop,
"tol": self.tol,
"batch_size": self.batch_size,
"learning_rate_param": self.learning_rate_param,
"optimizer_param": self.optimizer_param},
}
return {
"data": {"estimator": all_estimator},
"meta": {
"epochs": self.epochs,
"ovr": self.ovr,
"early_stop": self.early_stop,
"tol": self.tol,
"batch_size": self.batch_size,
"learning_rate_param": self.learning_rate_param,
"optimizer_param": self.optimizer_param,
},
}

def from_model(cls, model):
lr = CoordinatedLRModuleArbiter(epochs=model["meta"]["epochs"],
early_stop=model["meta"]["early_stop"],
tol=model["meta"]["tol"],
batch_size=model["meta"]["batch_size"],
optimizer_param=model["meta"]["optimizer_param"],
learning_rate_param=model["meta"]["learning_rate_param"])
lr = CoordinatedLRModuleArbiter(
epochs=model["meta"]["epochs"],
early_stop=model["meta"]["early_stop"],
tol=model["meta"]["tol"],
batch_size=model["meta"]["batch_size"],
optimizer_param=model["meta"]["optimizer_param"],
learning_rate_param=model["meta"]["learning_rate_param"],
)
all_estimator = model["data"]["estimator"]
if lr.ovr:
lr.estimator = {label: CoordinatedLREstimatorArbiter().restore(d) for label, d in all_estimator.items()}
Expand All @@ -122,7 +128,7 @@ def from_model(cls, model):

class CoordinatedLREstimatorArbiter(HeteroModule):
def __init__(
self, epochs=None, early_stop=None, tol=None, batch_size=None, optimizer=None, learning_rate_scheduler=None
self, epochs=None, early_stop=None, tol=None, batch_size=None, optimizer=None, learning_rate_scheduler=None
):
self.epochs = epochs
self.batch_size = batch_size
Expand Down Expand Up @@ -191,7 +197,7 @@ def fit_single_model(self, ctx: Context, decryptor):

if iter_loss is not None:
logger.info(f"step={i}: lr_loss={iter_loss.tolist()}")
iter_ctx.metrics.log_loss("lr_loss", iter_loss.tolist(), step=i)
iter_ctx.metrics.log_loss("lr_loss", iter_loss.tolist())
if self.early_stop == "weight_diff":
self.is_converged = self.converge_func.is_converge(iter_g)
else:
Expand All @@ -218,7 +224,7 @@ def to_model(self):
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"end_epoch": self.end_epoch,
"is_converged": self.is_converged
"is_converged": self.is_converged,
}

def restore(self, model):
Expand Down

0 comments on commit baa257c

Please sign in to comment.