diff --git a/medcat/stats/kfold.py b/medcat/stats/kfold.py index c4162b9ed..2d8a6dd3f 100644 --- a/medcat/stats/kfold.py +++ b/medcat/stats/kfold.py @@ -4,6 +4,7 @@ from enum import Enum, auto from copy import deepcopy from pydantic import BaseModel +from itertools import islice import numpy as np @@ -205,19 +206,20 @@ def _add_target_ann(self, project: MedCATTrainerExportProject, cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc) cur_doc['annotations'].append(ann) - def _targets(self) -> Iterable[Tuple[MedCATTrainerExportProjectInfo, - MedCATTrainerExportDocument, - MedCATTrainerExportAnnotation]]: - return iter_anns(self.mct_export) + def _targets(self, start_at: int) -> Iterable[Tuple[MedCATTrainerExportProjectInfo, + MedCATTrainerExportDocument, + MedCATTrainerExportAnnotation]]: + return islice(iter_anns(self.mct_export), start_at, None) def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: per_fold = self.per_fold[fold_nr] + already_used = sum(self.per_fold[fn] for fn in range(fold_nr)) cur_fold: MedCATTrainerExport = { 'projects': [] } cur_project: Optional[MedCATTrainerExportProject] = None included = 0 - for target in self._targets(): + for target in self._targets(already_used): proj_info, cur_doc, cur_ann = target proj_name = proj_info[0] if not cur_project or cur_project['name'] != proj_name: diff --git a/tests/stats/test_kfold.py b/tests/stats/test_kfold.py index d06b666ec..60e3fcabf 100644 --- a/tests/stats/test_kfold.py +++ b/tests/stats/test_kfold.py @@ -2,6 +2,7 @@ import json from typing import Dict, Union, Optional from copy import deepcopy +from collections import Counter from medcat.stats import kfold from medcat.cat import CAT @@ -80,6 +81,20 @@ def test_folds_keep_all_anns(self): count_all_once = kfold.count_all_annotations(self.mct_export) self.assertEqual(total_anns, count_all_once) + def count_cuis(self, export: MCTExportTests) -> Counter: + cntr = Counter() + for _, _, ann in kfold.iter_anns(export): + cui = ann["cui"] + cntr[cui] += 1 + return cntr + + def test_folds_keep_ann_targets(self): + orig_cntr = self.count_cuis(self.mct_export) + fold_counter = Counter() + for fold in self.folds: + fold_counter += self.count_cuis(fold) + self.assertEqual(orig_cntr, fold_counter) + def test_1fold_same_as_orig(self): folds = kfold.get_fold_creator(self.mct_export, 1, split_type=self.SPLIT_TYPE).create_folds() self.assertEqual(len(folds), 1)