Skip to content

Commit 4ae656b

Browse files
committed
[WIP] Define framework
1 parent 6d1be33 commit 4ae656b

9 files changed

+704
-0
lines changed

refactor_demo/core/benchmark.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from abc import ABC, abstractmethod
15+
from typing import Iterable
16+
17+
from .environment import Environment
18+
from .evaluator import Evaluator
19+
20+
21+
class Benchmark(ABC):
22+
@abstractmethod
23+
def get_corresponding_env(self) -> Environment:
24+
pass
25+
26+
@abstractmethod
27+
def get_task_by_id(self, id: str) -> str:
28+
pass
29+
30+
@abstractmethod
31+
def tasks(self) -> Iterable[tuple[str, Evaluator]]:
32+
pass

refactor_demo/core/environment.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from abc import ABC, abstractmethod
15+
from typing import Any
16+
17+
import gymnasium as gym
18+
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
19+
20+
21+
class Environment(gym.Env, ABC):
22+
"""The base environment class for agents to interact with in the CRAB framework.
23+
24+
Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class
25+
for all environments in the CRAB. Your must implement two functions
26+
`get_action_schema` and `convert_tool_call_to_action` to make the environment
27+
compatible with OpenAI tool use API.
28+
"""
29+
30+
@abstractmethod
31+
def get_description(self) -> str:
32+
"""Get the description of the environment, which can be used as a part of the
33+
agent prompt.
34+
35+
Returns:
36+
A string description of the environment.
37+
"""
38+
39+
@abstractmethod
40+
def get_action_schema(self) -> list[ChatCompletionToolParam]:
41+
"""Get the tool schema for the action space of the environment.
42+
43+
The schema provides detailed descriptions of the whole actions space and their
44+
parameters that represent all the possible actions in the tool calling format,
45+
which can be directly used in the OpenAI API. It should be comprehensive and do
46+
not produce any misunderstanding for a human user.
47+
48+
Returns:
49+
A list of tool schema.
50+
"""
51+
...
52+
53+
@abstractmethod
54+
def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:
55+
"""Convert a tool call to the actual action space in the environment.
56+
57+
Args:
58+
tool_name: The name of the tool.
59+
parameters: The parameters of the tool call.
60+
"""
61+
...

refactor_demo/core/evaluator.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from abc import ABC, abstractmethod
15+
16+
17+
class Evaluator(ABC):
18+
@abstractmethod
19+
def step(self, environment: Environment, task: Task) -> Any:
20+
pass

refactor_demo/core/policy.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========

refactor_demo/core/task.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from abc import ABC, abstractmethod
15+
from pydantic import BaseModel, ConfigDict
16+
from typing import Any
17+
18+
from .environment import Environment
19+
20+
21+
class Task(BaseModel):
22+
model_config = ConfigDict(arbitrary_types_allowed=True)
23+
id: str
24+
description: str
25+
evaluator: Evaluator
26+
setup: setup = []
27+
extra_action: list[Action] = []

