Skip to content

Commit

Permalink
lr & linr support asynchronous gradient compute (#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Jul 21, 2023
1 parent 834d52d commit e5717cc
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 18 deletions.
7 changes: 4 additions & 3 deletions python/fate/ml/glm/hetero/coordinated_linr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def set_epochs(self, epochs):
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
ctx.guest("encryptor").put(encryptor)
if self.estimator is None:
optimizer = Optimizer(
self.optimizer_param["method"],
Expand All @@ -65,7 +66,7 @@ def fit(self, ctx: Context) -> None:
)
lr_scheduler = LRScheduler(self.learning_rate_param["method"],
self.learning_rate_param["scheduler_params"])
single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs,
single_estimator = HeteroLinREstimatorArbiter(epochs=self.epochs,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
Expand Down Expand Up @@ -93,13 +94,13 @@ def from_model(cls, model):
model["meta"]["batch_size"],
model["meta"]["optimizer_param"],
model["meta"]["learning_rate_param"])
estimator = HeteroLinrEstimatorArbiter()
estimator = HeteroLinREstimatorArbiter()
estimator.restore(model["data"]["estimator"])
linr.estimator = estimator
return linr


class HeteroLinrEstimatorArbiter(HeteroModule):
class HeteroLinREstimatorArbiter(HeteroModule):
def __init__(
self,
epochs=None,
Expand Down
77 changes: 73 additions & 4 deletions python/fate/ml/glm/hetero/coordinated_linr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
learning_rate_scheduler=lr_scheduler,
init_param=self.init_param)
self.estimator = estimator
self.estimator.fit_model(ctx, train_data, validate_data)
encryptor = ctx.arbiter("encryptor").get()
self.estimator.fit_model(ctx, encryptor, train_data, validate_data)

def predict(self, ctx, test_data):
prob = self.estimator.predict(ctx, test_data)
Expand Down Expand Up @@ -117,7 +118,70 @@ def __init__(
self.end_epoch = -1
self.is_converged = False

def fit_model(self, ctx, train_data, validate_data=None):
def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X, Y, weight):
h = X.shape[0]
Xw = torch.matmul(X, w.detach())
half_d = Xw - Y
if weight:
half_d = half_d * weight
batch_ctx.hosts.put("half_d", encryptor.encrypt(half_d))
half_g = torch.matmul(X.T, half_d)

Xw_h = batch_ctx.hosts.get("Xw_h")[0]
if weight:
Xw_h = Xw_h * weight
host_half_g = torch.matmul(X.T, Xw_h)

loss = 0.5 / h * torch.matmul(half_d.T, half_d)
if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm

for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.5 / h * Xw2_h
h_loss_list = batch_ctx.hosts.get("h_loss")
for h_loss in h_loss_list:
if h_loss is not None:
loss += h_loss

batch_ctx.arbiter.put(loss=loss)

# gradient
g = 1 / h * (half_g + host_half_g)
return g

def centralized_compute_gradient(self, batch_ctx, w, X, Y, weight):
h = X.shape[0]
Xw = torch.matmul(X, w.detach())
d = Xw - Y
loss = 0.5 / h * torch.matmul(d.T, d)
if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm
Xw_h_all = batch_ctx.hosts.get("Xw_h")
for Xw_h in Xw_h_all:
d += Xw_h
loss += 1 / h * torch.matmul(Xw.T, Xw_h)

if weight:
d = d * weight
batch_ctx.hosts.put(d=d)

for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.5 / h * Xw2_h
h_loss_list = batch_ctx.hosts.get("h_loss")
for h_loss in h_loss_list:
if h_loss is not None:
loss += h_loss

if len(Xw_h_all) == 1:
batch_ctx.arbiter.put(loss=loss)

# gradient
g = 1 / h * torch.matmul(X.T, d)
return g

def fit_model(self, ctx, encryptor, train_data, validate_data=None):
coef_count = train_data.shape[1]
logger.debug(f"init param: {self.init_param}")
if self.init_param.get("fit_intercept"):
Expand All @@ -133,6 +197,7 @@ def fit_model(self, ctx, train_data, validate_data=None):
)
# if self.end_epoch >= 0:
# self.start_epoch = self.end_epoch + 1
is_centralized = len(ctx.hosts) > 1

