diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py index a07015e462..c4df2a6170 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py @@ -83,24 +83,24 @@ def __init__( self._last_layer_node_map = None self._hist_sub = hist_sub - def _get_plain_text_schema(self): + def _get_plain_text_schema(self, dtypes): return { - "g": {"type": "plaintext", "stride": 1, "dtype": torch.float32}, - "h": {"type": "plaintext", "stride": 1, "dtype": torch.float32}, - "cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32}, + "g": {"type": "plaintext", "stride": 1, "dtype": dtypes["g"]}, + "h": {"type": "plaintext", "stride": 1, "dtype": dtypes["h"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, } - def _get_enc_hist_schema(self, pk, evaluator): + def _get_enc_hist_schema(self, pk, evaluator, dtypes): return { - "g": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32}, - "h": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32}, - "cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32}, + "g": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["g"]}, + "h": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["h"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, } - def _get_pack_en_hist_schema(self, pk, evaluator): + def _get_pack_en_hist_schema(self, pk, evaluator, dtypes): return { - "gh": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": torch.float32}, - "cnt": {"type": "plaintext", "stride": 1, "dtype": torch.int32}, + "gh": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["gh"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, } def _prepare_hist_sub(self, nodes: List[Node], cur_layer_node_map: dict, parent_node_map: dict): @@ -178,15 +178,15 @@ def compute_hist( logger.debug("weak nodes {}, new_node_map {}, mapping {}".format(weak_nodes, new_node_map, mapping)) if ctx.is_on_guest: - schema = self._get_plain_text_schema() + schema = self._get_plain_text_schema(gh.dtypes) elif ctx.is_on_host: if pk is None or evaluator is None: - schema = self._get_plain_text_schema() + schema = self._get_plain_text_schema(gh.dtypes) else: if gh_pack: - schema = self._get_pack_en_hist_schema(pk, evaluator) + schema = self._get_pack_en_hist_schema(pk, evaluator, gh.dtypes) else: - schema = self._get_enc_hist_schema(pk, evaluator) + schema = self._get_enc_hist_schema(pk, evaluator, gh.dtypes) else: raise ValueError("not support called on role: {}".format(ctx.local))