Skip to content

Commit

Permalink
Optimized XGB using SharableList
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Dec 5, 2024
1 parent 245f0f5 commit 63827a1
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def _send_request(self, op: str, req: Shareable) -> Tuple[bytes, Shareable]:
"seq": req[Constant.PARAM_KEY_SEQ],
}
fl_ctx.set_prop(key=PROP_KEY_DEBUG_INFO, value=debug_info, private=True, sticky=False)

send_buf = req[Constant.PARAM_KEY_SEND_BUF]
try:
length = len(send_buf)
except:
length = -1

self.log_info(fl_ctx, f"Sending GRPC payload size: {length} Info: {debug_info}")
reply = ReliableMessage.send_request(
target=FQCN.ROOT_SERVER,
topic=Constant.TOPIC_XGB_REQUEST,
Expand Down
14 changes: 9 additions & 5 deletions nvflare/app_opt/xgboost/histogram_based_v2/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing.shared_memory import ShareableList


class Aggregator:
Expand All @@ -20,22 +21,25 @@ def __init__(self, initial_value=0):
def add(self, a, b):
return a + b

def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr):
def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr, accessor):
bin_id = sample_bin_assignment[sample_id]
sample_value = gh_values[sample_id]
if accessor:
sample_value = accessor(gh_values, sample_id)
else:
sample_value = gh_values[sample_id]
current_value = aggr[bin_id]
if current_value == 0:
# avoid add since sample_value may be cypher-text!
aggr[bin_id] = sample_value
else:
aggr[bin_id] = self.add(current_value, sample_value)

def aggregate(self, gh_values: list, sample_bin_assignment, num_bins, sample_ids):
def aggregate(self, gh_values: list | ShareableList, sample_bin_assignment, num_bins, sample_ids, accessor=None):
aggr_result = [self.initial_value] * num_bins
if not sample_ids:
for sample_id in range(len(gh_values)):
self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result)
self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result, accessor)
else:
for sample_id in sample_ids:
self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result)
self._update_aggregation(gh_values, sample_bin_assignment, sample_id, aggr_result, accessor)
return aggr_result
8 changes: 6 additions & 2 deletions nvflare/app_opt/xgboost/histogram_based_v2/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,12 @@ def _process_xgb_request(self, topic: str, request: Shareable, fl_ctx: FLContext
self.log_exception(fl_ctx, f"exception processing {op}: {secure_format_exception(ex)}")
self._trigger_stop(fl_ctx, process_error)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

self.log_info(fl_ctx, f"received reply for '{op}'")
rcv_buf = reply[Constant.PARAM_KEY_RCV_BUF]
try:
length = len(rcv_buf)
except:
length = -1
self.log_info(fl_ctx, f"received reply for '{op}' GRPC payload size: {length}")
reply.set_header(Constant.MSG_KEY_XGB_OP, op)
return reply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _process_before_broadcast(self, fl_ctx: FLContext):
# encrypt clear-text gh pairs and send to server
self.clear_ghs = [combine(clear_ghs[i][0], clear_ghs[i][1]) for i in range(len(clear_ghs))]
t = time.time()
self.info(fl_ctx, f"encrypting {len(self.clear_ghs)} gh pairs")
encrypted_values = self.encryptor.encrypt(self.clear_ghs)
self.info(fl_ctx, f"encrypted gh pairs: {len(encrypted_values)}, took {time.time() - t} secs")

Expand Down Expand Up @@ -216,9 +217,12 @@ def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext):
self.info(
fl_ctx, f"_process_before_all_gather_v: non-label client - do encrypted aggr for {len(groups)} groups"
)

samples_to_add = sum([len(id_list) for _, id_list in groups])
self.info(fl_ctx, f"Adding encrypted values for {samples_to_add} samples")
start = time.time()
aggr_result = self.adder.add(self.encrypted_ghs, self.feature_masks, groups, encode_sum=True)
self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features in {time.time() - start} secs")
self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features, took {time.time() - start} secs")
start = time.time()
encoded_str = encode_feature_aggregations(aggr_result)
self.info(fl_ctx, f"encoded aggr result len {len(encoded_str)} in {time.time() - start} secs")
Expand Down Expand Up @@ -284,8 +288,9 @@ def _decrypt_aggr_result(self, encoded, fl_ctx: FLContext):

t = time.time()
aggrs_to_decrypt = [decoded_aggrs[i][2] for i in range(len(decoded_aggrs))]
self.info(fl_ctx, f"decrypting {len(aggrs_to_decrypt)} numbers")
decrypted_aggrs = self.decrypter.decrypt(aggrs_to_decrypt) # this is a list of clear-text GH numbers
self.info(fl_ctx, f"decrypted {len(aggrs_to_decrypt)} numbers in {time.time() - t} secs")
self.info(fl_ctx, f"decrypted {len(aggrs_to_decrypt)} numbers, took {time.time() - t} secs")

aggr_result = []
for i in range(len(decoded_aggrs)):
Expand Down
63 changes: 56 additions & 7 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@
# limitations under the License.

