Skip to content

Commit

Permalink
Liqun/refactor role (#441)
Browse files Browse the repository at this point in the history
refactored the load example/experience functions to abstract a
preparation process.
  • Loading branch information
liqul authored Nov 19, 2024
1 parent 9ddf42c commit 08e0ddf
Showing 1 changed file with 84 additions and 72 deletions.
156 changes: 84 additions & 72 deletions taskweaver/role/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path
from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Set, Tuple, Union

from injector import Module, inject, provider

Expand Down Expand Up @@ -153,101 +153,113 @@ def format_experience(
else ""
)

def prepare_loading(
self,
use_flag: bool,
dynamic_sub_path: bool,
base_path: str,
memory: Optional[Memory],
loaded_from_attr: str,
item_type: Literal["experience", "example"],
) -> Optional[str]:
"""Prepare for loading by checking configurations and memory, and return load_from path if applicable."""
if not use_flag:
setattr(self, f"{item_type}s", [])
return None

if not os.path.exists(base_path):
raise FileNotFoundError(
f"The default {item_type} base path {base_path} does not exist."
f"The original {item_type} base paths have been changed to `{item_type}s` folder."
f"Please migrate the {item_type}s to the new base path.",
)

sub_path = ""
if dynamic_sub_path:
assert memory is not None, f"Memory should be provided when dynamic_{item_type}_sub_path is True"
sub_paths = memory.get_shared_memory_entries(entry_type=f"{item_type}_sub_path")
if sub_paths:
self.tracing.set_span_attribute(f"{item_type}_sub_path", str(sub_paths))
# todo: handle multiple sub paths
sub_path = sub_paths[0].content
else:
self.logger.info(f"No {item_type} sub path found in memory.")
setattr(self, f"{item_type}s", [])
return None

load_from = os.path.join(base_path, sub_path)
if getattr(self, loaded_from_attr) is not None and getattr(self, loaded_from_attr) == load_from:
self.logger.info(f"{item_type.capitalize()} already loaded from {load_from}.")
return None

setattr(self, loaded_from_attr, load_from)
return sub_path

def role_load_experience(
self,
query: str,
memory: Optional[Memory] = None,
) -> None:
if not self.config.use_experience:
self.experiences = []
sub_path = self.prepare_loading(
self.config.use_experience,
self.config.dynamic_experience_sub_path,
self.config.experience_dir,
memory,
"experience_loaded_from",
"experience",
)
if sub_path is None:
return

if self.experience_generator is None:
raise ValueError(
"Experience generator is not initialized. Each role instance should have its own generator.",
)

experience_sub_path = ""
if self.config.dynamic_experience_sub_path:
assert memory is not None, "Memory should be provided when dynamic_experience_sub_path is True"
experience_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")
if experience_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(experience_sub_paths))
# todo: handle multiple experience sub paths
experience_sub_path = experience_sub_paths[0].content
else:
self.logger.info("No experience sub path found in memory.")
self.experiences = []
return

load_from = os.path.join(self.config.experience_dir, experience_sub_path)
if self.experience_loaded_from is None or self.experience_loaded_from != load_from:
self.experience_loaded_from = load_from
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(experience_sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
self.alias,
len(self.experience_generator.experience_list),
experience_sub_path,
),
)
else:
self.logger.info(f"Experience already loaded from {load_from}.")
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
self.alias,
len(self.experience_generator.experience_list),
sub_path,
),
)

experiences = self.experience_generator.retrieve_experience(query)
self.logger.info(f"Retrieved {len(experiences)} experiences for query [{query}]")
self.experiences = [exp for exp, _ in experiences]

# todo: `role_load_example` is similar to `role_load_experience`, consider refactoring
def role_load_example(
self,
role_set: Set[str],
memory: Optional[Memory] = None,
) -> None:
if not self.config.use_example:
self.examples = []
sub_path = self.prepare_loading(
self.config.use_example,
self.config.dynamic_example_sub_path,
self.config.example_base_path,
memory,
"example_loaded_from",
"example",
)
if sub_path is None:
return

if not os.path.exists(self.config.example_base_path):
raise FileNotFoundError(
f"The default example base path {self.config.example_base_path} does not exist."
"The original example base paths have been changed to `examples` folder."
"Please migrate the examples to the new base path.",
)

example_sub_path = ""
if self.config.dynamic_example_sub_path:
assert memory is not None, "Memory should be provided when dynamic_example_sub_path is True"
example_sub_paths = memory.get_shared_memory_entries(entry_type="example_sub_path")
if example_sub_paths:
self.tracing.set_span_attribute("example_sub_path", str(example_sub_paths))
# todo: handle multiple sub paths
example_sub_path = example_sub_paths[0].content
else:
self.logger.info("No example sub path found in memory.")
self.examples = []
return

load_from = os.path.join(self.config.example_base_path, example_sub_path)
if self.example_loaded_from is None or self.example_loaded_from != load_from:
self.example_loaded_from = load_from
self.examples = load_examples(
folder=self.config.example_base_path,
sub_path=example_sub_path,
role_set=role_set,
)
self.logger.info(
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
self.alias,
len(self.examples),
example_sub_path,
),
)
else:
self.logger.info(f"Example already loaded from {load_from}.")
self.examples = load_examples(
folder=self.config.example_base_path,
sub_path=sub_path,
role_set=role_set,
)
self.logger.info(
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
self.alias,
len(self.examples),
sub_path,
),
)


class RoleModuleConfig(ModuleConfig):
Expand Down

0 comments on commit 08e0ddf

Please sign in to comment.