refactor_demo/core/task_wrapper.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from typing import Any
15+
16+
import gymnasium as gym
17+
from gymnasium import Wrapper
18+
from gymnasium.core import ActType, ObsType, WrapperObsType
19+
from gymnasium.spaces import Dict, Space, Text, Tuple
20+
21+
22+
class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
23+
def __init__(
24+
self,
25+
env: gym.Env[ObsType, ActType],
26+
task: Task,
27+
*,
28+
dict_task_key: str = "task",
29+
):
30+
super().__init__(env)
31+
self.env = env
32+
self.task = task
33+
34+
task_space = Text(500)
35+
36+
# Observation space in different situations
37+
if isinstance(env.observation_space, Dict):
38+
assert dict_task_key not in env.observation_space.keys()
39+
observation_space = Dict(
40+
{dict_task_key: task_space, **env.observation_space.spaces}
41+
)
42+
self._append_data_func = lambda obs, task: {dict_task_key: task, **obs}
43+
elif isinstance(env.observation_space, Tuple):
44+
observation_space = Tuple(env.observation_space.spaces + (task_space,))
45+
self._append_data_func = lambda obs, task: obs + (task,)
46+
else:
47+
observation_space = Dict(obs=env.observation_space, task=task_space)
48+
self._append_data_func = lambda obs, task: {"obs": obs, "task": task}
49+
50+
self.observation_space: gym.Space[WrapperObsType] = observation_space
51+
52+
def reset(
53+
self, *, seed: int | None = None, options: dict[str, Any] | None = None
54+
) -> tuple[Dict, dict[str, Any]]:
55+
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified
56+
observation using :meth:`self.observation`."""
57+
obs, info = self.env.reset(seed=seed, options=options)
58+
return self.observation(obs), info
59+
60+
def step(
61+
self, action: ActType
62+
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
63+
observation, reward, terminal, truncated, info = self.step(action)
64+
reward = self.task.evaluate(self.env)
65+
return self.observation(observation), reward, terminal, truncated, info
66+
67+
def observation(self, observation: ObsType):
68+
"""Returns a modified observation.
69+
70+
Args:
71+
observation: The :attr:`env` observation
72+
73+
Returns:
74+
The modified observation
75+
"""
76+
return self._append_data_func(observation, self.task.description)

refactor_demo/core/workflow.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from abc import ABC, abstractmethod
15+
16+
17+
class Workflow(ABC):
18+
@abstractmethod
19+
def step(self):
20+
pass

refactor_demo/envs/multi_env.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import gymnasium as gym
15+
import numpy as np
16+
from gymnasium import spaces
17+
18+
19+
class MultiEnv(gym.Env):
20+
def __init__(self, envs):
21+
"""
22+
Initialize the MultiEnv environment.
23+
24+
Args:
25+
envs (list): A list of gymnasium environments to integrate.
26+
"""
27+
super().__init__()
28+
29+
# Store the environments
30+
self.envs = envs
31+
32+
# Create action space using OneOf with the action spaces of each environment
33+
self.action_space = spaces.OneOf([env.action_space for env in envs])
34+
35+
# Create observation space as a Dict space containing each environment's observation space
36+
self.observation_space = spaces.Dict(
37+
{f"env_{i}": env.observation_space for i, env in enumerate(envs)}
38+
)
39+
40+
def reset(self):
41+
"""
42+
Reset all environments and return initial observations.
43+
44+
Returns:
45+
dict: A dictionary with initial observations from each environment.
46+
"""
47+
observations = {}
48+
for i, env in enumerate(self.envs):
49+
observations[f"env_{i}"], _ = env.reset()
50+
return observations
51+
52+
def step(self, action):
53+
"""
54+
Take a step in the selected environment based on the action.
55+
56+
Args:
57+
action (int): The index of the environment to take a step in.
58+
59+
Returns:
60+
tuple: A tuple containing the observations, rewards, done flags, and info.
61+
"""
62+
assert 0 <= action < len(self.envs), "Invalid action for environment selection."
63+
64+
# Initialize dictionaries to store results
65+
observations = {}
66+
rewards = {}
67+
dones = {}
68+
infos = {}
69+
70+
# Perform a step in the selected environment
71+
obs, reward, done, truncated, info = self.envs[action].step(action)
72+
73+
# Populate results for the selected environment
74+
observations[f"env_{action}"] = obs
75+
rewards[f"env_{action}"] = reward
76+
dones[f"env_{action}"] = done
77+
infos[f"env_{action}"] = info
78+
79+
# For other environments, simply pass their previous observations
80+
for i, env in enumerate(self.envs):
81+
if i != action:
82+
observations[f"env_{i}"] = (
83+
None # No new observation for non-acting environments
84+
)
85+
rewards[f"env_{i}"] = 0
86+
dones[f"env_{i}"] = False
87+
infos[f"env_{i}"] = {}
88+
89+
# Set done if all environments are done
90+
all_done = all(dones.values())
91+
92+
return observations, rewards, all_done, infos
93+
94+
def render(self, mode="human"):
95+
"""
96+
Render all environments (optional implementation).
97+
"""
98+
for i, env in enumerate(self.envs):
99+
env.render(mode=mode)
100+
101+
def close(self):
102+
"""
103+
Close all environments.
104+
"""
105+
for env in self.envs:
106+
env.close()

0 commit comments

Comments
 (0)