Skip to content

Commit

Permalink
edit glm(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jul 5, 2023
1 parent c0d84d9 commit 644cf90
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@


@cpn.component(roles=[GUEST, HOST, ARBITER])
def hetero_linr(ctx, role):
def coordinated_linr(ctx, role):
...


@hetero_linr.train()
@coordinated_linr.train()
def train(
ctx: Context,
role: Role,
Expand Down Expand Up @@ -67,7 +67,7 @@ def train(
train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer, learning_rate_scheduler)


@hetero_linr.predict()
@coordinated_linr.predict()
def predict(
ctx,
role: Role,
Expand All @@ -83,20 +83,19 @@ def predict(

def train_guest(ctx, train_data, validate_data, train_output_data, output_model, max_iter,
batch_size, optimizer_param, learning_rate_param, init_param):
from fate.ml.glm.hetero_linr import HeteroLinRModuleGuest
from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleGuest
# optimizer = optimizer_factory(optimizer_param)

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLinRModuleGuest(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
module = CoordinatedLinRModuleGuest(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
train_data = train_data.read()
if validate_data is not None:
validate_data = validate_data.read()
module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
model_writer.write_model("hetero_linr_host", model, metadata={})
output_model.write(model, metadata={})

with ctx.sub_ctx("predict") as sub_ctx:
predict_score = module.predict(sub_ctx, validate_data)
Expand All @@ -106,55 +105,54 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,

def train_host(ctx, train_data, validate_data, train_output_data, output_model, max_iter, batch_size,
optimizer_param, learning_rate_param, init_param):
from fate.ml.glm.hetero_linr import HeteroLinRModuleHost
from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleHost
# optimizer = optimizer_factory(optimizer_param)

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLinRModuleHost(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
module = CoordinatedLinRModuleHost(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
train_data = train_data.read()
if validate_data is not None:
validate_data = validate_data.read()
module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
model_writer.write_model("hetero_linr_host", model, metadata={})
output_model.write(model, metadata={})
with ctx.sub_ctx("predict") as sub_ctx:
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer_param,
learning_rate_param):
from fate.ml.glm.hetero_linr import HeteroLinRModuleArbiter
from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleArbiter

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLinRModuleArbiter(max_iter=max_iter, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param)
module = CoordinatedLinRModuleArbiter(max_iter=max_iter, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param)
module.fit(sub_ctx)


def predict_guest(ctx, input_model, test_data, test_output_data):
from fate.ml.glm.hetero_linr import HeteroLinRModuleGuest
from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleGuest

with ctx.sub_ctx("predict") as sub_ctx:
with input_model as model_reader:
model = model_reader.read_model()

module = HeteroLinRModuleGuest.from_model(model)
module = CoordinatedLinRModuleGuest.from_model(model)
test_data = test_data.read()
predict_score = module.predict(sub_ctx, test_data)
predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict")
sub_ctx.writer(test_output_data).write_dataframe(predict_result)


def predict_host(ctx, input_model, test_data, test_output_data):
from fate.ml.glm.hetero_linr import HeteroLinRModuleHost
from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleHost

with ctx.sub_ctx("predict") as sub_ctx:
with input_model as model_reader:
model = model_reader.read_model()
module = HeteroLinRModuleHost.from_model(model)
module = CoordinatedLinRModuleHost.from_model(model)
test_data = test_data.read()
module.predict(sub_ctx, test_data)

Expand Down
2 changes: 1 addition & 1 deletion python/fate/ml/glm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .coordinated_linr import CoordinatedLinRModuleHost, CoordinatedLinRModuleGuest, CoordinatedLinRModuleArbiter
from .coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter
from .hetero_linr import HeteroLinRModuleHost, HeteroLinRModuleGuest, HeteroLinRModuleArbiter
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .arbiter import HeteroLinRModuleArbiter
from .guest import HeteroLinRModuleGuest
from .host import HeteroLinRModuleHost
from .arbiter import CoordinatedLinRModuleArbiter
from .guest import CoordinatedLinRModuleGuest
from .host import CoordinatedLinRModuleHost
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class HeteroLinRModuleArbiter(HeteroModule):
class CoordinatedLinRModuleArbiter(HeteroModule):
def __init__(
self,
max_iter,
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
ctx.guest("encryptor").put(encryptor)
single_estimator = HeteroLinrEstimatorArbiter(max_iter=self.max_iter,
early_stop=self.early_stop,
tol=self.tol,
Expand All @@ -73,7 +74,7 @@ def to_model(self):
}

def from_model(cls, model):
linr = HeteroLinRModuleArbiter(**model["metadata"])
linr = CoordinatedLinRModuleArbiter(**model["metadata"])
estimator = HeteroLinrEstimatorArbiter()
estimator.restore(model["estimator"])
linr.estimator = estimator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class HeteroLinRModuleGuest(HeteroModule):
class CoordinatedLinRModuleGuest(HeteroModule):
def __init__(
self,
max_iter=None,
Expand All @@ -49,13 +49,14 @@ def __init__(

def fit(self, ctx: Context, train_data, validate_data=None) -> None:
with_weight = train_data.weight is not None

estimator = HeteroLinrEstimatorGuest(max_iter=self.max_iter,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler,
init_param=self.init_param)
estimator.fit_model(ctx, train_data, validate_data, with_weight=with_weight)
encryptor = ctx.arbiter("encryptor").get()

estimator = CoordinatedLinREstimatorGuest(max_iter=self.max_iter,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler,
init_param=self.init_param)
estimator.fit_model(ctx, encryptor, train_data, validate_data, with_weight=with_weight)
self.estimator = estimator

def predict(self, ctx, test_data):
Expand All @@ -78,16 +79,16 @@ def get_model(self):
}

@classmethod
def from_model(cls, model) -> "HeteroLinRModuleGuest":
linr = HeteroLinRModuleGuest(**model["metadata"])
estimator = HeteroLinrEstimatorGuest()
def from_model(cls, model) -> "CoordinatedLinRModuleGuest":
linr = CoordinatedLinRModuleGuest()
estimator = CoordinatedLinREstimatorGuest()
estimator.restore(model["estimator"])
linr.estimator = estimator

return linr


class HeteroLinrEstimatorGuest(HeteroModule):
class CoordinatedLinREstimatorGuest(HeteroModule):
def __init__(
self,
max_iter=None,
Expand All @@ -107,7 +108,7 @@ def __init__(
self.end_iter = -1
self.is_converged = False

def fit_model(self, ctx, train_data, validate_data=None, with_weight=False):
def fit_model(self, ctx, encryptor, train_data, validate_data=None, with_weight=False):
coef_count = train_data.shape[1]
if self.init_param.fit_intercept:
train_data["intercept"] = 1
Expand Down Expand Up @@ -135,6 +136,7 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False):
h = X.shape[0]
Xw = torch.matmul(X, w)
d = Xw - Y
encryptor.encrypt(d).to(batch_ctx.hosts, "d")
loss = 1 / 2 / h * torch.matmul(d.T, d)
if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
Expand All @@ -155,10 +157,11 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False):
batch_ctx.arbiter.put(loss=loss)

# gradient
g = self.optimizer.add_regular_to_grad(X.T @ d, w, self.init_param.fit_intercept)
g = X.T @ d
batch_ctx.arbiter.put("g_enc", X.T @ g)
g = batch_ctx.arbiter.get("g")

g = self.optimizer.add_regular_to_grad(g, w, self.init_param.fit_intercept)
w = self.optimizer.update_weights(w, g, self.init_param.fit_intercept, self.lr_scheduler.lr)
logger.info(f"w={w}")
j += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class HeteroLinRModuleHost(HeteroModule):
class CoordinatedLinRModuleHost(HeteroModule):
def __init__(
self,
max_iter,
Expand All @@ -49,11 +49,11 @@ def __init__(

def fit(self, ctx: Context, train_data, validate_data=None) -> None:
encryptor = ctx.arbiter("encryptor").get()
estimator = HeteroLinrEstimatorHost(max_iter=self.max_iter,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler,
init_param=self.init_param)
estimator = CoordiantedLinREstimatorHost(max_iter=self.max_iter,
batch_size=self.batch_size,
optimizer=self.optimizer,
learning_rate_scheduler=self.lr_scheduler,
init_param=self.init_param)
estimator.fit_model(ctx, encryptor, train_data, validate_data)
self.estimator = estimator

Expand All @@ -66,16 +66,16 @@ def get_model(self):
}

@classmethod
def from_model(cls, model) -> "HeteroLinRModuleHost":
linr = HeteroLinRModuleHost(**model["metadata"])
estimator = HeteroLinrEstimatorHost()
def from_model(cls, model) -> "CoordinatedLinRModuleHost":
linr = CoordinatedLinRModuleHost()
estimator = CoordiantedLinREstimatorHost()
estimator.restore(model["estimator"])
linr.estimator = estimator

return linr


class HeteroLinrEstimatorHost(HeteroModule):
class CoordiantedLinREstimatorHost(HeteroModule):
def __init__(
self,
max_iter=None,
Expand Down
1 change: 1 addition & 0 deletions python/fate/ml/glm/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
ctx.guest("encryptor").put(encryptor)
""" label_count = ctx.guest("label_count").get()"""
label_count = 2
if label_count > 2:
Expand Down
16 changes: 10 additions & 6 deletions python/fate/ml/glm/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self.labels = []

def fit(self, ctx: Context, train_data, validate_data=None) -> None:
encryptor = ctx.arbiter("encryptor").get()
train_data_binarized_label = train_data.label.get_dummies()
label_count = train_data_binarized_label.shape[1]
ctx.arbiter.put("label_count", label_count)
Expand All @@ -88,7 +89,8 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
learning_rate_scheduler=self.lr_scheduler,
init_param=self.init_param)
train_data.label = train_data_binarized_label[self.labels[i]]
single_estimator.fit_single_model(class_ctx, train_data, validate_data, with_weight=with_weight)
single_estimator.fit_single_model(class_ctx, encryptor, train_data, validate_data,
with_weight=with_weight)
self.estimator[i] = single_estimator
else:
single_estimator = CoordinatedLREstimatorGuest(max_iter=self.max_iter,
Expand All @@ -102,7 +104,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
def predict(self, ctx, test_data):
if self.ovr:
predict_score = test_data.create_dataframe(with_label=False, with_weight=False)
for i, class_ctx in ctx.range(len(self.labels)):
for i, class_ctx in ctx.ctxs_range(len(self.labels)):
estimator = self.estimator[i]
pred = estimator.predict(test_data)
predict_score[self.labels[i]] = pred
Expand Down Expand Up @@ -147,7 +149,7 @@ def get_model(self):

@classmethod
def from_model(cls, model) -> "CoordinatedLRModuleGuest":
lr = CoordinatedLRModuleGuest(**model["metadata"])
lr = CoordinatedLRModuleGuest()
lr.ovr = model["ovr"]
lr.labels = model["labels"]
lr.threshold = model["threshold"]
Expand Down Expand Up @@ -186,7 +188,7 @@ def __init__(
self.is_converged = False
self.with_weight = False

def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=False):
def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=None, with_weight=False):
"""
l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2)
∇l(w) = 1/h * Σ(0.25 * xw - 0.5 * y)x = 1/h * Σdx
Expand Down Expand Up @@ -234,11 +236,12 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals

Xw = torch.matmul(X, w)
d = 0.25 * Xw - 0.5 * Y
encryptor.encrypt(d).to(batch_ctx.hosts, "d")
loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y)

if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm
d = self.optimizer.add_regular_to_grad(d, w)
for Xw_h in batch_ctx.hosts.get("Xw_h"):
d += Xw_h
loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
Expand All @@ -256,9 +259,10 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
batch_ctx.arbiter.put(loss=loss)

# gradient
g = self.optimizer.add_regular_to_grad(X.T @ d, w, self.init_param.fit_intercept)
g = X.T @ d
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.fit_intercept)
# self.optimizer.step(g)
w = self.optimizer.update_weights(w, g, self.init_param.fit_intercept, self.lr_scheduler.lr)

Expand Down
Loading

0 comments on commit 644cf90

Please sign in to comment.