diff --git a/docs/usage/cl_tutorial.md b/docs/usage/cl_tutorial.md index b671cfdf2..01893f450 100644 --- a/docs/usage/cl_tutorial.md +++ b/docs/usage/cl_tutorial.md @@ -31,6 +31,9 @@ Here, !!! tip "Running more than once" * The complete details of the run are saved as a "trajectory" file (more about them [here](trajectories.md)). They can also be turned into new [demonstrations](../config/demonstrations.md). * If you run the same command more than once, you will find that SWE-agent aborts with ` Skipping existing trajectory`. You can either remove the trajectory from the warning message, or add the `--skip_existing=False` flag. + * If you solve multiple issues from the same repository/in the same environment, you can specify the + `--cache_task_images` flag. This will create a persistent docker image with the initialized environment + required for the problem. ## Specifying the repository @@ -139,4 +142,4 @@ And follow the instructions below it: git apply "${PATCH_FILE_PATH}" ``` -{% include-markdown "../_footer.md" %} \ No newline at end of file +{% include-markdown "../_footer.md" %} diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index d9bd4eaa9..5b7579b3e 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -39,7 +39,7 @@ get_requirements, MAP_VERSION_TO_INSTALL ) -from typing import Optional, Tuple +from typing import List, Optional, Tuple LONG_TIMEOUT = 500 PATH_TO_REQS = "/root/requirements.txt" @@ -100,6 +100,7 @@ class SWEEnv(gym.Env): """Gym environment for SWE-bench. This class should handle all communication with the docker container.""" name = "swe_main" + cached_image_prefix = "swe-agent-task-env-" def __init__(self, args: EnvironmentArguments): super().__init__() @@ -142,20 +143,22 @@ def __init__(self, args: EnvironmentArguments): self.image_name = args.image_name self._reset_container() - # Prepare image tag prefix for cached task environments - if self.args.cache_task_images: - logger.info("Task environment caching enabled") - tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__" - assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash." - image_name_without_tag = self.image_name.split(":")[0] - self.cached_image_prefix = f"{image_name_without_tag}:{tag}" - # Set timeout self.timeout = self.args.timeout self.idx = 0 self.clean_multi_line_functions = lambda x: x self.hooks = [] + def _get_cached_task_image_name(self) -> str: + assert self.record is not None + inputs: List[str] = [ + self.record["repo"], + self.record["base_commit"], + self.args.environment_setup or "no_setup", + ] + tag = hashlib.sha256("".join(inputs).encode()).hexdigest()[:50] + return f"{self.cached_image_prefix}{tag}" + def add_hook(self, hook: EnvHook): hook.on_init() self.hooks.append(hook) @@ -233,7 +236,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> ### Reset Container ### if self.args.cache_task_images: - cached_image = f"{self.cached_image_prefix}{index}" + cached_image = self._get_cached_task_image_name() if image_exists(cached_image): logger.info(f"Restore environment from cached image {cached_image}") self.close() # stop current container @@ -309,6 +312,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> envs = self.communicate("env") logger.debug(f"Environment variables to save:\n{envs}\n") self.communicate("env >> /.env") + assert self.container_obj is not None # mypy self.container_obj.commit(cached_image) logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") diff --git a/tests/test_env.py b/tests/test_env.py index 8baa9aabe..0be458627 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -2,6 +2,7 @@ import os from pathlib import Path import subprocess +import time import pytest import yaml from sweagent.environment.swe_env import EnvHook, EnvironmentArguments, SWEEnv @@ -55,6 +56,33 @@ def test_init_swe_env_non_persistent(test_env_args): env.reset() +@pytest.mark.slow +def test_init_swe_env_cached_task_image(test_env_args): + test_env_args = dataclasses.replace(test_env_args, cache_task_images=True) + start = time.perf_counter() + with swe_env_context(test_env_args) as env: + env.reset() + duration_no_cache = time.perf_counter() - start + start = time.perf_counter() + # now it should be cached, so let's run again + image_prefix = None + with swe_env_context(test_env_args) as env: + env.reset() + image_prefix = env.cached_image_prefix + assert image_prefix + duration_cache = time.perf_counter() - start + assert duration_cache < duration_no_cache + # Retrieve all images with a prefix "prefix" + client = docker.from_env() + # Remove the images + for image in client.images.list(): + if not image.tags: + continue + if not image.tags[0].startswith(image_prefix): + continue + client.images.remove(image.id) + + @pytest.mark.slow def test_execute_setup_script(tmp_path, test_env_args): test_script = "echo 'hello world'"