Skip to content

Commit

Permalink
fix binning when he kind is mock(#4660)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Sep 12, 2023
1 parent 0c2ab73 commit b7315c4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
10 changes: 5 additions & 5 deletions python/fate/arch/histogram/values/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ def __init__(self, data, size: int, dtype: torch.dtype, stride: int):
self.stride = stride

def decode_f64(self, coder):
return HistogramPlainValues(coder.decode_f64_vec(self.data), self.size, self.stride)
return HistogramPlainValues(coder.decode_f64_vec(self.data), self.dtype, self.size, self.stride)

def decode_i64(self, coder):
return HistogramPlainValues(coder.decode_i64_vec(self.data), self.size, self.stride)
return HistogramPlainValues(coder.decode_i64_vec(self.data), self.dtype, self.size, self.stride)

def decode_f32(self, coder):
return HistogramPlainValues(coder.decode_f32_vec(self.data), self.size, self.stride)
return HistogramPlainValues(coder.decode_f32_vec(self.data), self.dtype, self.size, self.stride)

def decode_i32(self, coder):
return HistogramPlainValues(coder.decode_i32_vec(self.data), self.size, self.stride)
return HistogramPlainValues(coder.decode_i32_vec(self.data), self.dtype, self.size, self.stride)

def decode(self, coder, dtype):
if dtype is None:
Expand All @@ -43,7 +43,7 @@ def decode(self, coder, dtype):

def unpack(self, coder, pack_num, offset_bit, precision, total_num, stride):
data = coder.unpack_floats(self.data, offset_bit, pack_num, precision, total_num)
return HistogramPlainValues(data, self.size, stride)
return HistogramPlainValues(data, self.dtype, self.size, stride)

def slice(self, start, end):
if hasattr(self.data, "slice"):
Expand Down
16 changes: 9 additions & 7 deletions python/fate/ml/feature_binning/hetero_feature_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import pandas as pd
import torch

from fate.arch import Context
from fate.arch.histogram import HistogramBuilder
Expand Down Expand Up @@ -94,8 +93,8 @@ def compute_federated_metrics(self, ctx: Context, binned_data):
host_event_non_event_count)):
host_event_non_event_count_hist = en_host_count_res.decrypt({"event_count": sk,
"non_event_count": sk},
{"event_count": (coder, torch.int32),
"non_event_count": (coder, torch.int32)})
{"event_count": (coder, None),
"non_event_count": (coder, None)})
host_event_non_event_count_hist = host_event_non_event_count_hist.reshape(bin_sizes)
summary_metrics, _ = self._bin_obj.compute_all_col_metrics(host_event_non_event_count_hist,
col_bin_list)
Expand Down Expand Up @@ -196,16 +195,18 @@ def compute_federated_metrics(self, ctx: Context, binned_data):
hist_targets = binned_data.create_frame()
hist_targets["event_count"] = encrypt_y
hist_targets["non_event_count"] = 1
dtypes = hist_targets.dtypes

hist_schema = {"event_count": {"type": "ciphertext",
"stride": 1,
"pk": pk,
"evaluator": evaluator,
"coder": coder,
"dtype": torch.int32,
"dtype": dtypes["event_count"],
},
"non_event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32}
"dtype": dtypes["non_event_count"]}
}
hist = HistogramBuilder(num_node=1,
feature_bin_sizes=feature_bin_sizes,
Expand Down Expand Up @@ -372,12 +373,13 @@ def compute_metrics(self, binned_data):
hist_targets = binned_data.create_frame()
hist_targets["event_count"] = binned_data.label
hist_targets["non_event_count"] = 1
dtypes = hist_targets.dtypes
hist_schema = {"event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32},
"dtype": dtypes["event_count"]},
"non_event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32}
"dtype": dtypes["non_event_count"]}
}
hist = HistogramBuilder(num_node=1,
feature_bin_sizes=feature_bin_sizes,
Expand Down

0 comments on commit b7315c4

Please sign in to comment.