for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs):
self.optimizer.set_iters(i)
Expand All @@ -141,7 +206,11 @@ def fit_model(self, ctx, train_data, validate_data=None):
X = batch_data.x
Y = batch_data.label
weight = batch_data.weight
h = X.shape[0]
if is_centralized:
g = self.centralized_compute_gradient(batch_ctx, w, X, Y, weight)
else:
g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X, Y, weight)
"""h = X.shape[0]
Xw = torch.matmul(X, w.detach())
d = Xw - Y
loss = 0.5 / h * torch.matmul(d.T, d)
Expand All @@ -168,7 +237,7 @@ def fit_model(self, ctx, train_data, validate_data=None):
batch_ctx.arbiter.put(loss=loss)
# gradient
g = 1 / h * torch.matmul(X.T, d)
g = 1 / h * torch.matmul(X.T, d)"""
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept"))
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
Expand Down
43 changes: 41 additions & 2 deletions python/fate/ml/glm/hetero/coordinated_linr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ def __init__(
self.end_epoch = -1
self.is_converged = False

def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X):
h = X.shape[0]
Xw_h = torch.matmul(X, w.detach())
batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h))
half_g = torch.matmul(X.T, Xw_h)
guest_half_d = batch_ctx.guest.get("half_d")
guest_half_g = torch.matmul(X.T, guest_half_d)

batch_ctx.guest.put("Xw2_h", encryptor.encrypt(torch.matmul(Xw_h.T, Xw_h)))
loss_norm = self.optimizer.loss_norm(w)
if loss_norm is not None:
batch_ctx.guest.put("h_loss", encryptor.encrypt(loss_norm))
else:
batch_ctx.guest.put(h_loss=loss_norm)

g = 1 / h * (half_g + guest_half_g)
return g

def centralized_compute_gradient(self, batch_ctx, encryptor, w, X):
h = X.shape[0]
Xw_h = torch.matmul(X, w.detach())
batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h))
batch_ctx.guest.put("Xw2_h", encryptor.encrypt(torch.matmul(Xw_h.T, Xw_h)))

loss_norm = self.optimizer.loss_norm(w)
if loss_norm is not None:
batch_ctx.guest.put("h_loss", encryptor.encrypt(loss_norm))
else:
batch_ctx.guest.put(h_loss=loss_norm)

d = batch_ctx.guest.get("d")
g = 1 / h * torch.matmul(X.T, d)
return g

def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> None:
batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host")

Expand All @@ -130,12 +164,13 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) ->
self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer)
# if self.end_epoch >= 0:
# self.start_epoch = self.end_epoch + 1
is_centralized = len(ctx.hosts) > 1
for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs):
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set epoch {i}")
for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader):
X = batch_data.x
h = X.shape[0]
"""h = X.shape[0]
Xw_h = torch.matmul(X, w.detach())
batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h))
batch_ctx.guest.put("Xw2_h", encryptor.encrypt(torch.matmul(Xw_h.T, Xw_h)))
Expand All @@ -147,7 +182,11 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) ->
batch_ctx.guest.put(h_loss=loss_norm)
d = batch_ctx.guest.get("d")
g = 1 / h * torch.matmul(X.T, d)
g = 1 / h * torch.matmul(X.T, d)"""
if is_centralized:
g = self.centralized_compute_gradient(batch_ctx, encryptor, w, X)
else:
g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X)
g = self.optimizer.add_regular_to_grad(g, w, False)
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
Expand Down
1 change: 1 addition & 0 deletions python/fate/ml/glm/hetero/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def set_epochs(self, epochs):
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
ctx.guest("encryptor").put(encryptor)
label_count = ctx.guest("label_count").get()
if label_count > 2 or self.ovr:
self.ovr = True
Expand Down
98 changes: 90 additions & 8 deletions python/fate/ml/glm/hetero/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
label_count = train_data_binarized_label.shape[1]
ctx.arbiter.put("label_count", label_count)
ctx.hosts.put("label_count", label_count)
encryptor = ctx.arbiter("encryptor").get()
labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns]
if self.labels is None:
self.labels = labels
Expand Down Expand Up @@ -110,7 +111,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
if validate_data:
class_validate_data = validate_data.copy()
class_train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]]
single_estimator.fit_single_model(class_ctx, class_train_data, class_validate_data)
single_estimator.fit_single_model(class_ctx, encryptor, class_train_data, class_validate_data)
self.estimator[i] = single_estimator

