Skip to content

Commit

Permalink
refactor: move get_pairs from memory to shared utils (#4411)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww authored Oct 15, 2024
1 parent 2cf77e2 commit da23189
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 52 deletions.
56 changes: 56 additions & 0 deletions openhands/events/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.empty import NullAction
from openhands.events.event import Event
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.observation import Observation


def get_pairs_from_events(events: list[Event]) -> list[tuple[Action, Observation]]:
"""Return the history as a list of tuples (action, observation)."""
tuples: list[tuple[Action, Observation]] = []
action_map: dict[int, Action] = {}
observation_map: dict[int, Observation] = {}

# runnable actions are set as cause of observations
# (MessageAction, NullObservation) for source=USER
# (MessageAction, NullObservation) for source=AGENT
# (other_action?, NullObservation)
# (NullAction, CmdOutputObservation) background CmdOutputObservations

for event in events:
if event.id is None or event.id == -1:
logger.debug(f'Event {event} has no ID')

if isinstance(event, Action):
action_map[event.id] = event

if isinstance(event, Observation):
if event.cause is None or event.cause == -1:
logger.debug(f'Observation {event} has no cause')

if event.cause is None:
# runnable actions are set as cause of observations
# NullObservations have no cause
continue

observation_map[event.cause] = event

for action_id, action in action_map.items():
observation = observation_map.get(action_id)
if observation:
# observation with a cause
tuples.append((action, observation))
else:
tuples.append((action, NullObservation('')))

for cause_id, observation in observation_map.items():
if cause_id not in action_map:
if isinstance(observation, NullObservation):
continue
if not isinstance(observation, CmdOutputObservation):
logger.debug(f'Observation {observation} has no cause')
tuples.append((NullAction(), observation))

return tuples.copy()
54 changes: 4 additions & 50 deletions openhands/memory/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import EventStream
from openhands.events.utils import get_pairs_from_events


class ShortTermHistory(list[Event]):
Expand Down Expand Up @@ -216,55 +216,9 @@ def on_event(self, event: Event):
def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
history_pairs = []

for action, observation in self.get_pairs():
for action, observation in get_pairs_from_events(
self.get_events_as_list(include_delegates=True)
):
history_pairs.append((event_to_dict(action), event_to_dict(observation)))

return history_pairs

def get_pairs(self) -> list[tuple[Action, Observation]]:
"""Return the history as a list of tuples (action, observation)."""
tuples: list[tuple[Action, Observation]] = []
action_map: dict[int, Action] = {}
observation_map: dict[int, Observation] = {}

# runnable actions are set as cause of observations
# (MessageAction, NullObservation) for source=USER
# (MessageAction, NullObservation) for source=AGENT
# (other_action?, NullObservation)
# (NullAction, CmdOutputObservation) background CmdOutputObservations

for event in self.get_events_as_list(include_delegates=True):
if event.id is None or event.id == -1:
logger.debug(f'Event {event} has no ID')

if isinstance(event, Action):
action_map[event.id] = event

if isinstance(event, Observation):
if event.cause is None or event.cause == -1:
logger.debug(f'Observation {event} has no cause')

if event.cause is None:
# runnable actions are set as cause of observations
# NullObservations have no cause
continue

observation_map[event.cause] = event

for action_id, action in action_map.items():
observation = observation_map.get(action_id)
if observation:
# observation with a cause
tuples.append((action, observation))
else:
tuples.append((action, NullObservation('')))

for cause_id, observation in observation_map.items():
if cause_id not in action_map:
if isinstance(observation, NullObservation):
continue
if not isinstance(observation, CmdOutputObservation):
logger.debug(f'Observation {observation} has no cause')
tuples.append((NullAction(), observation))

return tuples.copy()
23 changes: 21 additions & 2 deletions tests/unit/test_is_stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.stream import EventSource, EventStream
from openhands.events.utils import get_pairs_from_events
from openhands.memory.history import ShortTermHistory
from openhands.storage import get_file_store

Expand Down Expand Up @@ -170,7 +171,16 @@ def test_is_stuck_repeating_action_observation(

assert len(collect_events(event_stream)) == 10
assert len(list(stuck_detector.state.history.get_events())) == 8
assert len(stuck_detector.state.history.get_pairs()) == 5
assert (
len(
get_pairs_from_events(
stuck_detector.state.history.get_events_as_list(
include_delegates=True
)
)
)
== 5
)

assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 1
Expand All @@ -186,7 +196,16 @@ def test_is_stuck_repeating_action_observation(

assert len(collect_events(event_stream)) == 12
assert len(list(stuck_detector.state.history.get_events())) == 10
assert len(stuck_detector.state.history.get_pairs()) == 6
assert (
len(
get_pairs_from_events(
stuck_detector.state.history.get_events_as_list(
include_delegates=True
)
)
)
== 6
)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
Expand Down

0 comments on commit da23189

Please sign in to comment.