diff --git a/amlb/datasets/openml.py b/amlb/datasets/openml.py index 3779f3d36..657f49f99 100644 --- a/amlb/datasets/openml.py +++ b/amlb/datasets/openml.py @@ -69,6 +69,15 @@ def __init__(self, oml_task: oml.OpenMLTask, oml_dataset: oml.OpenMLDataset, fol self.fold = fold self._train = None self._test = None + self._nrows = None + + + @property + def nrows(self) -> int: + if self._nrows is None: + self._nrows = len(self._load_full_data(fmt='dataframe')) + return self._nrows + @lazy_property def type(self): @@ -110,9 +119,13 @@ def inference_subsample_files(self, fmt: str, with_labels: bool = False, scikit_ are imputed. """ seed = rget().seed(self.fold) + batch_sizes = [ + batch_size for batch_size in rconfig().inference_time_measurements.batch_sizes + if not (batch_size > self.nrows and rconfig().inference_time_measurements.limit_by_dataset_size) + ] return [ (n, str(self._inference_subsample(fmt=fmt, n=n, seed=seed + i, with_labels=with_labels, scikit_safe=scikit_safe))) - for n in rconfig().inference_time_measurements.batch_sizes + for n in batch_sizes for i, _ in enumerate(range(rconfig().inference_time_measurements.repeats)) ] diff --git a/resources/config.yaml b/resources/config.yaml index e0d526a7a..0d4755936 100644 --- a/resources/config.yaml +++ b/resources/config.yaml @@ -87,8 +87,10 @@ results: # configuration namespace for the results.csv file. inference_time_measurements: # configuration namespace for performing additional inference time measurements on various batch sizes enabled: false batch_sizes: [1, 10, 100, 1000, 10000] # the batch sizes for which inference speed should be measured - repeats: 100 # the number of times to repeat the inference measurement for each batch size + repeats: 10 # the number of times to repeat the inference measurement for each batch size additional_job_time: 300 # the time in seconds that will be added to the maximum job time if inference time is measured + limit_by_dataset_size: true # Don't measure inference time on `batch size` if it exceeds the number of rows in the dataset. + # E.g., on micro-mass (571 rows) with `batch_sizes` [1, 10, 100, 1000, 10000], only measure [1, 10, 100]. openml: # configuration namespace for openML. apikey: c1994bdb7ecb3c6f3c8f3b35f4b47f1f