Skip to content

Commit

Permalink
Use custom Config class to support env and keys.cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed May 28, 2024
1 parent 2ae6e97 commit 521ff85
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
1 change: 1 addition & 0 deletions sweagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
"SWEEnv",
"get_data_path_name",
"PACKAGE_DIR",
"CONFIG_DIR",
]
8 changes: 4 additions & 4 deletions sweagent/agent/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import config
import json
import logging
import os
Expand All @@ -9,6 +8,7 @@
from dataclasses import dataclass, fields
from openai import BadRequestError, OpenAI, AzureOpenAI
from simple_parsing.helpers.serialization.serializable import FrozenSerializable, Serializable
from sweagent.utils.config import Config
from sweagent.agent.commands import Command
from tenacity import (
retry,
Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
logging.getLogger("httpx").setLevel(logging.WARNING)

# Set OpenAI key
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
cfg = Config()
if self.args.model_name.startswith("azure"):
self.api_model = cfg["AZURE_OPENAI_DEPLOYMENT"]
self.client = AzureOpenAI(api_key=cfg["AZURE_OPENAI_API_KEY"], azure_endpoint=cfg["AZURE_OPENAI_ENDPOINT"], api_version=cfg.get("AZURE_OPENAI_API_VERSION", "2024-02-01"))
Expand Down Expand Up @@ -338,7 +338,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
super().__init__(args, commands)

# Set Anthropic key
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
cfg = Config()
self.api = Anthropic(api_key=cfg["ANTHROPIC_API_KEY"])

def history_to_messages(
Expand Down Expand Up @@ -660,7 +660,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
assert together.version >= '1.1.0', "Please upgrade to Together SDK v1.1.0 or later."

# Set Together key
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
cfg = Config()
together.api_key = cfg.TOGETHER_API_KEY

def history_to_messages(
Expand Down
12 changes: 4 additions & 8 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
import random
import config
import datetime
import docker
import gymnasium as gym
Expand Down Expand Up @@ -41,6 +40,8 @@
)
from typing import List, Optional, Tuple

from sweagent.utils.config import Config

LONG_TIMEOUT = 500
PATH_TO_REQS = "/root/requirements.txt"
PATH_TO_ENV_YML = "/root/environment.yml"
Expand Down Expand Up @@ -125,12 +126,7 @@ def __init__(self, args: EnvironmentArguments):
except:
logger.warning("Failed to get commit hash for this repo")

self._github_token: str = os.environ.get("GITHUB_TOKEN", "")
if not self._github_token and os.path.isfile(
os.path.join(os.getcwd(), "keys.cfg")
):
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
self._github_token: str = cfg.get("GITHUB_TOKEN", "") # type: ignore
self._github_token: str = Config().get("GITHUB_TOKEN", "") # type: ignore

# Load Task Instances
self.data_path = self.args.data_path
Expand Down Expand Up @@ -556,7 +552,7 @@ def _communicate(
input: str,
timeout_duration=25,
) -> str:
if "SWE_AGENT_EXPERIMENTAL_COMMUNICATE" in os.environ:
if "SWE_AGENT_EXPERIMENTAL_COMMUNICATE" in Config():
return self._communicate_experimental(input, timeout_duration)
try:
self.returncode = None
Expand Down
42 changes: 42 additions & 0 deletions sweagent/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from typing import Any
import config as config_file
import os


logger = logging.getLogger("config")


class Config:
def __init__(self, ):
"""This wrapper class is used to load keys from environment variables or keys.cfg file.
Whenever both are presents, the environment variable is used.
"""
# Defer import to avoid circular import
from sweagent import PACKAGE_DIR
self._keys_cfg = None
keys_cfg_path = PACKAGE_DIR / "keys.cfg"
if keys_cfg_path.exists():
try:
self._keys_cfg = config_file.Config(PACKAGE_DIR / "keys.cfg")
except Exception as e:
raise RuntimeError(f"Error loading keys.cfg. Please check the file.") from e
else:
logger.error(f"keys.cfg not found in {PACKAGE_DIR}")

def get(self, key, default=None) -> Any:
if key in os.environ:
return os.environ[key]
if self._keys_cfg is not None and key in self._keys_cfg:
return self._keys_cfg[key]
return default

def __getitem__(self, key: str) -> Any:
if key in os.environ:
return os.environ[key]
if self._keys_cfg is not None and key in self._keys_cfg:
return self._keys_cfg[key]
raise KeyError(f"Key {key} not found in environment or keys.cfg")

def __contains__(self, key: str) -> bool:
return key in os.environ or (self._keys_cfg is not None and key in self._keys_cfg)

0 comments on commit 521ff85

Please sign in to comment.