diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index 62333662c5..72ed3ccc70 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -13,6 +13,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + assert_and_raise, codeact_user_response, make_metadata, prepare_dataset, @@ -163,14 +164,16 @@ def initialize_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {obs.content}' + ) action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """) action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {obs.content}') if USE_INSTANCE_IMAGE: # inject the init script @@ -182,9 +185,10 @@ def initialize_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert ( - obs.exit_code == 0 - ), f'Failed to create /swe_util/eval_data/instances: {obs.content}' + assert_and_raise( + obs.exit_code == 0, + f'Failed to create /swe_util/eval_data/instances: {obs.content}', + ) swe_instance_json_name = 'swe-bench-instance.json' with tempfile.TemporaryDirectory() as temp_dir: @@ -210,44 +214,53 @@ def initialize_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {obs.content}') action = CmdRunAction(command='source ~/.bashrc') action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, f'Failed to source ~/.bashrc: {obs.content}' + ) action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh') action.timeout = 3600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, + f'Failed to source /swe_util/instance_swe_entry.sh: {obs.content}', + ) else: action = CmdRunAction(command='source /swe_util/swe_entry.sh') action.timeout = 1800 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert ( - obs.exit_code == 0 - ), f'Failed to source /swe_util/swe_entry.sh: {obs.content}' + assert_and_raise( + obs.exit_code == 0, + f'Failed to source /swe_util/swe_entry.sh: {obs.content}', + ) action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}') action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, + f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}', + ) action = CmdRunAction(command='git reset --hard') action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {obs.content}') action = CmdRunAction( command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done' @@ -256,7 +269,7 @@ def initialize_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {obs.content}') logger.info('-' * 30) logger.info('END Runtime Initialization Fn') @@ -284,21 +297,27 @@ def complete_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, + f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}', + ) action = CmdRunAction(command='git config --global core.pager ""') action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise( + obs.exit_code == 0, + f'Failed to git config --global core.pager "": {obs.content}', + ) action = CmdRunAction(command='git add -A') action.timeout = 600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert obs.exit_code == 0 + assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {obs.content}') n_retries = 0 git_patch = None @@ -323,7 +342,7 @@ def complete_runtime( logger.error(f'Error occurred: {obs.content}. Retrying...') sleep_if_should_continue(10) else: - raise ValueError(f'Unexpected observation type: {type(obs)}') + assert_and_raise(False, f'Unexpected observation type: {type(obs)}') logger.info('-' * 30) logger.info('END Runtime Completion Fn') @@ -346,31 +365,32 @@ def process_instance( logger.info(f'Starting evaluation for instance {instance.instance_id}.') runtime = create_runtime(config, sid=instance.instance_id) - initialize_runtime(runtime, instance) - - instruction = get_instruction(instance, metadata) - - # Here's how you can run the agent (similar to the `main` function) and get the final task state - state: State | None = asyncio.run( - run_controller( - config=config, - task_str=instruction, - runtime=runtime, - fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[ - metadata.agent_class - ], + try: + initialize_runtime(runtime, instance) + + instruction = get_instruction(instance, metadata) + + # Here's how you can run the agent (similar to the `main` function) and get the final task state + state: State | None = asyncio.run( + run_controller( + config=config, + task_str=instruction, + runtime=runtime, + fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[ + metadata.agent_class + ], + ) ) - ) - # ======= THIS IS SWE-Bench specific ======= - # Get git patch - return_val = complete_runtime(runtime, instance) - git_patch = return_val['git_patch'] - logger.info( - f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------' - ) - - runtime.close() + # ======= THIS IS SWE-Bench specific ======= + # Get git patch + return_val = complete_runtime(runtime, instance) + git_patch = return_val['git_patch'] + logger.info( + f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------' + ) + finally: + runtime.close() # ========================================== # ======= Attempt to evaluate the agent's edits ======= diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py index 62c3cc0d7e..70e1118313 100644 --- a/evaluation/utils/shared.py +++ b/evaluation/utils/shared.py @@ -86,6 +86,10 @@ def model_dump_json(self, *args, **kwargs): return json.dumps(dumped_dict) +class EvalException(Exception): + pass + + def codeact_user_response( state: State, encapsulate_solution: bool = False, @@ -252,6 +256,15 @@ def update_progress( output_fp.flush() +def assert_and_raise(condition: bool, msg: str): + """Raise an EvalException if the condition is not met. + + This will be used in conjunction with _process_instance_wrapper to handle retries. An EvalException should trigger a retry. + """ + if not condition: + raise EvalException(msg) + + def _process_instance_wrapper( process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput], instance: pd.Series,