Skip to content

Commit

Permalink
Typing improvements for summarizer (#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret authored Oct 15, 2024
1 parent 12ce5e5 commit d6770db
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sweagent/agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def __post_init__(self):
"history_processor",
HistoryProcessor.get(self.history_processor, **self.history_processor_args),
)
if "WINDOW" in self.env_variables and self.summarizer_config.window_length is not None:
if "WINDOW" in self.env_variables:
window_size = self.env_variables["WINDOW"]
if self.summarizer_config.window_length < int(window_size):
msg = f"Summarizer window length is set to {self.summarizer_config.window_length} which is less then the defined window length {window_size}"
msg = f"Summarizer window length is set to {self.summarizer_config.window_length} which is less than the window length {window_size}"
raise ValueError(msg)
object.__setattr__(
self,
Expand Down
13 changes: 7 additions & 6 deletions sweagent/agent/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ class SummarizerConfig(FrozenSerializable):
"""The configuration for the summarizer"""

function: str = "Identity"
window_length: int | None = 105
window_length: int = 105
template: str | None = None
model: ModelArguments | None = None
system_template: str | None = None
instance_template: str | None = None

def __post_init__(self):
object.__setattr__(self, "function", SummarizeFunction.get(self.function, self.window_length))
object.__setattr__(self, "function", SummarizeFunction.get(self.function, self.window_length)) # type: ignore
if isinstance(self.model, dict):
object.__setattr__(self, "model", ModelArguments.from_dict(self.summarizer_model))
object.__setattr__(self, "model", ModelArguments.from_dict(self.summarizer_model)) # type: ignore


# ABSTRACT BASE CLASSES
Expand Down Expand Up @@ -59,7 +59,7 @@ class SummarizeFunction(metaclass=SummarizeFunctionMeta):
We use get to generate the right summarizer based on the name of the summarizer.
"""

def __init__(self, window_length: int | None):
def __init__(self, window_length: int):
self._window_length = window_length
self.logger = get_logger("summarizer")

Expand All @@ -74,7 +74,8 @@ def _slugify_action(action: str) -> str:
return "".join(c if c.isalnum() else "_" for c in action)[:50]

@staticmethod
def _upload_file_to_container(file_content: bytes, file_path_on_container: str, env: SWEEnv):
def _upload_file_to_container(file_content: str, file_path_on_container: str, env: SWEEnv):
assert env.container_obj is not None
env.communicate(f'mkdir -p "{Path(file_path_on_container).parent}"')
with tempfile.NamedTemporaryFile() as fp:
fp.write(file_content.encode("utf-8"))
Expand Down Expand Up @@ -123,7 +124,7 @@ class SimpleSummarizer(SummarizeFunction):
"search_dir",
]

def __call__(self, input: str, observation: str, env: SWEEnv, model: type[BaseModel]) -> tuple[str, APIStats]:
def __call__(self, input: str, observation: str, env: SWEEnv, model: BaseModel) -> tuple[str, APIStats]:
try:
if (
any(input.startswith(s) for s in self.block_list_input)
Expand Down

0 comments on commit d6770db

Please sign in to comment.