Skip to content

Commit

Permalink
Merge pull request #5175 from FederatedAI/feature-2.0.0-beta-datafram…
Browse files Browse the repository at this point in the history
…e_add_dtypes

secureboost: hist fix, use dtypes from gh frame
  • Loading branch information
mgqa34 authored Sep 12, 2023
2 parents 030a4d2 + 8cdc115 commit a90deb2
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,24 @@ def __init__(
self._last_layer_node_map = None
self._hist_sub = hist_sub

def _get_plain_text_schema(self):
def _get_plain_text_schema(self, dtypes):
return {
"g": {"type": "plaintext", "stride": 1, "dtype": torch.float32},
"h": {"type": "plaintext", "stride": 1, "dtype": torch.float32},
"cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32},
"g": {"type": "plaintext", "stride": 1, "dtype": dtypes["g"]},
"h": {"type": "plaintext", "stride": 1, "dtype": dtypes["h"]},
"cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]},
}

def _get_enc_hist_schema(self, pk, evaluator):
def _get_enc_hist_schema(self, pk, evaluator, dtypes):
return {
"g": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32},
"h": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32},
"cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32},
"g": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["g"]},
"h": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["h"]},
"cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]},
}

def _get_pack_en_hist_schema(self, pk, evaluator):
def _get_pack_en_hist_schema(self, pk, evaluator, dtypes):
return {
"gh": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32},
"cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32},
"gh": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["gh"]},
"cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]},
}

def _prepare_hist_sub(self, nodes: List[Node], cur_layer_node_map: dict, parent_node_map: dict):
Expand Down Expand Up @@ -178,15 +178,15 @@ def compute_hist(
logger.debug("weak nodes {}, new_node_map {}, mapping {}".format(weak_nodes, new_node_map, mapping))

if ctx.is_on_guest:
schema = self._get_plain_text_schema()
schema = self._get_plain_text_schema(gh.dtypes)
elif ctx.is_on_host:
if pk is None or evaluator is None:
schema = self._get_plain_text_schema()
schema = self._get_plain_text_schema(gh.dtypes)
else:
if gh_pack:
schema = self._get_pack_en_hist_schema(pk, evaluator)
schema = self._get_pack_en_hist_schema(pk, evaluator, gh.dtypes)
else:
schema = self._get_enc_hist_schema(pk, evaluator)
schema = self._get_enc_hist_schema(pk, evaluator, gh.dtypes)
else:
raise ValueError("not support called on role: {}".format(ctx.local))

Expand Down

0 comments on commit a90deb2

Please sign in to comment.