Skip to content

Commit

Permalink
fix histogram
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Aug 9, 2023
1 parent 2474b84 commit 47a59c4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 21 deletions.
14 changes: 10 additions & 4 deletions python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def zeros(cls, pk, evaluator, size: int, stride: int = 1):
return cls(pk, evaluator, evaluator.zeros(size * stride), stride)

def iadd_slice(self, index, value):
from fate.arch.tensor.phe import PHETensor

if isinstance(value, PHETensor):
value = value.data
self.evaluator.i_add(self.pk, self.data, value, index * self.stride)
return self

Expand Down Expand Up @@ -311,12 +315,14 @@ def create(cls, node_size, feature_bin_sizes, values_schema: dict):
raise NotImplementedError
return cls(indexer, values_mapping)

def i_update(self, nids, fids, targets):
for nid, bins, target in zip(nids, fids, targets):
def i_update(self, fids, nids, targets):
for i in range(fids.shape[0]):
nid = nids[i][0]
bins = fids[i]
for fid, bid in enumerate(bins):
index = self._indexer.get_position(nid, fid, bid)
for name, value in target.items():
self._values_mapping[name].iadd_slice(index, value)
for name, value in targets.items():
self._values_mapping[name].iadd_slice(index, value[i])
return self

def iadd(self, hist: "Histogram"):
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/phe/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def slice_f(input, item):
evaluator = input.evaluator
stride = input.shape[1]
start = stride * item
data = evaluator.slice(input, start, stride)
data = evaluator.slice(input._data, start, stride)
return PHETensor(input.pk, evaluator, input.coder, torch.Size([*input.shape[1:]]), data, input.dtype)


Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/phe/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, pk, evaluator, coder, shape: torch.Size, data, dtype) -> None
self._dtype = dtype

def __repr__(self) -> str:
return f"<PaillierTensor shape={self.shape}, dtype={self.dtype}>"
return f"<PHETensor shape={self.shape}, dtype={self.dtype}>"

def __str__(self) -> str:
return self.__repr__()
Expand Down
32 changes: 17 additions & 15 deletions python/fate/test/test_histogram.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import pickle
import random

import pandas as pd
import torch
from fate.arch import Context
from fate.arch.computing.standalone import CSession
from fate.arch.federation.standalone import StandaloneFederation
from fate.arch.histogram.histogram import DistributedHistogram, Histogram

ctx = Context()
computing = CSession()

arbiter = ("arbiter", "10000")
guest = ("guest", "10000")
host = ("host", "9999")
name = "fed"
ctx = Context(computing=computing, federation=StandaloneFederation(computing, name, guest, [guest, host, arbiter]))
kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024})
sk, pk, coder, evaluator = kit.sk, kit.pk, kit.coder, kit.evaluator
sk, pk, coder, evaluator, encryptor = kit.sk, kit.pk, kit.coder, kit.evaluator, kit.get_tensor_encryptor()


def test_plain():
Expand Down Expand Up @@ -267,13 +273,13 @@ def test_distributed_hist():


def test_distributed_hist_calling_from_df():
import multiprocessing
import random

import pandas as pd
from fate.arch.dataframe import DataFrame, PandasReader

multiprocessing.set_start_method("fork")
# import multiprocessing
# multiprocessing.set_start_method("fork")

data_list = []
for i in range(100):
Expand All @@ -286,15 +292,6 @@ def test_distributed_hist_calling_from_df():
columns=["sample_id", "match_id", "node_id"],
)

computing = CSession()
from fate.arch.federation.standalone import StandaloneFederation

arbiter = ("arbiter", "10000")
guest = ("guest", "10000")
host = ("host", "9999")
name = "fed"
ctx = Context(computing=computing, federation=StandaloneFederation(computing, name, guest, [guest, host, arbiter]))

df_reader = PandasReader(
sample_id_name="sample_id",
match_id_name="match_id",
Expand All @@ -310,7 +307,6 @@ def test_distributed_hist_calling_from_df():
one_df = df.create_frame()
one_df["one"] = 1

encryptor = kit.get_tensor_encryptor()
# decryptor = kit.get_tensor_encryptor()

targets = dict(one=one_df["one"].as_tensor(), g=encryptor.encrypt_tensor(df.label.as_tensor()))
Expand All @@ -327,5 +323,11 @@ def test_distributed_hist_calling_from_df():

stat_obj = df.distributed_hist_stat(hist, pos_df, targets)

out = stat_obj.decrypt(sk_map={"g": sk, "h": sk}, coder_map={"g": (coder, torch.float32)})
print(out)
out = out.reshape([3, 2])
out.i_shuffle(seed=0, reverse=True)
print(out)


# test_distributed_hist_calling_from_df()

0 comments on commit 47a59c4

Please sign in to comment.