From 34d3aad25d85e3c2c4e0b0e1991432ab04524fab Mon Sep 17 00:00:00 2001 From: ellabarkan Date: Sun, 11 Aug 2024 09:00:44 -0400 Subject: [PATCH 1/2] added fix for tupples --- fuse/eval/metrics/metrics_common.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 464503ab5..99c37aa49 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -887,6 +887,23 @@ def reset(self) -> None: self._metric.reset() return super().reset() + @staticmethod + def _convert_tuples(ids: List[Tuple[str, int]]) -> np.ndarray: + + sample_tuple = ids[0] + dtype_tuple = [] + + for i, tuple_elem in enumerate(sample_tuple): + if isinstance(tuple_elem, str): + max_len = max(len(str(el[i])) for el in ids) + dtype_tuple.append((f"field{i}", f"U{max_len}")) + else: + dtype_tuple.append((f"field{i}", type(tuple_elem))) + + ids = np.array(ids, dtype=dtype_tuple) + ids = [tuple(x) for x in ids] + return ids + def eval( self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None ) -> Dict[str, Any]: @@ -902,7 +919,11 @@ def eval( raise Exception( "Error: confidence interval is supported only when a unique identifier is specified. Add key 'id' to your data" ) - ids = np.array(ids) + + if isinstance(ids[0], tuple): + ids = self._convert_tuples(ids) + else: + ids = np.array(ids) rnd = np.random.RandomState(self._rnd_seed) original_sample_results = self._metric.eval(results, ids=ids) @@ -920,7 +941,11 @@ def eval( stratum_filter = stratum_id == stratum n_stratum = sum(stratum_filter) random_sample = rnd.randint(0, n_stratum, size=n_stratum) - sampled_ids[stratum_filter] = ids[stratum_filter][random_sample] + + flt_indx = np.where(stratum_filter)[0] + for i, idx in enumerate(random_sample): + sampled_ids[flt_indx[i]] = ids[flt_indx[idx]] + boot_results.append(self._metric.eval(results, sampled_ids)) # results can be either a list of floats or a list of dictionaries From f54914081f7cca14b330235d9cfafbbdb7f1972a Mon Sep 17 00:00:00 2001 From: ellabarkan Date: Mon, 19 Aug 2024 08:18:02 -0400 Subject: [PATCH 2/2] fixed the signature of tupple convert method --- fuse/eval/metrics/metrics_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 99c37aa49..fbd98a287 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -888,7 +888,7 @@ def reset(self) -> None: return super().reset() @staticmethod - def _convert_tuples(ids: List[Tuple[str, int]]) -> np.ndarray: + def _convert_tuples(ids: List[Tuple[Any, ...]]) -> np.ndarray: sample_tuple = ids[0] dtype_tuple = []