diff --git a/amlb/benchmark.py b/amlb/benchmark.py index d05b2666d..b6c2840df 100644 --- a/amlb/benchmark.py +++ b/amlb/benchmark.py @@ -21,6 +21,8 @@ import sys from typing import List, Union +import pandas as pd + from .job import Job, JobError, SimpleJobRunner, MultiThreadingJobRunner from .datasets import DataLoader, DataSourceType from .data import DatasetType @@ -59,12 +61,17 @@ class Benchmark: - openml datasets - openml studies (=benchmark suites) - user-defined (list of) datasets + + :param job_history: str or pd.DataFrame, default = None + If specified, jobs will be skipped if their result is present in job_history. + Useful to avoid duplicate work when trying to retry failed jobs. + """ data_loader = None framework_install_required = True - def __init__(self, framework_name: str, benchmark_name: str, constraint_name: str): + def __init__(self, framework_name: str, benchmark_name: str, constraint_name: str, job_history: str | pd.DataFrame | None = None): self.job_runner = None if rconfig().run_mode == 'script': @@ -80,6 +87,8 @@ def __init__(self, framework_name: str, benchmark_name: str, constraint_name: st if Benchmark.data_loader is None: Benchmark.data_loader = DataLoader(rconfig()) + self._job_history = self._load_job_history(job_history=job_history) + fsplits = framework_name.split(':', 1) framework_name = fsplits[0] tag = fsplits[1] if len(fsplits) > 1 else None @@ -110,6 +119,20 @@ def _validate(self): log.warning("Parallelization is not supported in local mode: ignoring `parallel_jobs=%s` parameter.", self.parallel_jobs) self.parallel_jobs = 1 + def _load_job_history(self, job_history: str | pd.DataFrame | None) -> pd.DataFrame: + """ + If job_history is None, return None + If str, load result csv from str, return pandas DataFrame + If pandas DataFrame, return pandas DataFrame + """ + if job_history is None: + return None + if isinstance(job_history, str): + log.info(f'Loading job history from {job_history}') + job_history = read_csv(job_history) + self._validate_job_history(job_history=job_history) + return job_history + def setup(self, mode: SetupMode): """ ensure all dependencies needed by framework are available @@ -211,6 +234,7 @@ def run(self, tasks: str | list[str] | None = None, folds: int | list[int] | Non task_defs = self._get_task_defs(tasks) jobs = flatten([self._task_jobs(task_def, folds) for task_def in task_defs]) + log.info(f"Running {len(jobs)} jobs") results = self._run_jobs(jobs) log.info(f"Processing results for {self.sid}") log.debug(results) @@ -299,14 +323,6 @@ def _make_job(self, task_def, fold: int): """ return BenchmarkTask(self, task_def, fold).as_job() if not self._skip_job(task_def, fold) else None - @lazy_property - def _job_history(self): - jh = rconfig().job_history - if jh and not os.path.exists(jh): - log.warning(f"Job history file {jh} does not exist, ignoring it.") - return None - return read_csv(jh) if jh else None - def _in_job_history(self, task_def, fold): jh = self._job_history if jh is None: @@ -316,13 +332,21 @@ def _in_job_history(self, task_def, fold): & (jh.id == task_def.id) & (jh.fold == fold)]) > 0 + @staticmethod + def _validate_job_history(job_history): + required_columns = {'framework', 'constraint', 'id', 'fold'} + actual_columns = set(job_history.columns) + if missing_columns := (required_columns - actual_columns): + quoted_columns = ', '.join(repr(c) for c in missing_columns) + raise AssertionError(f'job_history missing required column(s) {quoted_columns}! ') + def _skip_job(self, task_def, fold): if fold < 0 or fold >= task_def.folds: log.warning(f"Fold value {fold} is out of range for task {task_def.name}, skipping it.") return True if self._in_job_history(task_def, fold): - log.info(f"Task {task_def.name} with fold {fold} is already present in job history {rconfig().job_history}, skipping it.") + log.info(f"Task {task_def.name} with fold {fold} is already present in job history, skipping it.") return True return False diff --git a/amlb/runners/aws.py b/amlb/runners/aws.py index 946a1f63b..c8b5dfa22 100644 --- a/amlb/runners/aws.py +++ b/amlb/runners/aws.py @@ -119,7 +119,7 @@ def _on_done(job_self): finally: bench.cleanup() - def __init__(self, framework_name, benchmark_name, constraint_name, region=None): + def __init__(self, framework_name, benchmark_name, constraint_name, region=None, job_history: str = None): """ :param framework_name: @@ -127,7 +127,7 @@ def __init__(self, framework_name, benchmark_name, constraint_name, region=None) :param constraint_name: :param region: """ - super().__init__(framework_name, benchmark_name, constraint_name) + super().__init__(framework_name, benchmark_name, constraint_name, job_history=job_history) self.suid = datetime_iso(micros=True, no_sep=True) # short sid for AWS entities whose name length is limited self.region = (region if region else rconfig().aws.region if rconfig().aws['region'] @@ -1219,12 +1219,12 @@ class AWSRemoteBenchmark(Benchmark): # TODO: idea is to handle results progressively on the remote side and push results as soon as they're generated # this would allow to safely run multiple tasks on single AWS instance - def __init__(self, framework_name, benchmark_name, constraint_name, region=None): + def __init__(self, framework_name, benchmark_name, constraint_name, region=None, job_history: str = None): self.region = region self.s3 = boto3.resource('s3', region_name=self.region) self.bucket = self._init_bucket() self._download_resources() - super().__init__(framework_name, benchmark_name, constraint_name) + super().__init__(framework_name, benchmark_name, constraint_name, job_history=job_history) def run(self, save_scores=False): super().run(save_scores) diff --git a/runbenchmark.py b/runbenchmark.py index ad5eac7e3..4c6ffec9c 100644 --- a/runbenchmark.py +++ b/runbenchmark.py @@ -49,6 +49,9 @@ parser.add_argument('-u', '--userdir', metavar='user_dir', default=None, help="Folder where all the customizations are stored." f"(default: '{default_dirs.user_dir}')") +parser.add_argument('--jobhistory', metavar='job_history', default=None, + help="File where prior job run results are stored. Only used when --resume is specified." + f"(default: 'None')") parser.add_argument('-p', '--parallel', metavar='parallel_jobs', type=int, default=1, help="The number of jobs (i.e. tasks or folds) that can run in parallel." "\nA hard limit is defined by property `job_scheduler.max_parallel_jobs`" @@ -159,24 +162,38 @@ # merging all configuration files amlb_res = amlb.resources.from_configs(config_default, config_default_dirs, config_user, config_args) if args.resume: - amlb_res.config.job_history = os.path.join(amlb_res.config.output_dir, amlb.results.Scoreboard.results_file) + if args.jobhistory is not None: + job_history = args.jobhistory + else: + job_history = os.path.join(amlb_res.config.output_dir, amlb.results.Scoreboard.results_file) +else: + job_history = None bench = None exit_code = 0 try: + bench_kwargs = dict( + framework_name=args.framework, + benchmark_name=args.benchmark, + constraint_name=args.constraint, + ) + if job_history is not None: + bench_kwargs['job_history'] = job_history + if args.mode == 'local': - bench = amlb.Benchmark(args.framework, args.benchmark, args.constraint) + bench_cls = amlb.Benchmark elif args.mode == 'docker': - bench = amlb.DockerBenchmark(args.framework, args.benchmark, args.constraint) + bench_cls = amlb.DockerBenchmark elif args.mode == 'singularity': - bench = amlb.SingularityBenchmark(args.framework, args.benchmark, args.constraint) + bench_cls = amlb.SingularityBenchmark elif args.mode == 'aws': - bench = amlb.AWSBenchmark(args.framework, args.benchmark, args.constraint) + bench_cls = amlb.AWSBenchmark # bench = amlb.AWSBenchmark(args.framework, args.benchmark, args.constraint, region=args.region) # elif args.mode == "aws-remote": # bench = amlb.AWSRemoteBenchmark(args.framework, args.benchmark, args.constraint, region=args.region) else: raise ValueError("`mode` must be one of 'aws', 'docker', 'singularity' or 'local'.") + bench = bench_cls(**bench_kwargs) if args.setup == 'only': log.warning("Setting up %s environment only for %s, no benchmark will be run.", args.mode, args.framework)