Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cached task instances (2) #425

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/usage/cl_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ Here,
!!! tip "Running more than once"
* The complete details of the run are saved as a "trajectory" file (more about them [here](trajectories.md)). They can also be turned into new [demonstrations](../config/demonstrations.md).
* If you run the same command more than once, you will find that SWE-agent aborts with ` Skipping existing trajectory`. You can either remove the trajectory from the warning message, or add the `--skip_existing=False` flag.
* If you solve multiple issues from the same repository/in the same environment, you can specify the
`--cache_task_images` flag. This will create a persistent docker image with the initialized environment
required for the problem.


## Specifying the repository
Expand Down Expand Up @@ -139,4 +142,4 @@ And follow the instructions below it:
git apply "${PATCH_FILE_PATH}"
```

{% include-markdown "../_footer.md" %}
{% include-markdown "../_footer.md" %}
24 changes: 14 additions & 10 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
get_requirements,
MAP_VERSION_TO_INSTALL
)
from typing import Optional, Tuple
from typing import List, Optional, Tuple

LONG_TIMEOUT = 500
PATH_TO_REQS = "/root/requirements.txt"
Expand Down Expand Up @@ -100,6 +100,7 @@ class SWEEnv(gym.Env):
"""Gym environment for SWE-bench. This class should handle all communication with the docker container."""

name = "swe_main"
cached_image_prefix = "swe-agent-task-env-"

def __init__(self, args: EnvironmentArguments):
super().__init__()
Expand Down Expand Up @@ -142,20 +143,22 @@ def __init__(self, args: EnvironmentArguments):
self.image_name = args.image_name
self._reset_container()

# Prepare image tag prefix for cached task environments
if self.args.cache_task_images:
logger.info("Task environment caching enabled")
tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__"
assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash."
image_name_without_tag = self.image_name.split(":")[0]
self.cached_image_prefix = f"{image_name_without_tag}:{tag}"

# Set timeout
self.timeout = self.args.timeout
self.idx = 0
self.clean_multi_line_functions = lambda x: x
self.hooks = []

def _get_cached_task_image_name(self) -> str:
assert self.record is not None
inputs: List[str] = [
self.record["repo"],
self.record["base_commit"],
self.args.environment_setup or "no_setup",
]
tag = hashlib.sha256("".join(inputs).encode()).hexdigest()[:50]
return f"{self.cached_image_prefix}{tag}"

def add_hook(self, hook: EnvHook):
hook.on_init()
self.hooks.append(hook)
Expand Down Expand Up @@ -233,7 +236,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->
### Reset Container ###

if self.args.cache_task_images:
cached_image = f"{self.cached_image_prefix}{index}"
cached_image = self._get_cached_task_image_name()
if image_exists(cached_image):
logger.info(f"Restore environment from cached image {cached_image}")
self.close() # stop current container
Expand Down Expand Up @@ -309,6 +312,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->
envs = self.communicate("env")
logger.debug(f"Environment variables to save:\n{envs}\n")
self.communicate("env >> /.env")
assert self.container_obj is not None # mypy
self.container_obj.commit(cached_image)
logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}")

Expand Down
28 changes: 28 additions & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from pathlib import Path
import subprocess
import time
import pytest
import yaml
from sweagent.environment.swe_env import EnvHook, EnvironmentArguments, SWEEnv
Expand Down Expand Up @@ -55,6 +56,33 @@ def test_init_swe_env_non_persistent(test_env_args):
env.reset()


@pytest.mark.slow
def test_init_swe_env_cached_task_image(test_env_args):
test_env_args = dataclasses.replace(test_env_args, cache_task_images=True)
start = time.perf_counter()
with swe_env_context(test_env_args) as env:
env.reset()
duration_no_cache = time.perf_counter() - start
start = time.perf_counter()
# now it should be cached, so let's run again
image_prefix = None
with swe_env_context(test_env_args) as env:
env.reset()
image_prefix = env.cached_image_prefix
assert image_prefix
duration_cache = time.perf_counter() - start
assert duration_cache < duration_no_cache
# Retrieve all images with a prefix "prefix"
client = docker.from_env()
# Remove the images
for image in client.images.list():
if not image.tags:
continue
if not image.tags[0].startswith(image_prefix):
continue
client.images.remove(image.id)


@pytest.mark.slow
def test_execute_setup_script(tmp_path, test_env_args):
test_script = "echo 'hello world'"
Expand Down
Loading