Skip to content

Commit

Permalink
support handcrafted exp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu0307 committed Jan 3, 2024
1 parent 2fbf7ad commit fc9c30d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 32 deletions.
4 changes: 2 additions & 2 deletions scripts/experience_mgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(self):
)
app_injector.binder.bind(AppConfigSource, to=app_config)
self.experience_generator = app_injector.create_object(ExperienceGenerator)
self.experience_generator.summarize_experience_in_batch(args.target_role, refresh=False)
self.experience_generator.load_experience(args.target_role)

def refresh(self):
self.experience_generator.summarize_experience_in_batch(args.target_role, refresh=True)
self.experience_generator.refresh(args.target_role)
print("Refreshed experience list")

def delete_experience(self, session_id: str):
Expand Down
109 changes: 85 additions & 24 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os
import warnings
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple

import numpy as np
from injector import inject
Expand All @@ -20,7 +20,7 @@ class Experience:
session_id: str
raw_experience_path: Optional[str] = None
embedding_model: Optional[str] = None
embedding: Optional[List[float]] = None
embedding: List[float] = field(default_factory=list)

def to_dict(self):
return {
Expand All @@ -31,6 +31,16 @@ def to_dict(self):
"embedding": self.embedding,
}

@staticmethod
def from_dict(d: Dict[str, Any]):
return Experience(
session_id=d["session_id"],
experience_text=d["experience_text"],
raw_experience_path=d["raw_experience_path"],
embedding_model=d["embedding_model"] if "embedding_model" in d else None,
embedding=d["embedding"] if "embedding" in d else [],
)


class ExperienceConfig(ModuleConfig):
def _configure(self) -> None:
Expand Down Expand Up @@ -113,51 +123,67 @@ def summarize_experience(

return summarized_experience

def summarize_experience_in_batch(
def refresh(
self,
target_role: Literal["Planner", "CodeInterpreter"],
prompt: Optional[str] = None,
refresh: bool = False,
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")

exp_files = os.listdir(self.config.experience_dir)
session_ids = [
conv_session_ids = [
os.path.splitext(os.path.basename(exp_file))[0].split("_")[2]
for exp_file in exp_files
if exp_file.startswith("raw_exp")
]

handcrafted_session_ids = [
os.path.splitext(os.path.basename(exp_file))[0].split("_")[2]
for exp_file in exp_files
if exp_file.startswith("handcrafted_exp")
]

session_ids = conv_session_ids + handcrafted_session_ids

if len(session_ids) == 0:
warnings.warn("No experience found. Please type #SAVE AS EXP in the chat window to save experience.")
warnings.warn(
"No raw experience found. "
"Please type #SAVE AS EXP in the chat window to save raw experience"
"or write handcrafted experience.",
)
return

if refresh:
self.experience_list = []
for session_id in session_ids:
self.delete_experience(session_id, target_role)

to_be_embedded = []
for idx, session_id in enumerate(session_ids):
exp_file_name = f"{target_role}_exp_{session_id}.yaml"
# if the experience file already exists, load it
# if the experience file already exists and the embedding is valid, skip
if exp_file_name in os.listdir(self.config.experience_dir):
exp_file_path = os.path.join(self.config.experience_dir, exp_file_name)
experience = read_yaml(exp_file_path)
experience_obj = Experience(**experience)
self.experience_list.append(experience_obj)
self.logger.info(f"Experience {exp_file_name} loaded.")
if (
experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model
and len(experience["embedding"]) > 0
):
continue
else:
# otherwise, summarize the experience and save it
summarized_experience = self.summarize_experience(session_id, prompt, target_role)
experience_obj = Experience(
experience_text=summarized_experience,
session_id=session_id,
raw_experience_path=os.path.join(
if session_id in conv_session_ids:
summarized_experience = self.summarize_experience(session_id, prompt, target_role)
experience_obj = Experience(
experience_text=summarized_experience,
session_id=session_id,
raw_experience_path=os.path.join(
self.config.experience_dir,
f"raw_exp_{session_id}.yaml",
),
)
else:
handcrafted_exp_file_path = os.path.join(
self.config.experience_dir,
f"raw_exp_{session_id}.yaml",
),
)
f"handcrafted_exp_{session_id}.yaml",
)
experience_obj = Experience.from_dict(read_yaml(handcrafted_exp_file_path))
self.experience_list.append(experience_obj)
to_be_embedded.append(idx)
self.logger.info("Experience created. Experience files number: {}".format(len(session_ids)))
Expand All @@ -177,6 +203,41 @@ def summarize_experience_in_batch(
write_yaml(experience_file_path, exp.to_dict())
self.logger.info("Experience obj saved.")

def load_experience(
self,
target_role: Literal["Planner", "CodeInterpreter"],
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")

exp_files = [
exp_file
for exp_file in os.listdir(self.config.experience_dir)
if exp_file.startswith(f"{target_role}_exp_")
]
if len(exp_files) == 0:
warnings.warn(
f"No experience found for {target_role}."
f" Please type #SAVE AS EXP in the chat window to save experience.",
)
return

for exp_file in exp_files:
exp_file_path = os.path.join(self.config.experience_dir, exp_file)
experience = read_yaml(exp_file_path)
if (
experience["embedding_model"] != self.llm_api.embedding_service.config.embedding_model
or len(experience["embedding"]) == 0
):
raise ValueError(
"The embedding model of the experience is not the same as the current one."
"Please re-summarize and generatr embedding for the experience.",
"Please cd to the `script` directory and "
"run `python -m experience_mgt --refresh` to refresh the experience.",
)
else:
self.experience_list.append(Experience(**experience))

def retrieve_experience(self, user_query: str) -> List[Tuple[Experience, float]]:
user_query_embedding = np.array(self.llm_api.get_embedding(user_query))

Expand Down
5 changes: 1 addition & 4 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def __init__(
if self.config.use_experience:
self.experience_generator = experience_generator
self.experience_prompt_template = read_yaml(self.config.exp_prompt_path)["content"]
self.experience_generator.summarize_experience_in_batch(
prompt=self.experience_prompt_template,
target_role="Planner",
)
self.experience_generator.load_experience(target_role="Planner")

self.logger.info("Planner initialized successfully")

Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/test_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_experience_generation():
app_injector.binder.bind(AppConfigSource, to=app_config)
experience_manager = app_injector.create_object(ExperienceGenerator)

experience_manager.summarize_experience_in_batch(target_role="Planner", refresh=True)
experience_manager.refresh(target_role="Planner")
experience_manager.load_experience(target_role="Planner")

exp_files = os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/experience"))
assert len(exp_files) == 2
Expand Down Expand Up @@ -98,7 +99,7 @@ def test_experience_retrieval():

user_query = "show top 10 data in ./data.csv"

experience_manager.summarize_experience_in_batch(target_role="Planner")
experience_manager.load_experience(target_role="Planner")

assert len(experience_manager.experience_list) == 1
exp = experience_manager.experience_list[0]
Expand Down

0 comments on commit fc9c30d

Please sign in to comment.