Skip to content

Commit

Permalink
fix lr multi load model(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Jul 21, 2023
1 parent cc35127 commit 532e32d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
8 changes: 6 additions & 2 deletions python/fate/ml/glm/hetero/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,13 @@ def from_model(cls, model) -> "CoordinatedLRModuleArbiter":
learning_rate_param=model["meta"]["learning_rate_param"],
)
all_estimator = model["data"]["estimator"]

lr.estimator = {}
if lr.ovr:
lr.estimator = {label: CoordinatedLREstimatorArbiter().restore(d) for label, d in all_estimator.items()}
for label, d in all_estimator.items():
estimator = CoordinatedLREstimatorArbiter(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"])
estimator.restore(d)
lr.estimator[int(label)] = estimator
else:
estimator = CoordinatedLREstimatorArbiter()
estimator.restore(all_estimator)
Expand Down
11 changes: 7 additions & 4 deletions python/fate/ml/glm/hetero/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,14 @@ def from_model(cls, model) -> "CoordinatedLRModuleGuest":
lr.labels = model["meta"]["labels"]

all_estimator = model["data"]["estimator"]
lr.estimator = {}
if lr.ovr:
lr.estimator = {label: CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"],
init_param=model["meta"]["init_param"]). \
restore(d) for label, d in all_estimator.items()}
for label, d in all_estimator.items():
estimator = CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"],
init_param=model["meta"]["init_param"])
estimator.restore(d)
lr.estimator[int(label)] = estimator
else:
estimator = CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"],
Expand Down
14 changes: 12 additions & 2 deletions python/fate/ml/glm/hetero/coordinated_lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,20 @@ def from_model(cls, model) -> "CoordinatedLRModuleHost":
lr.ovr = model["meta"]["ovr"]

all_estimator = model["data"]["estimator"]
lr.estimator = {}

if lr.ovr:
lr.estimator = {label: CoordinatedLREstimatorHost().restore(d) for label, d in all_estimator.items()}
lr.estimator = {}
for label, d in all_estimator.items():
estimator = CoordinatedLREstimatorHost(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"],
init_param=model["meta"]["init_param"])
estimator.restore(d)
lr.estimator[int(label)] = estimator
else:
estimator = CoordinatedLREstimatorHost()
estimator = CoordinatedLREstimatorHost(epochs=model["meta"]["epochs"],
batch_size=model["meta"]["batch_size"],
init_param=model["meta"]["init_param"])
estimator.restore(all_estimator)
lr.estimator = estimator
logger.info(f"finish from model")
Expand Down

0 comments on commit 532e32d

Please sign in to comment.