Skip to content

Commit

Permalink
Merge pull request #5060 from FederatedAI/dev-2.0.0-beta-improve-iupdate
Browse files Browse the repository at this point in the history
dev 2.0.0 beta improve iupdate
  • Loading branch information
mgqa34 authored Aug 21, 2023
2 parents 6d28612 + 42d40af commit 09e25e5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 24 deletions.
54 changes: 36 additions & 18 deletions python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import typing
import logging
from typing import List, MutableMapping, Tuple

import torch

# from fate_utils.histogram import HistogramIndexer, Shuffler
from .indexer import HistogramIndexer, Shuffler

logger = logging.getLogger(__name__)


class HistogramValues:
def iadd_slice(self, value, sa, sb, size):
raise NotImplementedError

def i_update(self, value, positions):
raise NotImplementedError

def iadd(self, other):
raise NotImplementedError

Expand Down Expand Up @@ -61,6 +67,11 @@ def iadd_slice(self, value, sa, sb, size):
self.evaluator.i_add(self.pk, self.data, value, sa * self.stride, sb, size * self.stride)
return self

def i_update(self, value, positions):
for i, feature_positions in enumerate(positions):
for pos in feature_positions:
self.iadd_slice(value, pos, i * self.stride, self.stride)

def iadd(self, other):
self.evaluator.i_add(self.pk, self.data, other.data)
return self
Expand Down Expand Up @@ -163,21 +174,31 @@ def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]):
start = 0
for s, e in intervals:
end = start + (e - s) * self.stride
result[start:end] = self.data[s * self.stride : e * self.stride]
result[start:end] = self.data[s * self.stride: e * self.stride]
start = end
return HistogramPlainValues(result, self.stride)

def iadd_slice(self, value, sa, sb, size):
size = size * self.stride
value = value.view(-1)
self.data[sa : sa + size] += value[sb : sb + size]
self.data[sa: sa + size] += value[sb: sb + size]

def slice(self, start, end):
return HistogramPlainValues(self.data[start * self.stride : end * self.stride], self.stride)
return HistogramPlainValues(self.data[start * self.stride: end * self.stride], self.stride)

def iadd(self, other):
self.data += other.data

def i_update(self, value, positions):
index = torch.LongTensor(positions)
value = value.reshape(-1, self.stride).expand(-1, index.shape[1]).flatten()
index = index.flatten()

if self.data.dtype != value.dtype:
logger.warning(f"update value dtype {value.dtype} is not equal to data dtype {self.data.dtype}")
value = value.to(self.data.dtype)
self.data.scatter_add_(0, index, value)

def i_shuffle(self, shuffler: "Shuffler", reverse=False):
indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse)
self.data = self.data[indices]
Expand All @@ -186,14 +207,14 @@ def i_chunking_cumsum(self, chunk_sizes: typing.List[int]):
data_view = self.data.view(-1, self.stride)
start = 0
for num in chunk_sizes:
data_view[start : start + num, :] = data_view[start : start + num, :].cumsum(dim=0)
data_view[start: start + num, :] = data_view[start: start + num, :].cumsum(dim=0)
start += num

def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]):
result = torch.zeros(len(intervals) * self.stride, dtype=self.data.dtype)
data_view = self.data.view(-1, self.stride)
for i, (start, end) in enumerate(intervals):
result[i * self.stride : (i + 1) * self.stride] = data_view[start:end, :].sum(dim=0)
result[i * self.stride: (i + 1) * self.stride] = data_view[start:end, :].sum(dim=0)
return HistogramPlainValues(result, self.stride)

@classmethod
Expand Down Expand Up @@ -242,10 +263,7 @@ def i_update(self, fids, nids, targets):
nids.flatten().detach().numpy().tolist(), fids.detach().numpy().tolist()
)
for name, value in targets.items():
shape = value.shape
for i, feature_positions in enumerate(positions):
for pos in feature_positions:
self._values_mapping[name].iadd_slice(value, pos, i * shape[1], shape[1])
self._values_mapping[name].i_update(value, positions)
return self

