Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separately save patch files + some typing cleanup #126

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
import traceback
from typing import Any, Dict
from typing import Any, Dict, Optional
import yaml

from dataclasses import dataclass
Expand Down Expand Up @@ -92,14 +92,15 @@ def main(args: ScriptArguments):
env = SWEEnv(args.environment)

traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
os.makedirs(traj_dir, exist_ok=True)
traj_dir.mkdir(parents=True, exist_ok=True)

save_arguments(traj_dir, args)

for index in range(len(env.data)):
try:
# Reset environment
instance_id = env.data[index]["instance_id"]
assert isinstance(instance_id, str) # mypy
if should_skip(args, traj_dir, instance_id):
continue
logger.info("▶️ Beginning task " + str(index))
Expand Down Expand Up @@ -140,6 +141,7 @@ def main(args: ScriptArguments):
return_type="info_trajectory",
)
save_predictions(traj_dir, instance_id, info)
save_patch(traj_dir, instance_id, info)
if args.actions.open_pr and should_open_pr(args, info, token=env.token):
env.open_pr(args.actions, info, trajectory)

Expand All @@ -156,7 +158,7 @@ def main(args: ScriptArguments):
continue


def should_open_pr(args, info: Dict[str, Any], *, token: str="") -> bool:
def should_open_pr(args: ScriptArguments, info: Dict[str, Any], *, token: str="") -> bool:
"""Does opening a PR make sense?"""
if not info.get("submission"):
logger.info("Not openening PR because submission was made.")
Expand Down Expand Up @@ -194,7 +196,7 @@ def should_open_pr(args, info: Dict[str, Any], *, token: str="") -> bool:
return True


def save_arguments(traj_dir, args):
def save_arguments(traj_dir: Path, args: ScriptArguments) -> None:
"""Save the arguments to a yaml file to the run's trajectory directory."""
log_path = traj_dir / "args.yaml"

Expand All @@ -212,7 +214,7 @@ def save_arguments(traj_dir, args):
args.dump_yaml(f)


def should_skip(args, traj_dir, instance_id):
def should_skip(args: ScriptArguments, traj_dir: Path, instance_id: str) -> bool:
"""Check if we should skip this instance based on the instance filter and skip_existing flag."""
# Skip instances that don't match the instance filter
if re.match(args.instance_filter, instance_id) is None:
Expand Down Expand Up @@ -240,8 +242,8 @@ def should_skip(args, traj_dir, instance_id):
return False


def save_predictions(traj_dir, instance_id, info):
output_file = Path(traj_dir) / "all_preds.jsonl"
def save_predictions(traj_dir: Path, instance_id: str, info):
output_file = traj_dir / "all_preds.jsonl"
model_patch = info["submission"] if "submission" in info else None
datum = {
KEY_MODEL: Path(traj_dir).name,
Expand All @@ -253,6 +255,24 @@ def save_predictions(traj_dir, instance_id, info):
logger.info(f"Saved predictions to {output_file}")


def save_patch(traj_dir: Path, instance_id: str, info) -> Optional[Path]:
"""Create patch files that can be applied with `git am`.

Returns:
The path to the patch file, if it was saved. Otherwise, returns None.
"""
patch_output_dir = traj_dir / "patches"
patch_output_dir.mkdir(exist_ok=True, parents=True)
patch_output_file = patch_output_dir / f"{instance_id}.patch"
if not "submission" in info:
logger.info("No patch to save.")
return
model_patch = info["submission"]
patch_output_file.write_text(model_patch)
logger.info(f"Saved patch to {patch_output_file}")
return patch_output_file


def get_args(args=None) -> ScriptArguments:
"""Parse command line arguments and return a ScriptArguments object.

Expand Down