import concurrent.futures
from functools import partial
from multiprocessing import shared_memory

from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator

from .util import encode_encrypted_numbers_to_str
from .util import (
bytes_to_int,
ciphertext_to_int,
encode_encrypted_numbers_to_str,
encrypt_number,
get_exponent,
int_to_bytes,
int_to_ciphertext,
)

SUFFIX = b"\xff"
SHARED_MEM_NAME = "encrypted_gh"


class Adder:
Expand All @@ -40,40 +53,76 @@ def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True):
samples in the group for the feature.
"""

shared_gh = shared_memory.ShareableList(self._shared_list(encrypted_numbers), name=SHARED_MEM_NAME)
items = []

for f in features:
fid, mask, num_bins = f
if not sample_groups:
items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, 0, None))
items.append((encode_sum, fid, mask, num_bins, 0, None))
else:
for g in sample_groups:
gid, sample_id_list = g
items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list))
items.append((encode_sum, fid, mask, num_bins, gid, sample_id_list))

pubkey = encrypted_numbers[0].public_key
chunk_size = int((len(items) - 1) / self.num_workers) + 1

results = self.exe.map(_do_add, items, chunksize=chunk_size)
results = self.exe.map(partial(_do_add, shared_gh.shm.name, pubkey), items, chunksize=chunk_size)
rl = []
for r in results:
rl.append(r)

shared_gh.shm.close()
shared_gh.shm.unlink()

return rl

def _shared_list(self, encrypted_numbers: list) -> list:
result = []
for ciphertext in encrypted_numbers:
# Due to a Python bug, a non-zero suffix is needed
# See https://github.com/python/cpython/issues/10693
result.append(int_to_bytes(ciphertext_to_int(ciphertext)) + SUFFIX)
result.append(int_to_bytes(get_exponent(ciphertext)) + SUFFIX)

return result


def _do_add(item):
encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list = item
def shared_list_accessor(pubkey, shared_gh, index):
"""
shared_gh contains ciphertext and exponent in bytes so each
encrypted number takes 2 slots
Due to the ShareableList bug, a non-zero byte is appended to the bytes
"""
n = bytes_to_int(shared_gh[index * 2][:-1])
exp = bytes_to_int(shared_gh[index * 2 + 1][:-1])
ciphertext = int_to_ciphertext(n, pubkey=pubkey)
return encrypt_number(pubkey, ciphertext, exp)


def _do_add(shared_mem_name, pubkey, item):

shared_gh = shared_memory.ShareableList(name=shared_mem_name)
encode_sum, fid, mask, num_bins, gid, sample_id_list = item
# bins = [0 for _ in range(num_bins)]
aggr = Aggregator()

bins = aggr.aggregate(
gh_values=encrypted_numbers,
gh_values=shared_gh,
sample_bin_assignment=mask,
num_bins=num_bins,
sample_ids=sample_id_list,
accessor=partial(shared_list_accessor, pubkey),
)

if encode_sum:
sums = encode_encrypted_numbers_to_str(bins)
else:
sums = bins

shared_gh.shm.close()

return fid, gid, sums
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import concurrent.futures
from functools import partial


class Encryptor:
Expand All @@ -28,18 +29,16 @@ def encrypt(self, numbers):
numbers: clear text numbers to be encrypted
Returns: list of encrypted numbers
"""
items = [(self.pubkey, numbers[i]) for i in range(len(numbers))]
chunk_size = int(len(items) / self.max_workers)
chunk_size = int(len(numbers) / self.max_workers)
if chunk_size == 0:
chunk_size = 1

results = self.exe.map(_do_enc, items, chunksize=chunk_size)
results = self.exe.map(partial(_do_enc, self.pubkey), numbers, chunksize=chunk_size)
rl = []
for r in results:
rl.append(r)
return rl


def _do_enc(item):
pubkey, num = item
return pubkey.encrypt(num)
def _do_enc(pubkey, item):
return pubkey.encrypt(item)
15 changes: 10 additions & 5 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import json
from base64 import urlsafe_b64decode, urlsafe_b64encode
from binascii import hexlify, unhexlify

# ipcl_python is not a required dependency. The import error causes unit test failure so make it optional
try:
Expand Down Expand Up @@ -74,14 +73,20 @@ def base64url_decode(payload):
return urlsafe_b64decode(payload.encode("utf-8"))


def int_to_bytes(num: int) -> bytes:
return num.to_bytes((max(num.bit_length(), 1) + 7) // 8, "big")


def bytes_to_int(buf: bytes) -> int:
return int.from_bytes(buf, "big")


def base64_to_int(source):
return int(hexlify(base64url_decode(source)), 16)
return bytes_to_int(base64url_decode(source))


def int_to_base64(source):
assert source != 0
I = hex(source).rstrip("L").lstrip("0x")
return base64url_encode(unhexlify((len(I) % 2) * "0" + I))
return base64url_encode(int_to_bytes(source))


def combine(g, h):
Expand Down

0 comments on commit 63827a1

Please sign in to comment.