Skip to content

Commit

Permalink
Merge pull request #5147 from FederatedAI/dev-2.0.0-beta-improve-ciph…
Browse files Browse the repository at this point in the history
…er-param

improve cipher setup
  • Loading branch information
mgqa34 authored Sep 8, 2023
2 parents f4672be + f575313 commit bbc2a45
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 64 deletions.
71 changes: 43 additions & 28 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,57 @@
# limitations under the License.

import logging
import typing

from ..unify import device

logger = logging.getLogger(__name__)


class CipherKit:
def __init__(self, device: device, cipher_mapping=None) -> None:
def __init__(self, device: device, cipher_mapping: typing.Optional[dict] = None) -> None:
self._device = device
self._cipher_mapping = cipher_mapping
if cipher_mapping is None:
self._cipher_mapping = {}
else:
self._cipher_mapping = cipher_mapping

@property
def phe(self):
if self._cipher_mapping is None:
if self._device == device.CPU:
return PHECipherBuilder("paillier")
else:
logger.warning(f"no impl exists for device {self._device}, fallback to CPU")
return PHECipherBuilder("paillier")
def set_phe(self, device: device, options: typing.Optional[dict]):
if "phe" not in self._cipher_mapping:
self._cipher_mapping["phe"] = {}
self._cipher_mapping["phe"][device] = options

def _set_default_phe(self):
if "phe" not in self._cipher_mapping:
raise ValueError("phe is not set")
self._cipher_mapping["phe"] = {}
if self._device == device.CPU:
self._cipher_mapping["phe"][device.CPU] = {"kind": "paillier", "key_length": 1024}
else:
logger.warning(f"no impl exists for device {self._device}, fallback to CPU")
self._cipher_mapping["phe"][device.CPU] = self._cipher_mapping["phe"].get(
device.CPU, {"kind": "paillier", "key_length": 1024}
)

@property
def phe(self):
self._set_default_phe()
if self._device not in self._cipher_mapping["phe"]:
raise ValueError(f"phe is not set for device {self._device}")

return PHECipherBuilder(self._cipher_mapping["phe"][self._device])
raise ValueError(f"no impl exists for device {self._device}")
return PHECipherBuilder(**self._cipher_mapping["phe"][self._device])


class PHECipherBuilder:
def __init__(self, kind) -> None:
def __init__(self, kind, key_length) -> None:
self.kind = kind
self.key_length = key_length

def setup(self, options):
kind = options.get("kind", self.kind)
key_size = options.get("key_length", 1024)
def setup(self, options: typing.Optional[dict] = None):
if options is None:
kind = self.kind
key_size = self.key_length
else:
kind = options.get("kind", self.kind)
key_size = options.get("key_length", 1024)

if kind == "paillier":
from fate.arch.protocol.phe.paillier import evaluator, keygen
Expand Down Expand Up @@ -81,16 +96,16 @@ def setup(self, options):

class PHECipher:
def __init__(
self,
key_size,
pk,
sk,
evaluator,
coder,
tensor_cipher,
can_support_negative_number,
can_support_squeeze,
can_support_pack,
self,
key_size,
pk,
sk,
evaluator,
coder,
tensor_cipher,
can_support_negative_number,
can_support_squeeze,
can_support_pack,
) -> None:
self._key_size = key_size
self._pk = pk
Expand Down
12 changes: 5 additions & 7 deletions python/fate/components/components/coordinated_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def train(
optimizer = optimizer.dict()
learning_rate_scheduler = learning_rate_scheduler.dict()
init_param = init_param.dict()
he_param = he_param.dict()
ctx.cipher.set_phe(ctx.device, he_param.dict())
# temp code end
if role.is_guest:
train_guest(
Expand All @@ -91,7 +91,7 @@ def train(
)
elif role.is_arbiter:
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler,
he_param, output_model, warm_start_model)
output_model, warm_start_model)


