diff --git a/python/fate/arch/histogram/histogram.py b/python/fate/arch/histogram/histogram.py index 897f53f7cb..74a705f96d 100644 --- a/python/fate/arch/histogram/histogram.py +++ b/python/fate/arch/histogram/histogram.py @@ -1,18 +1,98 @@ +import pprint import typing +from typing import List, MutableMapping, Tuple import numpy as np import torch +class Shuffler: + def __init__(self, num_node, node_size, seed): + self.num_node = num_node + self.node_size = node_size + self.perm_indexes = [ + torch.randperm(node_size, generator=torch.Generator().manual_seed(seed)) for _ in range(num_node) + ] + + def get_global_perm_index(self): + index = torch.hstack([index + (nid * self.node_size) for nid, index in enumerate(self.perm_indexes)]) + return index + + # + # def reverse_index(self, index): + # return torch.argsort(self.perm_index)[index] + + def get_shuffle_index(self, step, reverse=False): + """ + get chunk shuffle index + """ + stepped = torch.arange(0, self.num_node * self.node_size * step).reshape(self.num_node * self.node_size, step) + indexes = stepped[self.get_global_perm_index(), :].flatten() + if reverse: + indexes = torch.argsort(indexes) + return indexes + + class HistogramIndexer: - def __init__(self, feature_bin_sizes): - feature_size = len(feature_bin_sizes) - self.feature_axis_stride = np.cumsum([0] + [feature_bin_sizes[i] for i in range(feature_size)]) + def __init__(self, node_size: int, feature_bin_sizes: List[int]): + self.node_size = node_size + self.feature_bin_size = feature_bin_sizes + self.feature_size = len(feature_bin_sizes) + self.feature_axis_stride = np.cumsum([0] + [feature_bin_sizes[i] for i in range(self.feature_size)]) self.node_axis_stride = sum(feature_bin_sizes) + self._shuffler = None + def get_position(self, nid, fid, bid): return nid * self.node_axis_stride + self.feature_axis_stride[fid] + bid + def get_bin_num(self, fid): + return self.feature_bin_size[fid] + + def get_bin_interval(self, nid, fid): + node_stride = nid * self.node_axis_stride + return node_stride + self.feature_axis_stride[fid], node_stride + self.feature_axis_stride[fid + 1] + + def get_global_feature_intervals(self): + intervals = [] + for nid in range(self.node_size): + for fid in range(self.feature_size): + intervals.append(self.get_bin_interval(nid, fid)) + return intervals + + def splits_into_k(self, k): + n = self.node_axis_stride + split_sizes = [n // k + (1 if i < n % k else 0) for i in range(k)] + start = 0 + for pid, size in enumerate(split_sizes): + end = start + size + shift = self.node_axis_stride + yield pid, (start, end), [(start + nid * shift, end + nid * shift) for nid in range(self.node_size)] + start += size + + def total_data_size(self): + return self.node_size * self.node_axis_stride + + def one_node_data_size(self): + return self.node_axis_stride + + def global_flatten_bin_sizes(self): + return self.feature_bin_size * self.node_size + + def flatten_in_node(self): + return HistogramIndexer(self.node_size, [self.one_node_data_size()]) + + def squeeze_bins(self): + return HistogramIndexer(self.node_size, [1] * self.feature_size) + + def get_shuffler(self, seed): + if self._shuffler is None: + self._shuffler = Shuffler(self.node_size, self.one_node_data_size(), seed) + return self._shuffler + + def reshape(self, feature_bin_sizes): + return HistogramIndexer(self.node_size, feature_bin_sizes) + class HistogramValues: def iadd_slice(self, index, value): @@ -21,10 +101,16 @@ def iadd_slice(self, index, value): def iadd(self, other): raise NotImplementedError - def get_stride(self, index): + def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): + raise NotImplementedError + + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]): raise NotImplementedError - def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + raise NotImplementedError + + def slice(self, start, end): raise NotImplementedError def decrypt(self, sk): @@ -38,7 +124,7 @@ def decode(self, coder, dtype): class HistogramEncryptedValues(HistogramValues): - def __init__(self, pk: "PK", evaluator, data, stride=1): + def __init__(self, pk, evaluator, data, stride=1): self.stride = stride self.data = data self.pk = pk @@ -56,8 +142,23 @@ def iadd(self, other): self.evaluator.i_add(self.pk, self.data, other.data) return self - def get_stride(self, index): - return self.evaluator.slice(self.data, index * self.stride, self.stride) + def slice(self, start, end): + return HistogramEncryptedValues( + self.pk, + self.evaluator, + self.evaluator.slice(self.data, start * self.stride, end * self.stride), + self.stride, + ) + + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]) -> "HistogramEncryptedValues": + intervals = [(start * self.stride, end * self.stride) for start, end in intervals] + data = self.evaluator.intervals_slice(self.data, intervals) + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.stride) + + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + self.evaluator.i_shuffle(self.pk, self.data, indices) + return self def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): """ @@ -67,7 +168,7 @@ def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): data = self.evaluator.intervals_sum_with_step(self.pk, self.data, intervals, self.stride) return HistogramEncryptedValues(self.pk, self.evaluator, data, self.stride) - def decrypt(self, sk: "SK"): + def decrypt(self, sk): data = sk.decrypt_to_encoded(self.data) return HistogramEncodedValues(data, self.stride) @@ -77,7 +178,7 @@ def i_chunking_cumsum(self, chunk_sizes: typing.List[int]): return self def __str__(self): - return str(self.data) + return f"" class HistogramEncodedValues(HistogramValues): @@ -85,19 +186,19 @@ def __init__(self, data, stride=1): self.data = data self.stride = stride - def decode_f64(self, coder: "Coder"): + def decode_f64(self, coder): return HistogramPlainValues(coder.decode_f64_vec(self.data), self.stride) - def decode_i64(self, coder: "Coder"): + def decode_i64(self, coder): return HistogramPlainValues(coder.decode_i64_vec(self.data), self.stride) - def decode_f32(self, coder: "Coder"): + def decode_f32(self, coder): return HistogramPlainValues(coder.decode_f32_vec(self.data), self.stride) - def decode_i32(self, coder: "Coder"): + def decode_i32(self, coder): return HistogramPlainValues(coder.decode_i32_vec(self.data), self.stride) - def decode(self, coder: "Coder", dtype): + def decode(self, coder, dtype): if dtype == torch.float64: return self.decode_f64(coder) elif dtype == torch.float32: @@ -109,8 +210,8 @@ def decode(self, coder: "Coder", dtype): else: raise NotImplementedError - def get_stride(self, index): - return self.data.get_stride(index, self.stride) + def slice(self, start, end): + return self.data.slice(start * self.stride, end * self.stride) class HistogramPlainValues(HistogramValues): @@ -118,21 +219,40 @@ def __init__(self, data, stride=1): self.data = data self.stride = stride + def __str__(self): + return f"" + + def __repr__(self): + return str(self) + @classmethod def zeros(cls, size, stride, dtype=torch.float64): return cls(torch.zeros(size * stride, dtype=dtype), stride) - def get_stride(self, index): - return self.data[index * self.stride : index * self.stride + self.stride] + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]): + result = torch.zeros(sum(e - s for s, e in intervals) * self.stride, dtype=self.data.dtype) + start = 0 + for s, e in intervals: + end = start + (e - s) * self.stride + result[start:end] = self.data[s * self.stride : e * self.stride] + start = end + return HistogramPlainValues(result, self.stride) def iadd_slice(self, index, value): start = index * self.stride end = index * self.stride + len(value) self.data[start:end] += value + def slice(self, start, end): + return HistogramPlainValues(self.data[start * self.stride : end * self.stride], self.stride) + def iadd(self, other): self.data += other.data + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + self.data = self.data[indices] + def i_chunking_cumsum(self, chunk_sizes: typing.List[int]): data_view = self.data.view(-1, self.stride) start = 0 @@ -147,35 +267,40 @@ def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): result[i * self.stride : (i + 1) * self.stride] = data_view[start:end, :].sum(dim=0) return HistogramPlainValues(result, self.stride) - def __str__(self): - return str(self.data) + @classmethod + def cat(cls, chunks_info: List[Tuple[int, int]], values: List["HistogramPlainValues"]): + data = [] + for (num_chunk, chunk_size), value in zip(chunks_info, values): + data.append(value.data.reshape(num_chunk, chunk_size, value.stride)) + data = torch.cat(data, dim=1).flatten() + return cls(data, values[0].stride) class Histogram: - def __init__(self, node_size, feature_bin_sizes): - self.node_size = node_size - self.feature_bin_sizes = feature_bin_sizes - self._indexer = HistogramIndexer(feature_bin_sizes) - self._num_data_unit = self.node_size * self._indexer.node_axis_stride - - self._values_mapping: typing.MutableMapping[str, HistogramValues] = {} + def __init__(self, indexer: "HistogramIndexer", values: MutableMapping[str, "HistogramValues"]): + self._indexer = indexer + self._values_mapping = values - def set_value_schema(self, schema: dict): - for name, items in schema.items(): + @classmethod + def create(cls, node_size, feature_bin_sizes, values_schema: dict): + indexer = HistogramIndexer(node_size, feature_bin_sizes) + values_mapping = {} + for name, items in values_schema.items(): stride = items.get("stride", 1) if items["type"] == "paillier": pk = items["pk"] evaluator = items["evaluator"] - self._values_mapping[name] = HistogramEncryptedValues.zeros(pk, evaluator, self._num_data_unit, stride) + values_mapping[name] = HistogramEncryptedValues.zeros(pk, evaluator, indexer.total_data_size(), stride) elif items["type"] == "tensor": dtype = items.get("dtype", torch.float64) - self._values_mapping[name] = HistogramPlainValues.zeros( - self._num_data_unit, stride=stride, dtype=dtype + values_mapping[name] = HistogramPlainValues.zeros( + indexer.total_data_size(), stride=stride, dtype=dtype ) else: raise NotImplementedError + return cls(indexer, values_mapping) - def update(self, nids, fids, targets): + def i_update(self, nids, fids, targets): for nid, bins, target in zip(nids, fids, targets): for fid, bid in enumerate(bins): index = self._indexer.get_position(nid, fid, bid) @@ -183,7 +308,7 @@ def update(self, nids, fids, targets): self._values_mapping[name].iadd_slice(index, value) return self - def merge(self, hist: "Histogram"): + def iadd(self, hist: "Histogram"): for name, value_container in hist._values_mapping.items(): if name in self._values_mapping: self._values_mapping[name].iadd(value_container) @@ -192,34 +317,40 @@ def merge(self, hist: "Histogram"): return self def decrypt(self, sk_map: dict): - result = Histogram(self.node_size, self.feature_bin_sizes) + values_mapping = {} for name, value_container in self._values_mapping.items(): if name in sk_map: - result._values_mapping[name] = value_container.decrypt(sk_map[name]) + values_mapping[name] = value_container.decrypt(sk_map[name]) else: - result._values_mapping[name] = value_container - return result + values_mapping[name] = value_container + return Histogram(self._indexer, values_mapping) def decode(self, coder_map: dict): - result = Histogram(self.node_size, self.feature_bin_sizes) + values_mapping = {} for name, value_container in self._values_mapping.items(): if name in coder_map: coder, dtype = coder_map[name] - result._values_mapping[name] = value_container.decode(coder, dtype) + values_mapping[name] = value_container.decode(coder, dtype) else: - result._values_mapping[name] = value_container - return result + values_mapping[name] = value_container + return Histogram(self._indexer, values_mapping) + + def i_shuffle(self, seed, reverse=False): + shuffler = self._indexer.get_shuffler(seed) + for name, value_container in self._values_mapping.items(): + value_container.i_shuffle(shuffler, reverse=reverse) def __str__(self): result = "" - for nid in range(self.node_size): + for nid in range(self._indexer.node_size): result += f"node-{nid}:\n" - for fid in range(len(self.feature_bin_sizes)): + for fid in range(self._indexer.feature_size): result += f"\tfeature-{fid}:\n" - for bid in range(self.feature_bin_sizes[fid]): + for bid in range(self._indexer.get_bin_num(fid)): for name, value_container in self._values_mapping.items(): - values = value_container.get_stride(self._indexer.get_position(nid, fid, bid)) - result += f"\t\t{name}: {values}\t" + start = self._indexer.get_position(nid, fid, bid) + values = value_container.slice(start, start + 1) + result += f"\t\t{name}: {values}" result += "\n" return result @@ -231,50 +362,150 @@ def flatten_all_feature_bins(self): affect the result of this method and vice versa. :return: """ - result = Histogram(self.node_size, [sum(self.feature_bin_sizes)]) - for name, value_container in self._values_mapping.items(): - result._values_mapping[name] = value_container - return result + indexer = self._indexer.flatten_in_node() + values = {name: value_container for name, value_container in self._values_mapping.items()} + return Histogram(indexer, values) - def cumsum_bins(self): - feature_bin_nums = [] - for nid in range(self.node_size): - feature_bin_nums.extend(self.feature_bin_sizes) + def i_cumsum_bins(self): for name, value_container in self._values_mapping.items(): - value_container.i_chunking_cumsum(feature_bin_nums) + value_container.i_chunking_cumsum(self._indexer.global_flatten_bin_sizes()) def sum_bins(self): - result = Histogram(self.node_size, [1] * len(self.feature_bin_sizes)) - intervals = [] - for nid in range(self.node_size): - for fid in range(len(self.feature_bin_sizes)): - intervals.append( - ( - self._indexer.get_position(nid, fid, 0), - self._indexer.get_position(nid, fid, self.feature_bin_sizes[fid]), - ) - ) + indexer = self._indexer.squeeze_bins() + values_mapping = {} + intervals = self._indexer.get_global_feature_intervals() for name, value_container in self._values_mapping.items(): - result._values_mapping[name] = value_container.chunking_sum(intervals) + values_mapping[name] = value_container.chunking_sum(intervals) + return Histogram(indexer, values_mapping) + + def to_splits(self, k) -> typing.Iterator[typing.Tuple[(int, "HistogramSplits")]]: + for pid, (start, end), indexes in self._indexer.splits_into_k(k): + data = {} + for name, value_container in self._values_mapping.items(): + data[name] = value_container.intervals_slice(indexes) + yield pid, HistogramSplits(self._indexer.node_size, start, end, data) + + def reshape(self, feature_bin_sizes): + indexer = self._indexer.reshape(feature_bin_sizes) + return Histogram(indexer, self._values_mapping) + + +class HistogramSplits: + def __init__(self, num_node, start, end, data): + self.num_node = num_node + self.start = start + self.end = end + self._data: typing.MutableMapping[str, HistogramValues] = data + + def __str__(self): + result = f"HistogramSplits(start={self.start}, end={self.end}):\n" + for name, value in self._data.items(): + result += f"{name}: {value}\n" return result - def to_splits(self, num_splits) -> typing.Iterator[typing.Tuple[(int, "Histogram")]]: - ... + def __repr__(self): + return self.__str__() + def iadd(self, other: "HistogramSplits"): + for name, value in other._data.items(): + self._data[name].iadd(value) + return self + + def i_decrypt(self, sk_map): + for name, value in self._data.items(): + if name in sk_map: + self._data[name] = value.decrypt(sk_map[name]) + return self + + def i_decode(self, coder_map): + for name, value in self._data.items(): + if name in coder_map: + coder, dtype = coder_map[name] + self._data[name] = value.decode(coder, dtype) + return self + + def decrypt( + self, + sk_map: MutableMapping[str, typing.Any], + coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]], + ): + self.i_decrypt(sk_map) + self.i_decode(coder_map) + return self -class DistributedHistogram: @classmethod - def create(cls, data): - data.mapPartions(DistributedHistogram._hist_build_partition_mapper).reduceByKey(lambda x, y: x.merge(y)) - - @staticmethod - def _repartion_by_nid_and_fid(part): - # -> (nid, fid), data - ... - - @staticmethod - def _hist_build_partition_mapper(part): - hist = Histogram() - hist.update() - hist.cumsum_bins() - return hist.to_splits() + def cat(cls, splits: typing.List["HistogramSplits"]): + data = {} + chunks_info = [] + for split in splits: + chunks_info.append((split.num_node, split.end - split.start)) + for name, value in split._data.items(): + if name not in data: + data[name] = [value] + else: + data[name].append(value) + for name, values in data.items(): + data[name] = values[0].cat(chunks_info, values) + return data + + +class DistributedHistogram: + def __init__(self, node_size, feature_bin_sizes, value_schemas, seed): + self._node_size = node_size + self._feature_bin_sizes = feature_bin_sizes + self._node_data_size = sum(feature_bin_sizes) + self._value_schemas = value_schemas + self._seed = seed + + def i_update(self, data, k=None): + if k is None: + k = data.count() + mapper = get_partition_hist_build_mapper( + self._node_size, self._feature_bin_sizes, self._value_schemas, self._seed, k + ) + table = data.mapReducePartitions(mapper, lambda x, y: x.iadd(y)) + return ShuffledHistogram(table, self._node_size, self._node_data_size) + + +class ShuffledHistogram: + def __init__(self, table, node_size, node_data_size): + self._table = table + self._node_size = node_size + self._node_data_size = node_data_size + + def decrypt( + self, + sk_map: MutableMapping[str, typing.Any], + coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]], + ): + out = list(self._table.map(lambda pid, split: (pid, split.decrypt(sk_map, coder_map))).collect()) + out.sort(key=lambda x: x[0]) + return self.cat([split for _, split in out]) + + def cat(self, hists: typing.List["HistogramSplits"]) -> "Histogram": + data = HistogramSplits.cat(hists) + return Histogram(HistogramIndexer(self._node_size, [self._node_data_size]), data) + + +def argmax_reducer( + max1: typing.Dict[int, typing.Tuple[int, int, float]], max2: typing.Dict[int, typing.Tuple[int, int, float]] +): + for nid, (pid, index, gain) in max2.items(): + if nid in max1: + if gain > max1[nid][2]: + max1[nid] = (pid, index, gain) + return max1 + + +def get_partition_hist_build_mapper(node_size, feature_bin_sizes, value_schemas, seed, k): + def _partition_hist_build_mapper(part): + hist = Histogram.create(node_size, feature_bin_sizes, value_schemas) + for _, raw in part: + nids, fids, targets = raw + hist.i_update(nids, fids, targets) + hist.i_cumsum_bins() + hist.i_shuffle(seed) + splits = hist.to_splits(k) + return list(splits) + + return _partition_hist_build_mapper diff --git a/python/fate/arch/protocol/phe/paillier.py b/python/fate/arch/protocol/phe/paillier.py index 7c0d2035fb..ed922881fe 100644 --- a/python/fate/arch/protocol/phe/paillier.py +++ b/python/fate/arch/protocol/phe/paillier.py @@ -259,6 +259,28 @@ def slice(a: EV, start: int, size: int) -> EV: """ return a.slice(start, size) + @staticmethod + def i_shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> None: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + a.i_shuffle(indices.detach().tolist()) + + @staticmethod + def intervals_slice(a: EV, intervals: List[Tuple[int, int]]) -> EV: + """ + slice in the given intervals + + for example: + intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] + then the result is [a0, a1, a2, a3, a6, a7, a8, a9, a10, a11] + """ + return a.intervals_slice(intervals) + @staticmethod def intervals_sum_with_step(pk: PK, a: EV, intervals: List[Tuple[int, int]], step: int): """ diff --git a/python/fate/test/test_histogram.py b/python/fate/test/test_histogram.py index 1b57cf19fd..1728e5f087 100644 --- a/python/fate/test/test_histogram.py +++ b/python/fate/test/test_histogram.py @@ -1,8 +1,10 @@ import pickle +import random import torch from fate.arch import Context -from fate.arch.histogram.histogram import Histogram +from fate.arch.computing.standalone import CSession +from fate.arch.histogram.histogram import DistributedHistogram, Histogram ctx = Context() kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) @@ -11,10 +13,9 @@ def test_plain(): # plain - hist = Histogram(1, [3, 2]) - hist.set_value_schema({"c0": {"type": "tensor", "stride": 2}}) + hist = Histogram.create(1, [3, 2], {"c0": {"type": "tensor", "stride": 2}}) print(f"created:\n {hist}") - hist.update( + hist.i_update( [0, 0, 0, 0], [[1, 0], [0, 1], [2, 1], [2, 0]], [ @@ -31,10 +32,9 @@ def test_plain(): def test_tensor(): # paillier - hist = Histogram(1, [3, 2]) - hist.set_value_schema({"c0": {"type": "paillier", "stride": 2, "pk": pk, "evaluator": evaluator}}) + hist = Histogram.create(1, [3, 2], {"c0": {"type": "paillier", "stride": 2, "pk": pk, "evaluator": evaluator}}) print(f"created:\n {hist}") - hist.update( + hist.i_update( [0, 0, 0, 0], [[1, 0], [0, 1], [2, 1], [2, 0]], [ @@ -54,16 +54,17 @@ def test_tensor(): def create_mixed_hist(): - hist = Histogram(1, [3, 2]) - hist.set_value_schema( + hist = Histogram.create( + 1, + [3, 2], { "g": {"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator}, "h": {"type": "paillier", "stride": 2, "pk": pk, "evaluator": evaluator}, "1": {"type": "tensor", "stride": 2, "dtype": torch.int64}, - } + }, ) print(f"created:\n {hist}") - hist.update( + hist.i_update( [0, 0, 0, 0], [[1, 0], [0, 1], [2, 1], [2, 0]], [ @@ -117,7 +118,7 @@ def test_flatten(): def test_cumsum(): hist = create_mixed_hist() - hist.cumsum_bins() + hist.i_cumsum_bins() print(f"cumsum: \n: {hist}") hist = hist.decrypt({"g": sk, "h": sk}) print(f"decrypt: \n: {hist}") @@ -143,3 +144,122 @@ def test_serde(): hist2 = hist2.decrypt({"g": sk, "h": sk}) hist2 = hist2.decode({"g": (coder, torch.float64), "h": (coder, torch.float64)}) print(f"hist2: \n: {hist2}") + + +def create_complex_hist(num_nodes, feature_bins, count): + hist = Histogram.create( + num_nodes, + feature_bins, + { + "g": {"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator}, + "h": {"type": "paillier", "stride": 2, "pk": pk, "evaluator": evaluator}, + "1": {"type": "tensor", "stride": 2, "dtype": torch.int64}, + }, + ) + print(f"created:\n {hist}") + + for i in range(count): + hist.i_update( + [random.randint(0, num_nodes - 1)], + [[random.randint(0, bins - 1) for bins in feature_bins]], + [ + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.rand(1)), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.rand(2)), False), + "1": torch.tensor([1, -1]), + }, + ], + ) + print(f"update: \n: {hist}") + return hist + + +def test_split(): + hist = create_complex_hist(2, [3, 2, 4, 5], 100) + for pid, split in hist.to_splits(3): + print(f"split: \n: {split._data}") + + +def test_i_shuffle(): + hist = create_complex_hist(2, [3, 2, 4, 5], 100) + hist = hist.i_shuffle(0) + print(f"i_shuffle: \n: {hist}") + + +def test_distributed_hist(): + print() + computing = CSession() + ctx = Context(computing=computing) + fake_data = [ + ( + [0, 0, 0, 0], + [[1, 0], [0, 1], [2, 1], [2, 0]], + [ + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0, -0.0])), False), + "1": torch.tensor([1.0, 0.0, -0.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + ], + ), + ( + [0, 0, 0, 0], + [[1, 0], [0, 1], [2, 1], [2, 0]], + [ + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0, -0.0])), False), + "1": torch.tensor([1.0, 0.0, -0.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + { + "g": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([0.0])), False), + "h": pk.encrypt_encoded(coder.encode_f32_vec(torch.tensor([1.0, -1.0])), False), + "1": torch.tensor([0.0, 1.0, -1.0]), + }, + ], + ), + ] + table = ctx.computing.parallelize(fake_data, 2, include_key=False) + hist = DistributedHistogram( + node_size=2, + feature_bin_sizes=[3, 2], + value_schemas={ + "g": {"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator}, + "h": {"type": "paillier", "stride": 2, "pk": pk, "evaluator": evaluator}, + "1": {"type": "tensor", "stride": 3, "dtype": torch.float32}, + }, + seed=0, + ) + shuffled = hist.i_update(table) + out = shuffled.decrypt( + sk_map={"g": sk, "h": sk}, coder_map={"g": (coder, torch.float32), "h": (coder, torch.float32)} + ) + print(out) + out = out.reshape([3, 2]) + out.i_shuffle(seed=0, reverse=True) + print(out) diff --git a/rust/fate_utils/crates/fate_utils/src/histogram.rs b/rust/fate_utils/crates/fate_utils/src/histogram.rs index 27987b16bc..90d44181b0 100644 --- a/rust/fate_utils/crates/fate_utils/src/histogram.rs +++ b/rust/fate_utils/crates/fate_utils/src/histogram.rs @@ -268,6 +268,39 @@ impl FixedpointPaillierVector { let data = self.data[start..start + size].to_vec(); FixedpointPaillierVector { data } } + fn i_shuffle(&mut self, indexes: Vec) { + let mut visited = vec![false; self.data.len()]; + for i in 0..self.data.len() { + if visited[i] || indexes[i] == i { + continue; + } + + let mut current = i; + let mut next = indexes[current]; + while !visited[next] && next != i { + self.data.swap(current, next); + visited[current] = true; + current = next; + next = indexes[current]; + } + visited[current] = true; + } + } + fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> PyResult { + let mut data = vec![]; + for (start, end) in intervals { + if end > self.data.len() { + return Err(PyRuntimeError::new_err(format!( + "end index out of range: start={}, end={}, data_size={}", + start, + end, + self.data.len() + ))); + } + data.extend_from_slice(&self.data[start..end]); + } + Ok(FixedpointPaillierVector { data }) + } fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec>) { for (i, x) in other.iter().enumerate() { self.data[position + i] = self.data[position + i].add(&x.ct, &pk.pk);