Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor the logic in agent_controller to imporve readability #3873

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 176 additions & 145 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,56 +172,83 @@ async def on_event(self, event: Event):
Args:
event (Event): The incoming event to process.
"""
if isinstance(event, ChangeAgentStateAction):
await self.set_agent_state_to(event.agent_state) # type: ignore
elif isinstance(event, MessageAction):
if event.source == EventSource.USER:
logger.info(
event,
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif event.source == EventSource.AGENT and event.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
elif isinstance(event, AgentDelegateAction):
await self.start_delegate(event)
elif isinstance(event, AddTaskAction):
self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
elif isinstance(event, ModifyTaskAction):
self.state.root_task.set_subtask_state(event.task_id, event.state)
elif isinstance(event, AgentFinishAction):
self.state.outputs = event.outputs
if isinstance(event, Action):
await self._handle_action(event)
elif isinstance(event, Observation):
await self._handle_observation(event)

async def _handle_action(self, action: Action):
"""Handles actions from the event stream.

Args:
action (Action): The action to handle.
"""
if isinstance(action, ChangeAgentStateAction):
await self.set_agent_state_to(action.agent_state) # type: ignore
elif isinstance(action, MessageAction):
await self._handle_message_action(action)
elif isinstance(action, AgentDelegateAction):
await self.start_delegate(action)
elif isinstance(action, AddTaskAction):
self.state.root_task.add_subtask(
action.parent, action.goal, action.subtasks
)
elif isinstance(action, ModifyTaskAction):
self.state.root_task.set_subtask_state(action.task_id, action.state)
elif isinstance(action, AgentFinishAction):
self.state.outputs = action.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.FINISHED)
elif isinstance(event, AgentRejectAction):
self.state.outputs = event.outputs
elif isinstance(action, AgentRejectAction):
self.state.outputs = action.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(event, Observation):
if (
self._pending_action
and hasattr(self._pending_action, 'is_confirmed')
and self._pending_action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return
if self._pending_action and self._pending_action.id == event.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, CmdOutputObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, AgentDelegateObservation):
self.state.history.on_event(event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, ErrorObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)

async def _handle_observation(self, observation: Observation):
"""Handles observation from the event stream.

Args:
observation (observation): The observation to handle.
"""
if (
self._pending_action
and hasattr(self._pending_action, 'is_confirmed')
and self._pending_action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return

logger.info(observation, extra={'msg_type': 'OBSERVATION'})
if self._pending_action and self._pending_action.id == observation.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
return

if isinstance(observation, CmdOutputObservation):
return
elif isinstance(observation, AgentDelegateObservation):
self.state.history.on_event(observation)
elif isinstance(observation, ErrorObservation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)

async def _handle_message_action(self, action: MessageAction):
"""Handles message actions from the event stream.

Args:
action (MessageAction): The message action to handle.
"""
if action.source == EventSource.USER:
logger.info(
action, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)

def reset_task(self):
"""Resets the agent's task."""
Expand All @@ -242,9 +269,11 @@ async def set_agent_state_to(self, new_state: AgentState):
if new_state == self.state.agent_state:
return

if (
self.state.agent_state == AgentState.PAUSED
and new_state == AgentState.RUNNING
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()
elif (
new_state == AgentState.RUNNING
and self.state.agent_state == AgentState.PAUSED
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
Expand All @@ -257,19 +286,15 @@ async def set_agent_state_to(self, new_state: AgentState):
):
if self.state.iteration >= self.state.max_iterations:
self.state.max_iterations += self._initial_max_iterations

if (
self.state.metrics.accumulated_cost is not None
and self.max_budget_per_task is not None
and self._initial_max_budget_per_task is not None
):
if self.state.metrics.accumulated_cost >= self.max_budget_per_task:
self.max_budget_per_task += self._initial_max_budget_per_task

self.state.agent_state = new_state
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()

if self._pending_action is not None and (
elif self._pending_action is not None and (
new_state == AgentState.USER_CONFIRMED
or new_state == AgentState.USER_REJECTED
):
Expand All @@ -281,6 +306,7 @@ async def set_agent_state_to(self, new_state: AgentState):
self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined]
self.event_stream.add_event(self._pending_action, EventSource.AGENT)

self.state.agent_state = new_state
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
)
Expand Down Expand Up @@ -355,107 +381,29 @@ async def _step(self) -> None:
return

if self.delegate is not None:
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
assert self.delegate != self
await self.delegate._step()
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
logger.debug(
f'[Agent Controller {self.id}] Delegate state: {delegate_state}'
)
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration

# close the delegate upon error
await self.delegate.close()
self.delegate = None
self.delegateAction = None

await self.report_error('Delegator agent encounters an error')
return
delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
if delegate_done:
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
)
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}

# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration

# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()

# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)

# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
await self._delegate_step()
return

logger.info(
f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
extra={'msg_type': 'STEP'},
)

# check if agent hit the resources limit
stop_step = False
if self.state.iteration >= self.state.max_iterations:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
'Agent reached maximum number of iterations in headless mode, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
elif self.max_budget_per_task is not None:
stop_step = await self._handle_traffic_control(
'iteration', self.state.iteration, self.state.max_iterations
)
if self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# there is no way to resume
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
stop_step = await self._handle_traffic_control(
'budget', current_cost, self.max_budget_per_task
)
if stop_step:
return

self.update_state_before_step()
action: Action = NullAction()
Expand Down Expand Up @@ -492,6 +440,89 @@ async def _step(self) -> None:
await self.report_error('Agent got stuck in a loop')
await self.set_agent_state_to(AgentState.ERROR)

async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
await self.delegate._step() # type: ignore[union-attr]
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
logger.debug(f'[Agent Controller {self.id}] Delegate state: {delegate_state}')
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration

# close the delegate upon error
await self.delegate.close()
self.delegate = None
self.delegateAction = None

await self.report_error('Delegator agent encounters an error')
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
)
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}

# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration

# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()

# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)

# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
return

async def _handle_traffic_control(
self, limit_type: str, current_value: float, max_value: float
):
"""Handles agent state after hitting the traffic control limit.

Args:
limit_type (str): The type of limit that was hit.
current_value (float): The current value of the limit.
max_value (float): The maximum value of the limit.
"""
stop_step = False
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info('Hitting traffic control, temporarily resume upon user request')
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Agent reached maximum {limit_type}, task paused. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
f'{TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
stop_step = True
return stop_step

def get_state(self):
"""Returns the current running state object.

Expand Down
Loading
Loading