diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py index 33af74967d..181144336e 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py @@ -240,6 +240,14 @@ def _send_request(self, op: str, req: Shareable) -> Tuple[bytes, Shareable]: "seq": req[Constant.PARAM_KEY_SEQ], } fl_ctx.set_prop(key=PROP_KEY_DEBUG_INFO, value=debug_info, private=True, sticky=False) + + send_buf = req[Constant.PARAM_KEY_SEND_BUF] + try: + length = len(send_buf) + except: + length = -1 + + self.log_info(fl_ctx, f"Sending GRPC payload size: {length} Info: {debug_info}") reply = ReliableMessage.send_request( target=FQCN.ROOT_SERVER, topic=Constant.TOPIC_XGB_REQUEST, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/aggr.py b/nvflare/app_opt/xgboost/histogram_based_v2/aggr.py index 9da4088611..34fe2acd0a 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/aggr.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/aggr.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from multiprocessing.shared_memory import ShareableList class Aggregator: @@ -20,9 +21,12 @@ def __init__(self, initial_value=0): def add(self, a, b): return a + b - def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr): + def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr, accessor): bin_id = sample_bin_assignment[sample_id] - sample_value = gh_values[sample_id] + if accessor: + sample_value = accessor(gh_values, sample_id) + else: + sample_value = gh_values[sample_id] current_value = aggr[bin_id] if current_value == 0: # avoid add since sample_value may be cypher-text! @@ -30,12 +34,12 @@ def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr) else: aggr[bin_id] = self.add(current_value, sample_value) - def aggregate(self, gh_values: list, sample_bin_assignment, num_bins, sample_ids): + def aggregate(self, gh_values: list | ShareableList, sample_bin_assignment, num_bins, sample_ids, accessor=None): aggr_result = [self.initial_value] * num_bins if not sample_ids: for sample_id in range(len(gh_values)): - self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result) + self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result, accessor) else: for sample_id in sample_ids: - self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result) + self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result, accessor) return aggr_result diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py index cce4e5c3d3..983b49e433 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py @@ -413,8 +413,12 @@ def _process_xgb_request(self, topic: str, request: Shareable, fl_ctx: FLContext self.log_exception(fl_ctx, f"exception processing {op}: {secure_format_exception(ex)}") self._trigger_stop(fl_ctx, process_error) return make_reply(ReturnCode.EXECUTION_EXCEPTION) - - self.log_info(fl_ctx, f"received reply for '{op}'") + rcv_buf = reply[Constant.PARAM_KEY_RCV_BUF] + try: + length = len(rcv_buf) + except: + length = -1 + self.log_info(fl_ctx, f"received reply for '{op}' GRPC payload size: {length}") reply.set_header(Constant.MSG_KEY_XGB_OP, op) return reply diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 764b1312a7..b00e0a2e2f 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -109,6 +109,7 @@ def _process_before_broadcast(self, fl_ctx: FLContext): # encrypt clear-text gh pairs and send to server self.clear_ghs = [combine(clear_ghs[i][0], clear_ghs[i][1]) for i in range(len(clear_ghs))] t = time.time() + self.info(fl_ctx, f"encrypting {len(self.clear_ghs)} gh pairs") encrypted_values = self.encryptor.encrypt(self.clear_ghs) self.info(fl_ctx, f"encrypted gh pairs: {len(encrypted_values)}, took {time.time() - t} secs") @@ -216,9 +217,12 @@ def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext): self.info( fl_ctx, f"_process_before_all_gather_v: non-label client - do encrypted aggr for {len(groups)} groups" ) + + samples_to_add = sum([len(id_list) for _, id_list in groups]) + self.info(fl_ctx, f"Adding encrypted values for {samples_to_add} samples") start = time.time() aggr_result = self.adder.add(self.encrypted_ghs, self.feature_masks, groups, encode_sum=True) - self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features in {time.time() - start} secs") + self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features, took {time.time() - start} secs") start = time.time() encoded_str = encode_feature_aggregations(aggr_result) self.info(fl_ctx, f"encoded aggr result len {len(encoded_str)} in {time.time() - start} secs") @@ -284,8 +288,9 @@ def _decrypt_aggr_result(self, encoded, fl_ctx: FLContext): t = time.time() aggrs_to_decrypt = [decoded_aggrs[i][2] for i in range(len(decoded_aggrs))] + self.info(fl_ctx, f"decrypting {len(aggrs_to_decrypt)} numbers") decrypted_aggrs = self.decrypter.decrypt(aggrs_to_decrypt) # this is a list of clear-text GH numbers - self.info(fl_ctx, f"decrypted {len(aggrs_to_decrypt)} numbers in {time.time() - t} secs") + self.info(fl_ctx, f"decrypted {len(aggrs_to_decrypt)} numbers, took {time.time() - t} secs") aggr_result = [] for i in range(len(decoded_aggrs)): diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/adder.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/adder.py index 0e50b4880f..57b2ff3fcf 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/adder.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/adder.py @@ -13,10 +13,23 @@ # limitations under the License. import concurrent.futures +from functools import partial +from multiprocessing import shared_memory from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator -from .util import encode_encrypted_numbers_to_str +from .util import ( + bytes_to_int, + ciphertext_to_int, + encode_encrypted_numbers_to_str, + encrypt_number, + get_exponent, + int_to_bytes, + int_to_ciphertext, +) + +SUFFIX = b"\xff" +SHARED_MEM_NAME = "encrypted_gh" class Adder: @@ -40,40 +53,76 @@ def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True): samples in the group for the feature. """ + + shared_gh = shared_memory.ShareableList(self._shared_list(encrypted_numbers), name=SHARED_MEM_NAME) items = [] for f in features: fid, mask, num_bins = f if not sample_groups: - items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, 0, None)) + items.append((encode_sum, fid, mask, num_bins, 0, None)) else: for g in sample_groups: gid, sample_id_list = g - items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list)) + items.append((encode_sum, fid, mask, num_bins, gid, sample_id_list)) + pubkey = encrypted_numbers[0].public_key chunk_size = int((len(items) - 1) / self.num_workers) + 1 - results = self.exe.map(_do_add, items, chunksize=chunk_size) + results = self.exe.map(partial(_do_add, shared_gh.shm.name, pubkey), items, chunksize=chunk_size) rl = [] for r in results: rl.append(r) + + shared_gh.shm.close() + shared_gh.shm.unlink() + return rl + def _shared_list(self, encrypted_numbers: list) -> list: + result = [] + for ciphertext in encrypted_numbers: + # Due to a Python bug, a non-zero suffix is needed + # See https://github.com/python/cpython/issues/10693 + result.append(int_to_bytes(ciphertext_to_int(ciphertext)) + SUFFIX) + result.append(int_to_bytes(get_exponent(ciphertext)) + SUFFIX) + + return result + -def _do_add(item): - encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list = item +def shared_list_accessor(pubkey, shared_gh, index): + """ + shared_gh contains ciphertext and exponent in bytes so each + encrypted number takes 2 slots + + Due to the ShareableList bug, a non-zero byte is appended to the bytes + """ + n = bytes_to_int(shared_gh[index * 2][:-1]) + exp = bytes_to_int(shared_gh[index * 2 + 1][:-1]) + ciphertext = int_to_ciphertext(n, pubkey=pubkey) + return encrypt_number(pubkey, ciphertext, exp) + + +def _do_add(shared_mem_name, pubkey, item): + + shared_gh = shared_memory.ShareableList(name=shared_mem_name) + encode_sum, fid, mask, num_bins, gid, sample_id_list = item # bins = [0 for _ in range(num_bins)] aggr = Aggregator() bins = aggr.aggregate( - gh_values=encrypted_numbers, + gh_values=shared_gh, sample_bin_assignment=mask, num_bins=num_bins, sample_ids=sample_id_list, + accessor=partial(shared_list_accessor, pubkey), ) if encode_sum: sums = encode_encrypted_numbers_to_str(bins) else: sums = bins + + shared_gh.shm.close() + return fid, gid, sums diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/encryptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/encryptor.py index 72eb7360ee..4dd5221bc2 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/encryptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/encryptor.py @@ -13,6 +13,7 @@ # limitations under the License. import concurrent.futures +from functools import partial class Encryptor: @@ -28,18 +29,16 @@ def encrypt(self, numbers): numbers: clear text numbers to be encrypted Returns: list of encrypted numbers """ - items = [(self.pubkey, numbers[i]) for i in range(len(numbers))] - chunk_size = int(len(items) / self.max_workers) + chunk_size = int(len(numbers) / self.max_workers) if chunk_size == 0: chunk_size = 1 - results = self.exe.map(_do_enc, items, chunksize=chunk_size) + results = self.exe.map(partial(_do_enc, self.pubkey), numbers, chunksize=chunk_size) rl = [] for r in results: rl.append(r) return rl -def _do_enc(item): - pubkey, num = item - return pubkey.encrypt(num) +def _do_enc(pubkey, item): + return pubkey.encrypt(item) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py index 25d70cd75c..f4f8cd11c8 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py @@ -14,7 +14,6 @@ import json from base64 import urlsafe_b64decode, urlsafe_b64encode -from binascii import hexlify, unhexlify # ipcl_python is not a required dependency. The import error causes unit test failure so make it optional try: @@ -74,14 +73,20 @@ def base64url_decode(payload): return urlsafe_b64decode(payload.encode("utf-8")) +def int_to_bytes(num: int) -> bytes: + return num.to_bytes((max(num.bit_length(), 1) + 7) // 8, "big") + + +def bytes_to_int(buf: bytes) -> int: + return int.from_bytes(buf, "big") + + def base64_to_int(source): - return int(hexlify(base64url_decode(source)), 16) + return bytes_to_int(base64url_decode(source)) def int_to_base64(source): - assert source != 0 - I = hex(source).rstrip("L").lstrip("0x") - return base64url_encode(unhexlify((len(I) % 2) * "0" + I)) + return base64url_encode(int_to_bytes(source)) def combine(g, h):