diff --git a/tuning/rerun_best_trial.py b/tuning/rerun_best_trial.py index 7467729a4..ed269c30f 100644 --- a/tuning/rerun_best_trial.py +++ b/tuning/rerun_best_trial.py @@ -12,7 +12,7 @@ def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description= - "Re-run the best trials from a previous tuning run.", + "Re-run the best trial from a previous tuning run.", epilog=f"Example usage:\n" f"python rerun_best_trials.py tuning_run.json\n", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -25,12 +25,6 @@ def make_parser() -> argparse.ArgumentParser: help="The algorithm that has been tuned. " "Can usually be deduced from the study name.", ) - parser.add_argument( - "--top-k", - type=int, - default=1, - help="Chooses the kth best trial to re-run." - ) parser.add_argument( "journal_log", type=str, @@ -45,7 +39,7 @@ def make_parser() -> argparse.ArgumentParser: return parser -def infer_algo_name(study: optuna.Study) -> Tuple[str, List[str]]: +def infer_algo_name(study: optuna.Study) -> str: """Infer the algo name from the study name. Assumes that the study name is of the form "tuning_{algo}_with_{named_configs}". @@ -55,23 +49,6 @@ def infer_algo_name(study: optuna.Study) -> Tuple[str, List[str]]: return study.study_name[len("tuning_"):].split("_with_")[0] -def get_top_k_trial(study: optuna.Study, k: int) -> optuna.trial.Trial: - if k <= 0: - raise ValueError(f"--top-k must be positive, but is {k}.") - finished_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE] - if len(finished_trials) == 0: - raise ValueError("No trials have completed.") - if len(finished_trials) < k: - raise ValueError( - f"Only {len(finished_trials)} trials have completed, but --top-k is {k}." - ) - - return sorted( - finished_trials, - key=lambda t: t.value, reverse=True, - )[k-1] - - def main(): parser = make_parser() args = parser.parse_args() @@ -83,9 +60,7 @@ def main(): # inferred study_name=None, ) - trial = get_top_k_trial(study, args.top_k) - - print(trial.value, trial.params) + trial = study.best_trial algo_name = args.algo or infer_algo_name(study) sacred_experiment: sacred.Experiment = hp_search_spaces.objectives_by_algo[algo_name].sacred_ex diff --git a/tuning/rerun_on_slurm.sh b/tuning/rerun_on_slurm.sh index 1f684b79b..d1cb9d301 100644 --- a/tuning/rerun_on_slurm.sh +++ b/tuning/rerun_on_slurm.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --array=1-5 # Avoid cluttering the root directory with log files: -#SBATCH --output=%x/%a/sbatch_cout.txt +#SBATCH --output=%x/reruns/%a/sbatch_cout.txt #SBATCH --cpus-per-task=8 #SBATCH --gpus=0 #SBATCH --mem=8gb @@ -16,28 +16,26 @@ # A folder with a hyperparameter sweep as started by tune_on_slurm.sh. # USAGE: -# sbatch rerun_on_slurm +# sbatch --job-name= rerun_on_slurm.sh # -# Picks the top-k trial from the optuna study in and reruns them with +# Picks the best trial from the optuna study in and reruns them with # the same hyperparameters but different seeds. # OUTPUT: -# Creates a subfolder in the given tune_folder for each worker: -# /reruns/top_/ +# Creates a sub-folder in the given tune_folder for each worker: +# /reruns/ # The output of each worker is written to a cout.txt. source "/nas/ucb/$(whoami)/imitation/venv/bin/activate" -if [ -z $2 ]; then - top_k=1 -else - top_k=$2 -fi - -worker_dir="$1/reruns/top_$top_k/$SLURM_ARRAY_TASK_ID/" +worker_dir="$SLURM_JOB_NAME/reruns/$SLURM_ARRAY_TASK_ID/" if [ -f "$worker_dir/cout.txt" ]; then + # This indicates that there is already a worker running in that directory. + # So we better abort! + echo "There is already a worker running in this directory. \ + Try different seeds by picking a different array range!" exit 1 else # Note: we run each worker in a separate working directory to avoid race @@ -47,4 +45,4 @@ fi cd "$worker_dir" || exit -srun --output="$worker_dir/cout.txt" python ../../rerun_on_slurm.py "$1/optuna_study.log" --top_k "$top_k" --seed "$SLURM_ARRAY_TASK_ID" +srun --output="$worker_dir/cout.txt" python ../../../rerun_best_trial.py "$SLURM_JOB_NAME/optuna_study.log" --seed "$SLURM_ARRAY_TASK_ID"