Skip to content

Commit

Permalink
refactor & improve retry for the reliability of RemoteRuntime & eva…
Browse files Browse the repository at this point in the history
…luation (#3846)
  • Loading branch information
xingyaoww authored Sep 13, 2024
1 parent 7506b20 commit 78c5f58
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 80 deletions.
2 changes: 2 additions & 0 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def initialize_runtime(

# inject the instance info
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -233,6 +234,7 @@ def initialize_runtime(
), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
102 changes: 73 additions & 29 deletions evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import subprocess
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Awaitable, Callable

Expand Down Expand Up @@ -77,6 +78,12 @@ def model_dump_json(self, *args, **kwargs):
return json.dumps(dumped_dict)


class EvalError(BaseModel):
instance_id: str
error: str
stacktrace: str


def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
Expand Down Expand Up @@ -227,6 +234,20 @@ def prepare_dataset(
return pd.DataFrame(new_dataset)


def process_instance(
instance, metadata, use_multiprocessing, process_instance_func
) -> EvalOutput | EvalError:
try:
return process_instance_func(instance, metadata, use_multiprocessing)
except Exception as e:
logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
return EvalError(
instance_id=instance.instance_id,
error=str(e),
stacktrace=traceback.format_exc(),
)


def run_evaluation(
dataset: pd.DataFrame,
metadata: EvalMetadata,
Expand All @@ -241,42 +262,65 @@ def run_evaluation(
f'Evaluation started with Agent {metadata.agent_class}:\n'
f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
)
pbar = tqdm(total=len(dataset))
output_fp = open(output_file, 'a')

def update_progress(future):
pbar.update(1)
output: EvalOutput = future.result() if use_multiprocessing else future
instance_queue = mp.Queue()
for _, instance in dataset.iterrows():
instance_queue.put(instance)

pbar.set_description(f'Instance {output.instance_id}')
pbar.set_postfix_str(f'Test Result: {output.test_result}')
logger.info(
f'Finished evaluation for instance {output.instance_id}: {str(output.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(output.model_dump()) + '\n')
output_fp.flush()
total_instances = instance_queue.qsize()
pbar = tqdm(total=total_instances, desc='Instances processed')
output_fp = open(output_file, 'a')

def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
if isinstance(result, EvalOutput):
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
else:
logger.error(
f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
+ '\n'
+ '-' * 10
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+ '-' * 10
+ '\n'
)
instance_queue.put(instance)
pbar.total += 1
pbar.refresh()

try:
if use_multiprocessing:
with ProcessPoolExecutor(num_workers) as executor:
futures = []
for _, instance in dataset.iterrows():
future = executor.submit(
process_instance_func,
instance,
metadata,
bool(num_workers > 1),
)
future.add_done_callback(update_progress)
futures.append(future)
for future in futures:
future.result()
# Use plain for loop for single process for easier debugging
while not instance_queue.empty():
futures = []
for _ in range(min(num_workers, instance_queue.qsize())):
instance = instance_queue.get()
future = executor.submit(
process_instance,
instance,
metadata,
True,
process_instance_func,
)
future.add_done_callback(
lambda f, inst=instance: update_progress(f.result(), inst)
)
futures.append(future)
for future in futures:
future.result()
else:
assert num_workers == 1
for _, instance in dataset.iterrows():
output = process_instance_func(instance, metadata, False)
update_progress(output)
while not instance_queue.empty():
instance = instance_queue.get()
result = process_instance(
instance, metadata, False, process_instance_func
)
update_progress(result, instance)

except KeyboardInterrupt:
print('\nKeyboardInterrupt received. Cleaning up...\n')
Expand Down
21 changes: 13 additions & 8 deletions openhands/runtime/builder/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import RuntimeBuilder
from openhands.runtime.utils.request import send_request


class RemoteRuntimeBuilder(RuntimeBuilder):
Expand All @@ -15,6 +16,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
def __init__(self, api_url: str, api_key: str):
self.api_url = api_url
self.api_key = api_key
self.session = requests.Session()
self.session.headers.update({'X-API-Key': self.api_key})

def build(self, path: str, tags: list[str]) -> str:
"""Builds a Docker image using the Runtime API's /build endpoint."""
Expand All @@ -38,8 +41,9 @@ def build(self, path: str, tags: list[str]) -> str:
files.append(('tags', (None, tag)))

# Send the POST request to /build
headers = {'X-API-Key': self.api_key}
response = requests.post(f'{self.api_url}/build', files=files, headers=headers)
response = send_request(
self.session, 'POST', f'{self.api_url}/build', files=files
)

if response.status_code != 202:
logger.error(f'Build initiation failed: {response.text}')
Expand All @@ -57,10 +61,11 @@ def build(self, path: str, tags: list[str]) -> str:
logger.error('Build timed out after 30 minutes')
raise RuntimeError('Build timed out after 30 minutes')

status_response = requests.get(
status_response = send_request(
self.session,
'GET',
f'{self.api_url}/build_status',
params={'build_id': build_id},
headers=headers,
)

if status_response.status_code != 200:
Expand Down Expand Up @@ -90,14 +95,14 @@ def build(self, path: str, tags: list[str]) -> str:
raise RuntimeError(error_message)

# Wait before polling again
time.sleep(5)
time.sleep(30)

def image_exists(self, image_name: str) -> bool:
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""
params = {'image': image_name}
session = requests.Session()
session.headers.update({'X-API-Key': self.api_key})
response = session.get(f'{self.api_url}/image_exists', params=params)
response = send_request(
self.session, 'GET', f'{self.api_url}/image_exists', params=params
)

if response.status_code != 200:
logger.error(f'Failed to check image existence: {response.text}')
Expand Down
78 changes: 35 additions & 43 deletions openhands/runtime/remote/runtime.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import ssl
import tempfile
import threading
import uuid
from typing import Any, Type
from zipfile import ZipFile

import requests
from requests.exceptions import HTTPError, RequestException, Timeout
from requests.exceptions import Timeout
from tenacity import (
retry,
retry_if_exception_type,
Expand Down Expand Up @@ -37,15 +35,13 @@
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.request import (
DEFAULT_RETRY_EXCEPTIONS,
is_404_error,
send_request,
)
from openhands.runtime.utils.runtime_build import build_runtime_image

DEFAULT_RETRY_EXCEPTIONS = [
ssl.SSLCertVerificationError,
RequestException,
HTTPError,
Timeout,
]


class RemoteRuntime(Runtime):
"""This runtime will connect to a remote od-runtime-client."""
Expand Down Expand Up @@ -99,7 +95,7 @@ def __init__(
self.container_image: str = self.config.sandbox.base_container_image
self.container_name = 'od-remote-runtime-' + self.instance_id
logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
response = self._send_request('GET', f'{self.api_url}/registry_prefix')
response = send_request(self.session, 'GET', f'{self.api_url}/registry_prefix')
response_json = response.json()
registry_prefix = response_json['registry_prefix']
os.environ['OD_RUNTIME_RUNTIME_IMAGE_REPO'] = (
Expand All @@ -122,7 +118,8 @@ def __init__(
)

# Use the /image_exists endpoint to check if the image exists
response = self._send_request(
response = send_request(
self.session,
'GET',
f'{self.api_url}/image_exists',
params={'image': self.container_image},
Expand Down Expand Up @@ -157,8 +154,8 @@ def __init__(
}

# Start the sandbox using the /start endpoint
response = self._send_request(
'POST', f'{self.api_url}/start', json=start_request
response = send_request(
self.session, 'POST', f'{self.api_url}/start', json=start_request
)
if response.status_code != 201:
raise RuntimeError(f'Failed to start sandbox: {response.text}')
Expand All @@ -184,29 +181,6 @@ def __init__(
self.runtime_url is not None
), 'Runtime URL is not set. This should never happen.'

def _send_request(
self,
method: str,
url: str,
retry_exceptions: list[Type[Exception]] | None = None,
**kwargs: Any,
) -> requests.Response:
if retry_exceptions is None:
retry_exceptions = DEFAULT_RETRY_EXCEPTIONS

@retry(
stop=stop_after_attempt(30),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(tuple(retry_exceptions)),
reraise=True,
)
def _send_request_with_retry():
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response

return _send_request_with_retry()

@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=60),
Expand All @@ -215,7 +189,15 @@ def _send_request_with_retry():
)
def _wait_until_alive(self):
logger.info('Waiting for sandbox to be alive...')
response = self._send_request('GET', f'{self.runtime_url}/alive')
response = send_request(
self.session,
'GET',
f'{self.runtime_url}/alive',
# Retry 404 errors for the /alive endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
)
if response.status_code != 200:
msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
logger.warning(msg)
Expand All @@ -228,8 +210,11 @@ def sandbox_workspace_dir(self):
def close(self):
if self.runtime_id:
try:
response = self._send_request(
'POST', f'{self.api_url}/stop', json={'runtime_id': self.runtime_id}
response = send_request(
self.session,
'POST',
f'{self.api_url}/stop',
json={'runtime_id': self.runtime_id},
)
if response.status_code != 200:
logger.error(f'Failed to stop sandbox: {response.text}')
Expand Down Expand Up @@ -262,14 +247,19 @@ def run_action(self, action: Action) -> Observation:
logger.info('Executing action')
request_body = {'action': event_to_dict(action)}
logger.debug(f'Request body: {request_body}')
response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/execute_action',
json=request_body,
timeout=action.timeout,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
# Retry 404 errors for the /execute_action endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
)
if response.status_code == 200:
output = response.json()
Expand Down Expand Up @@ -335,7 +325,8 @@ def copy_to(

params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}

response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/upload_file',
files=upload_data,
Expand Down Expand Up @@ -368,7 +359,8 @@ def list_files(self, path: str | None = None) -> list[str]:
if path is not None:
data['path'] = path

response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/list_files',
json=data,
Expand Down
Loading

0 comments on commit 78c5f58

Please sign in to comment.