Skip to content

Commit

Permalink
Remove top-k feature when re-running trials.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 19, 2024
1 parent 7178fee commit 263db3e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 41 deletions.
31 changes: 3 additions & 28 deletions tuning/rerun_best_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=
"Re-run the best trials from a previous tuning run.",
"Re-run the best trial from a previous tuning run.",
epilog=f"Example usage:\n"
f"python rerun_best_trials.py tuning_run.json\n",
formatter_class=argparse.RawDescriptionHelpFormatter,
Expand All @@ -25,12 +25,6 @@ def make_parser() -> argparse.ArgumentParser:
help="The algorithm that has been tuned. "
"Can usually be deduced from the study name.",
)
parser.add_argument(
"--top-k",
type=int,
default=1,
help="Chooses the kth best trial to re-run."
)
parser.add_argument(
"journal_log",
type=str,
Expand All @@ -45,7 +39,7 @@ def make_parser() -> argparse.ArgumentParser:
return parser


def infer_algo_name(study: optuna.Study) -> Tuple[str, List[str]]:
def infer_algo_name(study: optuna.Study) -> str:
"""Infer the algo name from the study name.
Assumes that the study name is of the form "tuning_{algo}_with_{named_configs}".
Expand All @@ -55,23 +49,6 @@ def infer_algo_name(study: optuna.Study) -> Tuple[str, List[str]]:
return study.study_name[len("tuning_"):].split("_with_")[0]


def get_top_k_trial(study: optuna.Study, k: int) -> optuna.trial.Trial:
if k <= 0:
raise ValueError(f"--top-k must be positive, but is {k}.")
finished_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
if len(finished_trials) == 0:
raise ValueError("No trials have completed.")
if len(finished_trials) < k:
raise ValueError(
f"Only {len(finished_trials)} trials have completed, but --top-k is {k}."
)

return sorted(
finished_trials,
key=lambda t: t.value, reverse=True,
)[k-1]


def main():
parser = make_parser()
args = parser.parse_args()
Expand All @@ -83,9 +60,7 @@ def main():
# inferred
study_name=None,
)
trial = get_top_k_trial(study, args.top_k)

print(trial.value, trial.params)
trial = study.best_trial

algo_name = args.algo or infer_algo_name(study)
sacred_experiment: sacred.Experiment = hp_search_spaces.objectives_by_algo[algo_name].sacred_ex
Expand Down
24 changes: 11 additions & 13 deletions tuning/rerun_on_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --array=1-5
# Avoid cluttering the root directory with log files:
#SBATCH --output=%x/%a/sbatch_cout.txt
#SBATCH --output=%x/reruns/%a/sbatch_cout.txt
#SBATCH --cpus-per-task=8
#SBATCH --gpus=0
#SBATCH --mem=8gb
Expand All @@ -16,28 +16,26 @@
# A folder with a hyperparameter sweep as started by tune_on_slurm.sh.

# USAGE:
# sbatch rerun_on_slurm <tune_folder> <top-k>
# sbatch --job-name=<name of previous tuning job> rerun_on_slurm.sh
#
# Picks the top-k trial from the optuna study in <tune_folder> and reruns them with
# Picks the best trial from the optuna study in <tune_folder> and reruns them with
# the same hyperparameters but different seeds.

# OUTPUT:
# Creates a subfolder in the given tune_folder for each worker:
# <tune_folder>/reruns/top_<top-k>/<seed>
# Creates a sub-folder in the given tune_folder for each worker:
# <tune_folder>/reruns/<seed>
# The output of each worker is written to a cout.txt.


source "/nas/ucb/$(whoami)/imitation/venv/bin/activate"

if [ -z $2 ]; then
top_k=1
else
top_k=$2
fi

worker_dir="$1/reruns/top_$top_k/$SLURM_ARRAY_TASK_ID/"
worker_dir="$SLURM_JOB_NAME/reruns/$SLURM_ARRAY_TASK_ID/"

if [ -f "$worker_dir/cout.txt" ]; then
# This indicates that there is already a worker running in that directory.
# So we better abort!
echo "There is already a worker running in this directory. \
Try different seeds by picking a different array range!"
exit 1
else
# Note: we run each worker in a separate working directory to avoid race
Expand All @@ -47,4 +45,4 @@ fi

cd "$worker_dir" || exit

srun --output="$worker_dir/cout.txt" python ../../rerun_on_slurm.py "$1/optuna_study.log" --top_k "$top_k" --seed "$SLURM_ARRAY_TASK_ID"
srun --output="$worker_dir/cout.txt" python ../../../rerun_best_trial.py "$SLURM_JOB_NAME/optuna_study.log" --seed "$SLURM_ARRAY_TASK_ID"

0 comments on commit 263db3e

Please sign in to comment.