Skip to content

Commit

Permalink
add dtype to cipher value in histogram
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Sep 12, 2023
1 parent 21df8fc commit 0933775
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 27 deletions.
24 changes: 15 additions & 9 deletions python/fate/arch/histogram/values/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import typing
from typing import List, Tuple

import torch

from ._encoded import HistogramEncodedValues
from ._value import HistogramValues
from ..indexer import Shuffler
Expand All @@ -10,16 +12,18 @@


class HistogramEncryptedValues(HistogramValues):
def __init__(self, pk, evaluator, data, coder, stride=1):
def __init__(self, pk, evaluator, data, coder, dtype: torch.dtype, size: int, stride: int):
self.stride = stride
self.data = data
self.pk = pk
self.coder = coder
self.dtype = dtype
self.size = size
self.evaluator = evaluator

@classmethod
def zeros(cls, pk, evaluator, size: int, coder, dtype, stride: int = 1):
return cls(pk, evaluator, evaluator.zeros(size * stride, dtype), coder, stride)
def zeros(cls, pk, evaluator, size: int, coder, dtype, stride: int):
return cls(pk, evaluator, evaluator.zeros(size * stride, dtype), coder, dtype, size, stride)

def i_update(self, value, positions):
from fate.arch.tensor.phe import PHETensor
Expand Down Expand Up @@ -47,13 +51,15 @@ def slice(self, start, end):
self.evaluator,
self.evaluator.slice(self.data, start * self.stride, (end - start) * self.stride),
self.coder,
self.dtype,
self.size,
self.stride,
)

def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]) -> "HistogramEncryptedValues":
intervals = [(start * self.stride, end * self.stride) for start, end in intervals]
data = self.evaluator.intervals_slice(self.data, intervals)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride)

def i_shuffle(self, shuffler: "Shuffler", reverse=False):
indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse)
Expand All @@ -63,15 +69,15 @@ def i_shuffle(self, shuffler: "Shuffler", reverse=False):
def shuffle(self, shuffler: "Shuffler", reverse=False):
indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse)
data = self.evaluator.shuffle(self.pk, self.data, indices)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride)

def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]):
"""
sum bins in the given logical intervals
"""
intervals = [(start * self.stride, end * self.stride) for start, end in intervals]
data = self.evaluator.intervals_sum_with_step(self.pk, self.data, intervals, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride)

def compute_child(
self,
Expand Down Expand Up @@ -111,15 +117,15 @@ def compute_child(
s,
)

return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride)

def decrypt(self, sk):
data = sk.decrypt_to_encoded(self.data)
return HistogramEncodedValues(data, self.stride)
return HistogramEncodedValues(data, self.size, self.dtype, self.stride)

def squeeze(self, pack_num, offset_bit):
data = self.evaluator.pack_squeeze(self.data, pack_num, offset_bit, self.pk)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.stride)
return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride)

def i_chunking_cumsum(self, chunk_sizes: typing.List[int]):
chunk_sizes = [num * self.stride for num in chunk_sizes]
Expand Down
5 changes: 4 additions & 1 deletion python/fate/arch/histogram/values/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


class HistogramEncodedValues(HistogramValues):
def __init__(self, data, size, stride=1):
def __init__(self, data, size: int, dtype: torch.dtype, stride: int):
self.data = data
self.size = size
self.dtype = dtype
self.stride = stride

def decode_f64(self, coder):
Expand All @@ -27,6 +28,8 @@ def decode_i32(self, coder):
return HistogramPlainValues(coder.decode_i32_vec(self.data), self.size, self.stride)

def decode(self, coder, dtype):
if dtype is None:
dtype = self.dtype
if dtype == torch.float64:
return self.decode_f64(coder)
elif dtype == torch.float32:
Expand Down
20 changes: 12 additions & 8 deletions python/fate/arch/histogram/values/_plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@


class HistogramPlainValues(HistogramValues):
def __init__(self, data, size, stride):
def __init__(self, data, dtype: torch.dtype, size: int, stride: int):
self.data = data
self.dtype = dtype
self.size = size
self.stride = stride

Expand All @@ -24,7 +25,7 @@ def __repr__(self):

@classmethod
def zeros(cls, size, stride, dtype=torch.float64):
return cls(torch.zeros(size * stride, dtype=dtype), size, stride)
return cls(torch.zeros(size * stride, dtype=dtype), dtype, size, stride)

def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]):
size = sum(e - s for s, e in intervals)
Expand All @@ -34,15 +35,17 @@ def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]):
end = start + (e - s) * self.stride
result[start:end] = self.data[s * self.stride : e * self.stride]
start = end
return HistogramPlainValues(result, size, self.stride)
return HistogramPlainValues(result, self.dtype, size, self.stride)