@coordinated_linr.predict()
Expand Down Expand Up @@ -160,7 +160,7 @@ def cross_validation(
optimizer = optimizer.dict()
learning_rate_scheduler = learning_rate_scheduler.dict()
init_param = init_param.dict()
he_param = he_param.dict()
ctx.cipher.set_phe(ctx.device, he_param.dict())
# temp code end
if role.is_arbiter:
i = 0
Expand All @@ -173,7 +173,6 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
he_param=he_param,
)
module.fit(fold_ctx)
i += 1
Expand Down Expand Up @@ -301,7 +300,7 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param,
learning_rate_param, he_param, output_model, input_model):
learning_rate_param, output_model, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
Expand All @@ -310,8 +309,7 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param,
module.set_batch_size(batch_size)
else:
module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
he_param=he_param)
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param)
logger.info(f"coordinated linr arbiter start train")

sub_ctx = ctx.sub_ctx("train")
Expand Down
13 changes: 4 additions & 9 deletions python/fate/components/components/coordinated_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def train(
optimizer = optimizer.dict()
learning_rate_scheduler = learning_rate_scheduler.dict()
init_param = init_param.dict()
he_param = he_param.dict()
# temp code end
ctx.cipher.set_phe(ctx.device, he_param.dict())

if role.is_guest:
train_guest(
Expand Down Expand Up @@ -124,7 +123,6 @@ def train(
tol, batch_size,
optimizer,
learning_rate_scheduler,
he_param,
output_model,
warm_start_model)

Expand Down Expand Up @@ -195,12 +193,11 @@ def cross_validation(
output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"),
cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True),
):
# temp code start
optimizer = optimizer.dict()
learning_rate_scheduler = learning_rate_scheduler.dict()
init_param = init_param.dict()
he_param = he_param.dict()
# temp code end
ctx.cipher.set_phe(ctx.device, he_param.dict())

if role.is_arbiter:
i = 0
for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))):
Expand All @@ -212,7 +209,6 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
he_param=he_param,
)
module.fit(fold_ctx)
i += 1
Expand Down Expand Up @@ -386,7 +382,7 @@ def train_host(
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, he_param,
def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler,
output_model, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
Expand All @@ -402,7 +398,6 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, lea
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_scheduler,
he_param=he_param
)
logger.info(f"coordinated lr arbiter start train")
sub_ctx = ctx.sub_ctx("train")
Expand Down
9 changes: 4 additions & 5 deletions python/fate/components/components/hetero_feature_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def feature_binning_train(
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
output_model: cpn.json_model_output(roles=[GUEST, HOST]),
):
he_param = he_param.dict()
ctx.cipher.set_phe(ctx.device, he_param.dict())
train(ctx, train_data, train_output_data, output_model, role, method, n_bins, split_pt_dict,
bin_col, bin_idx, category_col, category_idx, use_anonymous, transform_method,
skip_metrics, local_only, relative_error, adjustment_factor, he_param)
skip_metrics, local_only, relative_error, adjustment_factor)


@hetero_feature_binning.predict()
Expand All @@ -114,7 +114,7 @@ def feature_binning_predict(

def train(ctx, train_data, train_output_data, output_model, role, method, n_bins, split_pt_dict,
bin_col, bin_idx, category_col, category_idx, use_anonymous, transform_method,
skip_metrics, local_only, relative_error, adjustment_factor, he_param):
skip_metrics, local_only, relative_error, adjustment_factor):
logger.info(f"start binning train")
sub_ctx = ctx.sub_ctx("train")
train_data = train_data.read()
Expand All @@ -130,8 +130,7 @@ def train(ctx, train_data, train_output_data, output_model, role, method, n_bins

if role.is_guest:
binning = HeteroBinningModuleGuest(method, n_bins, split_pt_dict, to_bin_cols, transform_method,
merged_category_col, local_only, relative_error, adjustment_factor,
he_param)
merged_category_col, local_only, relative_error, adjustment_factor)
elif role.is_host:
binning = HeteroBinningModuleHost(method, n_bins, split_pt_dict, to_bin_cols, transform_method,
merged_category_col, local_only, relative_error, adjustment_factor)
Expand Down
6 changes: 1 addition & 5 deletions python/fate/ml/feature_binning/hetero_feature_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ def __init__(
local_only=False,
error_rate=1e-6,
adjustment_factor=0.5,
he_param=None
):
self.method = method
self.bin_col = bin_col
self.category_col = category_col
self.he_param = he_param
self.n_bins = n_bins
self._federation_bin_obj = None
# param check
Expand Down Expand Up @@ -79,7 +77,7 @@ def compute_metrics(self, ctx: Context, binned_data):

