Skip to content

Commit

Permalink
eval: improve swebench infer error handling and retry (#4205)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww authored Oct 4, 2024
1 parent 0c2a35b commit 9cc9b19
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 41 deletions.
102 changes: 61 additions & 41 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
assert_and_raise,
codeact_user_response,
make_metadata,
prepare_dataset,
Expand Down Expand Up @@ -163,14 +164,16 @@ def initialize_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {obs.content}'
)

action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {obs.content}')

if USE_INSTANCE_IMAGE:
# inject the init script
Expand All @@ -182,9 +185,10 @@ def initialize_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert (
obs.exit_code == 0
), f'Failed to create /swe_util/eval_data/instances: {obs.content}'
assert_and_raise(
obs.exit_code == 0,
f'Failed to create /swe_util/eval_data/instances: {obs.content}',
)

swe_instance_json_name = 'swe-bench-instance.json'
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -210,44 +214,53 @@ def initialize_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {obs.content}')

action = CmdRunAction(command='source ~/.bashrc')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0, f'Failed to source ~/.bashrc: {obs.content}'
)

action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
action.timeout = 3600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0,
f'Failed to source /swe_util/instance_swe_entry.sh: {obs.content}',
)
else:
action = CmdRunAction(command='source /swe_util/swe_entry.sh')
action.timeout = 1800
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert (
obs.exit_code == 0
), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
assert_and_raise(
obs.exit_code == 0,
f'Failed to source /swe_util/swe_entry.sh: {obs.content}',
)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
)

action = CmdRunAction(command='git reset --hard')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {obs.content}')

action = CmdRunAction(
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
Expand All @@ -256,7 +269,7 @@ def initialize_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {obs.content}')

logger.info('-' * 30)
logger.info('END Runtime Initialization Fn')
Expand Down Expand Up @@ -284,21 +297,27 @@ def complete_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
)

action = CmdRunAction(command='git config --global core.pager ""')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(
obs.exit_code == 0,
f'Failed to git config --global core.pager "": {obs.content}',
)

action = CmdRunAction(command='git add -A')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {obs.content}')

n_retries = 0
git_patch = None
Expand All @@ -323,7 +342,7 @@ def complete_runtime(
logger.error(f'Error occurred: {obs.content}. Retrying...')
sleep_if_should_continue(10)
else:
raise ValueError(f'Unexpected observation type: {type(obs)}')
assert_and_raise(False, f'Unexpected observation type: {type(obs)}')

logger.info('-' * 30)
logger.info('END Runtime Completion Fn')
Expand All @@ -346,31 +365,32 @@ def process_instance(
logger.info(f'Starting evaluation for instance {instance.instance_id}.')

runtime = create_runtime(config, sid=instance.instance_id)
initialize_runtime(runtime, instance)

instruction = get_instruction(instance, metadata)

# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
try:
initialize_runtime(runtime, instance)

instruction = get_instruction(instance, metadata)

# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
)
)
)

# ======= THIS IS SWE-Bench specific =======
# Get git patch
return_val = complete_runtime(runtime, instance)
git_patch = return_val['git_patch']
logger.info(
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
)

runtime.close()
# ======= THIS IS SWE-Bench specific =======
# Get git patch
return_val = complete_runtime(runtime, instance)
git_patch = return_val['git_patch']
logger.info(
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
)
finally:
runtime.close()
# ==========================================

# ======= Attempt to evaluate the agent's edits =======
Expand Down
13 changes: 13 additions & 0 deletions evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def model_dump_json(self, *args, **kwargs):
return json.dumps(dumped_dict)


class EvalException(Exception):
pass


def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
Expand Down Expand Up @@ -252,6 +256,15 @@ def update_progress(
output_fp.flush()


def assert_and_raise(condition: bool, msg: str):
"""Raise an EvalException if the condition is not met.
This will be used in conjunction with _process_instance_wrapper to handle retries. An EvalException should trigger a retry.
"""
if not condition:
raise EvalException(msg)


def _process_instance_wrapper(
process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
instance: pd.Series,
Expand Down

0 comments on commit 9cc9b19

Please sign in to comment.