Skip to content

Commit

Permalink
Added logging and made dtype adjustments in histogram modules
Browse files Browse the repository at this point in the history
- Introduced `logging` in `fate/arch/histogram/histogram.py` to provide better traceability.
- Added a warning log when the update value data type does not match the histogram data type.
- Adjusted the data type of "cnt" from `torch.float32` to `torch.int32` in `fate/ml/ensemble/learner/decision_tree/tree_core/hist.py` for better type consistency.

Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Aug 21, 2023
1 parent 807629c commit 8a61bcf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import typing
import logging
from typing import List, MutableMapping, Tuple

import torch

# from fate_utils.histogram import HistogramIndexer, Shuffler
from .indexer import HistogramIndexer, Shuffler

loggger = logging.getLogger(__name__)


class HistogramValues:
def iadd_slice(self, value, sa, sb, size):
Expand Down Expand Up @@ -192,10 +195,10 @@ def i_update(self, value, positions):
index = index.flatten()

if self.data.dtype != value.dtype:
loggger.warning(f"update value dtype {value.dtype} is not equal to data dtype {self.data.dtype}")
value = value.to(self.data.dtype)
self.data.scatter_add_(0, index, value)


def i_shuffle(self, shuffler: "Shuffler", reverse=False):
indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse)
self.data = self.data[indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def _get_plain_text_schema(self):
return {
"g": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"h": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.int32},
}

def _get_enc_hist_schema(self, pk, evaluator):
return {
"g":{"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator},
"h":{"type": "paillier", "stride": 1, "pk": pk, "evaluator": evaluator},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "tensor", "stride": 1, "dtype": torch.int32},
}

def compute_hist(self, ctx: Context, nodes: List[Node], bin_train_data: DataFrame, gh: DataFrame, sample_pos: DataFrame = None, node_map={}, pk=None, evaluator=None):
Expand Down

0 comments on commit 8a61bcf

Please sign in to comment.