def compute_federated_metrics(self, ctx: Context, binned_data):
logger.info(f"Start computing federated metrics.")
kit = ctx.cipher.phe.setup(options=self.he_param)
kit = ctx.cipher.phe.setup()
encryptor = kit.get_tensor_encryptor()
sk, pk, evaluator, coder = kit.sk, kit.pk, kit.evaluator, kit.coder

Expand Down Expand Up @@ -119,7 +117,6 @@ def get_model(self):
"category_col": self.category_col,
"model_type": "binning",
"n_bins": self.n_bins,
"he_param": self.he_param,
},
}
return model
Expand All @@ -133,7 +130,6 @@ def from_model(cls, model) -> "HeteroBinningModuleGuest":
method=model["meta"]["method"],
bin_col=model["meta"]["bin_col"],
category_col=model["meta"]["category_col"],
he_param=model["meta"]["he_param"],
n_bins=model["meta"]["n_bins"],
)
bin_obj.restore(model["data"])
Expand Down
7 changes: 2 additions & 5 deletions python/fate/ml/glm/hetero/coordinated_linr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@


class CoordinatedLinRModuleArbiter(HeteroModule):
def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param, he_param):
def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param):
self.epochs = epochs
self.batch_size = batch_size
self.early_stop = early_stop
self.tol = tol
self.learning_rate_param = learning_rate_param
self.optimizer_param = optimizer_param
self.he_param = he_param

self.estimator = None

Expand All @@ -46,7 +45,7 @@ def set_epochs(self, epochs):
self.estimator.epochs = epochs

def fit(self, ctx: Context) -> None:
kit = ctx.cipher.phe.setup(options=self.he_param)
kit = ctx.cipher.phe.setup()
encryptor = kit.get_tensor_encryptor()
decryptor = kit.get_tensor_decryptor()
ctx.hosts("encryptor").put(encryptor)
Expand Down Expand Up @@ -82,7 +81,6 @@ def get_model(self):
"batch_size": self.batch_size,
"learning_rate_param": self.learning_rate_param,
"optimizer_param": self.optimizer_param,
"he_param": self.he_param
},
}

Expand All @@ -95,7 +93,6 @@ def from_model(cls, model):
model["meta"]["batch_size"],
model["meta"]["optimizer_param"],
model["meta"]["learning_rate_param"],
model["meta"]["he_param"],
)
estimator = HeteroLinREstimatorArbiter()
estimator.restore(model["data"]["estimator"])
Expand Down
7 changes: 2 additions & 5 deletions python/fate/ml/glm/hetero/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@


class CoordinatedLRModuleArbiter(HeteroModule):
def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param, he_param):
def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param):
self.epochs = epochs
self.batch_size = batch_size
self.early_stop = early_stop
self.tol = tol
self.learning_rate_param = learning_rate_param
self.optimizer_param = optimizer_param
self.lr_param = learning_rate_param
self.he_param = he_param

self.estimator = None
self.ovr = False
Expand All @@ -56,7 +55,7 @@ def set_epochs(self, epochs):
self.estimator.epochs = epochs

def fit(self, ctx: Context) -> None:
kit = ctx.cipher.phe.setup(options=self.he_param)
kit = ctx.cipher.phe.setup()
encryptor = kit.get_tensor_encryptor()
decryptor = kit.get_tensor_decryptor()
ctx.hosts("encryptor").put(encryptor)
Expand Down Expand Up @@ -138,7 +137,6 @@ def get_model(self):
"batch_size": self.batch_size,
"learning_rate_param": self.learning_rate_param,
"optimizer_param": self.optimizer_param,
"he_param": self.he_param,
},
}

Expand All @@ -151,7 +149,6 @@ def from_model(cls, model) -> "CoordinatedLRModuleArbiter":
batch_size=model["meta"]["batch_size"],
optimizer_param=model["meta"]["optimizer_param"],
learning_rate_param=model["meta"]["learning_rate_param"],
he_param=model["meta"]["he_param"],
)
all_estimator = model["data"]["estimator"]
lr.estimator = {}
Expand Down

0 comments on commit bbc2a45

Please sign in to comment.