def iadd_slice(self, value, sa, sb, size):
size = size * self.stride
value = value.view(-1)
self.data[sa : sa + size] += value[sb : sb + size]

def slice(self, start, end):
return HistogramPlainValues(self.data[start * self.stride : end * self.stride], end - start, self.stride)
return HistogramPlainValues(
self.data[start * self.stride : end * self.stride], self.dtype, end - start, self.stride
)

def iadd(self, other):
self.data += other.data
Expand Down Expand Up @@ -93,7 +96,7 @@ def i_shuffle(self, shuffler: "Shuffler", reverse=False):
def shuffle(self, shuffler: "Shuffler", reverse=False):
indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse)
data = self.data[indices]
return HistogramPlainValues(data, self.size, self.stride)
return HistogramPlainValues(data, self.dtype, self.size, self.stride)

def i_chunking_cumsum(self, chunk_sizes: typing.List[int]):
data_view = self.data.view(-1, self.stride)
Expand All @@ -108,7 +111,7 @@ def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]):
data_view = self.data.view(-1, self.stride)
for i, (start, end) in enumerate(intervals):
result[i * self.stride : (i + 1) * self.stride] = data_view[start:end, :].sum(dim=0)
return HistogramPlainValues(result, size, self.stride)
return HistogramPlainValues(result, self.dtype, size, self.stride)

def compute_child(
self, weak_child: "HistogramPlainValues", positions: List[Tuple[int, int, int, int, int, int, int, int]], size
Expand Down Expand Up @@ -138,7 +141,7 @@ def compute_child(
parent_data_view[parent_data_start:parent_data_end]
- weak_child_data_view[weak_child_data_start:weak_child_data_end]
)
return HistogramPlainValues(data, size, self.stride)
return HistogramPlainValues(data, self.dtype, size, self.stride)

@classmethod
def cat(cls, chunks_info: List[Tuple[int, int]], values: List["HistogramPlainValues"]):
Expand All @@ -147,8 +150,9 @@ def cat(cls, chunks_info: List[Tuple[int, int]], values: List["HistogramPlainVal
data.append(value.data.reshape(num_chunk, chunk_size, value.stride))
data = torch.cat(data, dim=1)
size = data.shape[0]
dtype = data.dtype
data = data.flatten()
return cls(data, size, values[0].stride)
return cls(data, dtype, size, values[0].stride)

def extract_node_data(self, node_data_size, node_size):
return list(self.data.reshape(node_size, node_data_size, self.stride))
8 changes: 7 additions & 1 deletion python/fate/arch/histogram/values/_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ def i_sub_on_key(self, from_key, to_key):
right_value.data, left_value.data, right_value.pk, right_value.coder
)
self._data[from_key] = HistogramEncryptedValues(
right_value.pk, right_value.evaluator, data, right_value.coder, right_value.stride
right_value.pk,
right_value.evaluator,
data,
right_value.coder,
right_value.dtype,
right_value.size,
right_value.stride,
)
elif isinstance(right_value, HistogramPlainValues):
assert left_value.stride == right_value.stride
Expand Down
16 changes: 8 additions & 8 deletions python/fate/arch/protocol/phe/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,49 +122,49 @@ def encode_f64(self, val: float) -> FV:
return torch.tensor(val, dtype=torch.float64)

def decode_f64(self, val):
return val.item()
return float(val.item())

def encode_i64(self, val: int):
return torch.tensor(val, dtype=torch.int64)

def decode_i64(self, val):
return val.item()
return int(val.item())

def encode_f32(self, val: float):
return torch.tensor(val, dtype=torch.float32)

def decode_f32(self, val):
return val.item()
return float(val.item())

def encode_i32(self, val: int):
return torch.tensor(val, dtype=torch.int32)

def decode_i32(self, val):
return val.item()
return int(val.item())

def encode_f64_vec(self, vec: torch.Tensor):
return FV(vec.detach().flatten())

def decode_f64_vec(self, vec):
return vec.data
return vec.data.type(torch.float64)

def encode_i64_vec(self, vec: torch.Tensor):
return FV(vec.detach().flatten())

def decode_i64_vec(self, vec):
return vec.data
return vec.data.type(torch.int64)

def encode_f32_vec(self, vec: torch.Tensor):
return FV(vec.detach().flatten())

def decode_f32_vec(self, vec):
return vec.data
return vec.data.type(torch.float32)

def encode_i32_vec(self, vec: torch.Tensor):
return FV(vec.detach().flatten())

def decode_i32_vec(self, vec):
return vec.data
return vec.data.type(torch.int32)


def keygen(key_size):
Expand Down

0 comments on commit 0933775

Please sign in to comment.