else:
Expand All @@ -135,7 +136,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
single_estimator = self.estimator
single_estimator.epochs = self.epochs
single_estimator.batch_size = self.batch_size
single_estimator.fit_single_model(ctx, train_data, validate_data)
single_estimator.fit_single_model(ctx, encryptor, train_data, validate_data)
self.estimator = single_estimator
train_data.label = original_label

Expand Down Expand Up @@ -223,7 +224,79 @@ def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_s
self.end_epoch = -1
self.is_converged = False

def fit_single_model(self, ctx: Context, train_data, validate_data=None):
def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X, Y, weight):
h = X.shape[0]
# logger.info(f"h: {h}")
Xw = torch.matmul(X, w.detach())
half_d = 0.25 * Xw - 0.5 * Y
if weight:
half_d = half_d * weight
batch_ctx.hosts.put("half_d", encryptor.encrypt(half_d))
half_g = torch.matmul(X.T, half_d)

Xw_h = batch_ctx.hosts.get("Xw_h")[0]
if weight:
Xw_h = Xw_h * weight
host_half_g = torch.matmul(X.T, Xw_h)

loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y)

if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm

loss += torch.matmul((0.25 / h * Xw - 0.5 / h * Y).T, Xw_h)

for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.125 / h * Xw2_h
h_loss_list = batch_ctx.hosts.get("h_loss")
for h_loss in h_loss_list:
if h_loss is not None:
loss += h_loss

batch_ctx.arbiter.put(loss=loss)
# gradient
g = 1 / h * (half_g + host_half_g)
return g

def centralized_compute_gradient(self, batch_ctx, w, X, Y, weight):
h = X.shape[0]
# logger.info(f"h: {h}")
Xw = torch.matmul(X, w.detach())
d = 0.25 * Xw - 0.5 * Y
loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y)

if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
loss_norm = self.optimizer.loss_norm(w)
loss += loss_norm

Xw_h_all = batch_ctx.hosts.get("Xw_h")

for Xw_h in Xw_h_all:
d += Xw_h
"""loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
loss += 0.25 / h * torch.matmul(Xw.T, Xw_h)"""
loss += torch.matmul((0.25 / h * Xw - 0.5 / h * Y).T, Xw_h)
if weight:
# logger.info(f"weight: {weight.tolist()}")
d = d * weight
batch_ctx.hosts.put("d", d)

for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.125 / h * Xw2_h
h_loss_list = batch_ctx.hosts.get("h_loss")
for h_loss in h_loss_list:
if h_loss is not None:
loss += h_loss

if len(Xw_h_all) == 1:
batch_ctx.arbiter.put(loss=loss)

# gradient
g = 1 / h * torch.matmul(X.T, d)
return g

def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=None):
"""
l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2)
∇l(w) = 1/h * Σ(0.25 * xw - 0.5 * y)x = 1/h * Σdx
Expand All @@ -250,14 +323,21 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None):
# if self.end_epoch >= 0:
# self.start_epoch = self.end_epoch + 1

is_centralized = len(ctx.hosts) > 1

for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs):
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set epoch {i}")
for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader):
X = batch_data.x
Y = batch_data.label
weight = batch_data.weight
h = X.shape[0]
if is_centralized:
g = self.centralized_compute_gradient(batch_ctx, w, X, Y, weight)
else:
g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X, Y, weight)

"""h = X.shape[0]
# logger.info(f"h: {h}")
Xw = torch.matmul(X, w.detach())
d = 0.25 * Xw - 0.5 * Y
Expand All @@ -268,14 +348,16 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None):
loss += loss_norm
Xw_h_all = batch_ctx.hosts.get("Xw_h")
for Xw_h in Xw_h_all:
d += Xw_h
loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
loss += 0.25 / h * torch.matmul(Xw.T, Xw_h)
#loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
# loss += 0.25 / h * torch.matmul(Xw.T, Xw_h)
loss += torch.matmul((0.25 / h * Xw - 0.5 / h * Y).T, Xw_h)
if weight:
# logger.info(f"weight: {weight.tolist()}")
d = d * weight
batch_ctx.hosts.put(d=d)
batch_ctx.hosts.put("d", d)
for Xw2_h in batch_ctx.hosts.get("Xw2_h"):
loss += 0.125 / h * Xw2_h
Expand All @@ -288,7 +370,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None):
batch_ctx.arbiter.put(loss=loss)
# gradient
g = 1 / h * torch.matmul(X.T, d)
g = 1 / h * torch.matmul(X.T, d)"""
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept"))
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
Expand Down
Loading

0 comments on commit e5717cc

Please sign in to comment.