Skip to content

Commit bfd778d

Browse files
committed
update benchmark
1 parent 56577e8 commit bfd778d

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

gui/gui_experiment.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from pathlib import Path
15+
from typing import Literal
16+
17+
from crab import AgentPolicy, Benchmark, Experiment, MessageType
18+
19+
20+
class GuiExperiment(Experiment):
21+
def __init__(
22+
self,
23+
benchmark: Benchmark,
24+
task_id: str,
25+
agent_policy: AgentPolicy | Literal["human"],
26+
log_dir: Path | None = None,
27+
) -> None:
28+
super().__init__(benchmark, task_id, agent_policy, log_dir)
29+
30+
def get_prompt(self):
31+
observation, ob_prompt = self.benchmark.observe_with_prompt()
32+
33+
# construct prompt
34+
result_prompt = {}
35+
for env in ob_prompt:
36+
if env == "root":
37+
continue
38+
screenshot = observation[env]["screenshot"]
39+
marked_screenshot, _ = ob_prompt[env]["screenshot"]
40+
result_prompt[env] = [
41+
(f"Here is the current screenshot of {env}:", MessageType.TEXT),
42+
(screenshot, MessageType.IMAGE_JPG_BASE64),
43+
(
44+
f"Here is the screenshot with element labels of {env}:",
45+
MessageType.TEXT,
46+
),
47+
(marked_screenshot, MessageType.IMAGE_JPG_BASE64),
48+
]
49+
return result_prompt

gui/main.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
import customtkinter as ctk
1919

20-
from crab import Experiment
2120
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
2221
from crab.agents.policies import SingleAgentPolicy
22+
from gui.gui_experiment import GuiExperiment
2323
from gui.utils import get_benchmark
2424

2525
warnings.filterwarnings("ignore")
@@ -58,7 +58,7 @@ def assign_task():
5858

5959
task_id = str(uuid4())
6060
benchmark = get_benchmark(task_id, task_description)
61-
experiment = Experiment(
61+
experiment = GuiExperiment(
6262
benchmark=benchmark,
6363
task_id=task_id,
6464
agent_policy=agent_policy,
@@ -82,8 +82,6 @@ def display_message(message, sender="user"):
8282

8383

8484
if __name__ == "__main__":
85-
# TODO: Handle JSON decode error from environment action endpoint and
86-
# display model response in GUI
8785
log_dir = (Path(__file__).parent / "logs").resolve()
8886

8987
ctk.set_appearance_mode("System")

0 commit comments

Comments
 (0)