Skip to content

Commit

Permalink
Apply DI implementation to project
Browse files Browse the repository at this point in the history
  • Loading branch information
pdemro committed Sep 6, 2024
1 parent aa23cab commit 9616842
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 32 deletions.
10 changes: 10 additions & 0 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import docker.errors
import docker.models.containers
from sweagent import REPO_ROOT
from sweagent.agent.issueService.issue_service_factory import (
IssueServiceFactory
)
from sweagent.environment.utils import (
PROCESS_DONE_MARKER_END,
PROCESS_DONE_MARKER_START,
Expand Down Expand Up @@ -165,6 +168,12 @@ def __init__(self, args: EnvironmentArguments):

self._github_token: str = keys_config.get("GITHUB_TOKEN", "") # type: ignore

# Get Problem Statement
self.logger.debug("Hello Demro")
issue_service_factory = IssueServiceFactory()
issue_service = issue_service_factory.create_issue_factory(self.data_path)
problem_statement_results = issue_service.get_problem_statement()

# Load Task Instances
self.data_path = self.args.data_path
self.data = get_instances(
Expand All @@ -173,6 +182,7 @@ def __init__(self, args: EnvironmentArguments):
self.args.split,
token=self._github_token,
repo_path=self.args.repo_path,
problem_statement_results=problem_statement_results
)
#: Instance we're currently processing. Gets set in self.reset.
self.record: dict[str, Any] | None = None
Expand Down
40 changes: 8 additions & 32 deletions sweagent/environment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from sweagent.utils.config import keys_config
from sweagent.utils.log import get_logger

from sweagent.agent.issueService.issue_service import ProblemStatementResults

DOCKER_START_UP_DELAY = float(keys_config.get("SWE_AGENT_DOCKER_START_UP_DELAY", 1))
GITHUB_ISSUE_URL_PATTERN = re.compile(r"github\.com\/(.*?)\/(.*?)\/issues\/(\d+)")
GITHUB_REPO_URL_PATTERN = re.compile(r".*[/@]?github\.com\/([^/]+)\/([^/]+)")
Expand Down Expand Up @@ -516,37 +518,10 @@ def __init__(self, token: str | None = None):
self.token = token
self._instance_id_problem_suffix = ""

def set_problem_statement_from_gh_issue(self, issue_url: str):
owner, repo, issue_number = parse_gh_issue_url(issue_url)
self.args["problem_statement"] = get_problem_statement_from_github_issue(
owner,
repo,
issue_number,
token=self.token,
)
self.args["instance_id"] = f"{owner}__{repo}-i{issue_number}"
self.args["problem_statement_source"] = "online"

def set_problem_statement_from_file(self, file_path: str):
self.set_problem_statement_from_text(Path(file_path).read_text())

def set_problem_statement_from_text(self, text: str):
self.args["problem_statement"] = text
self.args["instance_id"] = hashlib.sha256(self.args["problem_statement"].encode()).hexdigest()[:6]
self.args["problem_statement_source"] = "local"

def set_problem_statement(self, data_path: str):
"""Get problem statement for a single instance from a github issue url or a
path to a markdown or text file.
"""
if data_path.startswith("text://"):
return self.set_problem_statement_from_text(data_path.removeprefix("text://"))
if is_github_issue_url(data_path):
return self.set_problem_statement_from_gh_issue(data_path)
if Path(data_path).is_file():
return self.set_problem_statement_from_file(data_path)
msg = f"Not sure how to get problem statement from {data_path=}."
raise ValueError(msg)
def set_problem_statement(self, problem_statement_results: ProblemStatementResults):
self.args["problem_statement"] = problem_statement_results.problem_statement
self.args["instance_id"] = problem_statement_results.instance_id.value
self.args["problem_statement_source"] = problem_statement_results.problem_statement_source

def set_repo_info_from_gh_url(self, url: str, base_commit: str | None = None):
owner, repo = parse_gh_repo_url(url)
Expand Down Expand Up @@ -632,6 +607,7 @@ def get_instances(
token: str | None = None,
*,
repo_path: str = "",
problem_statement_results: ProblemStatementResults
) -> list[dict[str, Any]]:
"""
Getter function for handling json, jsonl files
Expand Down Expand Up @@ -661,7 +637,7 @@ def postproc_instance_list(instances):
or is_github_issue_url(file_path)
):
ib = InstanceBuilder(token=token)
ib.set_problem_statement(file_path)
ib.set_problem_statement(problem_statement_results)
if repo_path:
ib.set_repo_info(repo_path, base_commit=base_commit)
elif is_github_repo_url(file_path):
Expand Down

0 comments on commit 9616842

Please sign in to comment.