|
1 | 1 | import logging
|
| 2 | +import multiprocessing as mp |
2 | 3 | import os
|
| 4 | +import traceback |
| 5 | +import typing |
3 | 6 | from typing import Literal
|
4 | 7 |
|
5 |
| -import gymnasium as gym |
6 | 8 | import numpy as np
|
7 | 9 |
|
8 | 10 | from browsergym.experiments.loop import SEED_MAX, EnvArgs
|
@@ -206,35 +208,73 @@ def prepare_backend(backend: str):
|
206 | 208 | raise NotImplementedError(f"Unknown benchmark backend {repr(backend)}")
|
207 | 209 |
|
208 | 210 |
|
209 |
| -def massage_tasks(task_ids: list[str], max_retries: int = 1): |
| 211 | +def massage_tasks(task_ids: list[str], max_retries: int = 1, timeout: int = 60): |
210 | 212 | for i, task_id in enumerate(task_ids):
|
211 |
| - gym_id = f"browsergym/{task_id}" |
212 |
| - logger.info(f"Massaging task {i + 1} / {len(task_ids)}: {gym_id}") |
213 |
| - task_retries = 0 |
214 |
| - while True: |
215 |
| - env = gym.make(gym_id) |
216 |
| - try: |
217 |
| - env.reset() # task setup |
218 |
| - try: |
219 |
| - no_action = "noop()" |
220 |
| - # check if action space exists and is compatible with "noop()" |
221 |
| - env.unwrapped.action_mapping(no_action) |
222 |
| - except: |
223 |
| - # fallback plan |
224 |
| - no_action = "" |
225 |
| - env.step(no_action) # task validation |
226 |
| - env.step(no_action) # task validation again |
227 |
| - logger.info(f"Massage successful") |
| 213 | + logger.info(f"Massaging task {i + 1} / {len(task_ids)}: {task_id}") |
| 214 | + for retries in range(max_retries + 1): |
| 215 | + outcome, err_msg = massage_task_within_subprocess(task_id=task_id, timeout=timeout) |
| 216 | + if outcome == "success": |
228 | 217 | break
|
229 |
| - except Exception as e: |
230 |
| - if task_retries < max_retries: |
231 |
| - task_retries += 1 |
232 |
| - logger.info(f"Massage failed, retrying ({task_retries} / {max_retries})") |
233 |
| - continue |
234 |
| - else: |
235 |
| - logger.warning( |
236 |
| - f"Error during task massage after {task_retries} retries ({gym_id}): {e}" |
237 |
| - ) |
238 |
| - break |
239 |
| - finally: |
240 |
| - env.close() |
| 218 | + if retries < max_retries: |
| 219 | + logger.info( |
| 220 | + f"Massage resulted in {outcome}, retrying ({retries + 1} / {max_retries} retries)" |
| 221 | + ) |
| 222 | + else: |
| 223 | + logger.warning( |
| 224 | + f"Massage unsuccessful after {retries} retries, skipping. Last error message: {err_msg}" |
| 225 | + ) |
| 226 | + |
| 227 | + |
| 228 | +def massage_task_within_subprocess( |
| 229 | + task_id: str, timeout: int, kill_timeout: int = 10 |
| 230 | +) -> typing.Tuple[str, str]: |
| 231 | + """Massages a BrowserGym task (reset, noop, noop) inside a subprocess to monitor execution |
| 232 | + times and kill the process after a timeout. |
| 233 | +
|
| 234 | + Returns: an (outcome, err_msg) tuple. |
| 235 | + - outcome: the outcome of the massage, one of 'success', 'exception' or 'timeout'. |
| 236 | + - err_msg: error message if any, or None. |
| 237 | + """ |
| 238 | + |
| 239 | + def run_massage(outcome_queue: mp.Queue): |
| 240 | + import gymnasium as gym |
| 241 | + |
| 242 | + gym_id = f"browsergym/{task_id}" |
| 243 | + env = gym.make(gym_id) |
| 244 | + no_action = "noop()" |
| 245 | + # check if action space exists and is compatible with "noop()" |
| 246 | + try: |
| 247 | + env.unwrapped.action_mapping(no_action) |
| 248 | + except: |
| 249 | + no_action = "" # fallback plan |
| 250 | + # run massage |
| 251 | + try: |
| 252 | + env.reset() # task setup |
| 253 | + env.step(no_action) # task validation |
| 254 | + env.step(no_action) # task validation again |
| 255 | + outcome = "success", None |
| 256 | + except Exception as e: |
| 257 | + outcome = "exception", traceback.format_exception(e) |
| 258 | + finally: |
| 259 | + env.close() |
| 260 | + outcome_queue.put(outcome) |
| 261 | + |
| 262 | + queue = mp.Queue() |
| 263 | + process = mp.Process(target=run_massage, args=queue) |
| 264 | + process.start() |
| 265 | + process.join(timeout=timeout) |
| 266 | + |
| 267 | + if process.is_alive(): |
| 268 | + # if the process is still alive after the timeout |
| 269 | + outcome = "timeout", f"Timeout {timeout} seconds exceeded" |
| 270 | + process.kill() |
| 271 | + process.join(timeout=kill_timeout) |
| 272 | + if process.is_alive(): |
| 273 | + # if the process is still alive after the kill |
| 274 | + logger.warning( |
| 275 | + f"Massage sub-process still alive {kill_timeout} seconds after kill(), you might have a zombie process now." |
| 276 | + ) |
| 277 | + else: |
| 278 | + outcome = queue.get_nowait() |
| 279 | + |
| 280 | + return outcome |
0 commit comments