def iadd(self, hist: "Histogram"):
Expand Down Expand Up @@ -375,9 +393,9 @@ def i_decode(self, coder_map):
return self

def decrypt(
self,
sk_map: MutableMapping[str, typing.Any],
coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]],
self,
sk_map: MutableMapping[str, typing.Any],
coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]],
):
self.i_decrypt(sk_map)
self.i_decode(coder_map)
Expand Down Expand Up @@ -418,15 +436,15 @@ def i_update(self, data, k=None):
ShuffledHistogram, the shuffled(if seed is not None) histogram
"""
if k is None:
k = data.count()
k = data.partitions
mapper = get_partition_hist_build_mapper(
self._node_size, self._feature_bin_sizes, self._value_schemas, self._seed, k
)
table = data.mapReducePartitions(mapper, lambda x, y: x.iadd(y))
return ShuffledHistogram(table, self._node_size, self._node_data_size)

def recover_feature_bins(
self, seed, split_points: typing.Dict[int, int]
self, seed, split_points: typing.Dict[int, int]
) -> typing.Dict[int, typing.Tuple[int, int]]:
"""
Recover the feature bins from the split points.
Expand Down Expand Up @@ -456,9 +474,9 @@ def __init__(self, table, node_size, node_data_size):
self._node_data_size = node_data_size

def decrypt(
self,
sk_map: MutableMapping[str, typing.Any],
coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]],
self,
sk_map: MutableMapping[str, typing.Any],
coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]],
):
out = list(self._table.map(lambda pid, split: (pid, split.decrypt(sk_map, coder_map))).collect())
out.sort(key=lambda x: x[0])
Expand All @@ -470,7 +488,7 @@ def cat(self, hists: typing.List["HistogramSplits"]) -> "Histogram":


def argmax_reducer(
max1: typing.Dict[int, typing.Tuple[int, int, float]], max2: typing.Dict[int, typing.Tuple[int, int, float]]
max1: typing.Dict[int, typing.Tuple[int, int, float]], max2: typing.Dict[int, typing.Tuple[int, int, float]]
):
for nid, (pid, index, gain) in max2.items():
if nid in max1:
Expand Down
4 changes: 2 additions & 2 deletions python/fate/components/core/spec/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
class LoggerConfig(pydantic.BaseModel):
config: Optional[dict] = None

def install(self):
if self.config is None:
def install(self, debug=False):
if debug or self.config is None:
handler_name = "rich_handler"
self.config = dict(
version=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
show_default=True,
help="path for execution meta generated by component when execution finished",
)
def execute(process_tag, config, config_entrypoint, properties, env_prefix, env_name, execution_final_meta_path):
@click.option("--debug", is_flag=True, help="enable debug mode")
def execute(process_tag, config, config_entrypoint, properties, env_prefix, env_name, execution_final_meta_path, debug):
"""
execute component
"""
Expand Down Expand Up @@ -70,7 +71,7 @@ def execute(process_tag, config, config_entrypoint, properties, env_prefix, env_
task_config = TaskConfigSpec.parse_obj(configs)

# install logger
task_config.conf.logger.install()
task_config.conf.logger.install(debug=debug)
logger = logging.getLogger(__name__)
logger.debug("logger installed")
logger.debug(f"task config: {task_config}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def _get_plain_text_schema(self):
return {
"g": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"h": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.int32},
}

def _get_enc_hist_schema(self, pk, evaluator):
return {
"g":{"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator},
"h":{"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.int32},
}

def compute_hist(self, ctx: Context, nodes: List[Node], bin_train_data: DataFrame, gh: DataFrame, sample_pos: DataFrame = None, node_map={}, pk=None, evaluator=None):
Expand Down

0 comments on commit 09e25e5

Please sign in to comment.