diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 362b9e5..067bd73 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -67,3 +67,26 @@ jobs: pylint --disable=trailing-whitespace,missing-class-docstring,missing-final-newline,trailing-newlines \ --fail-under=9.0 \ $(git ls-files '*.py') || echo "::warning::Pylint check failed, but the workflow will continue." + + test: + runs-on: ubuntu-latest + needs: linter + strategy: + matrix: + python-version: [3.9, 3.10] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + - name: Run tests + run: | + pytest diff --git a/README.md b/README.md index 87ccecf..11b9916 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ urartu launch --name=roleplay action_config=roleplay +action_config/task/model_i The `action_config` parameter specifies which configuration file to use to run the action. Afterward, we define the configuration file for the inquirer using the `model_inquirer` argument and set the configuration for the responder with the `model_responder` argument. -To execute the command on a Slurm cluster, modify the `roleplay/configs/action_config/generate_dialogues.yaml` file with the corresponding fields, and then use the same command to run the job. For more details on how to edit the configuration files, please refer to the upcoming sections. +To execute the command on a Slurm cluster, modify the `roleplay/configs/action_config/dialogue_generator.yaml` file with the corresponding fields, and then use the same command to run the job. For more details on how to edit the configuration files, please refer to the upcoming sections. > **Huggingface Authentication** > You might need to log in to HuggingFace to authenticate your use of Mistral 8x7B. To do this, use the `huggingface-cli` login command and provide your access token. @@ -103,7 +103,7 @@ The default configs which shape the way of configs are defined in `urartu` under You have two flexible options for tailoring your configurations in `roleplay`. 1. **Custom Config Files**: To simplify configuration adjustments, `roleplay` provides a dedicated `configs` directory where you can store personalized configuration files. These files seamlessly integrate with Hydra's search path. The directory structure mirrors that of `urartu/config`. You can define project-specific configurations in specially named files. - The `generate_dialogues.yaml` file within the `configs` directory houses all the configurations specific to our `roleplay` project, with customized settings. + The `dialogue_generator.yaml` file within the `configs` directory houses all the configurations specific to our `roleplay` project, with customized settings. - **Personalized User Configs**: To further tailor configurations for individual users, create a directory named `configs_{username}` at the same level as the `configs` directory, where `{username}` represents your operating system username (check out `configs_tamoyan` for an example). The beauty of this approach is that there are no additional steps required. Your customizations will smoothly load and override the default configurations, ensuring a seamless and hassle-free experience. ✨ @@ -122,7 +122,7 @@ Choose the method that suits your workflow best and enjoy the flexibility `urart With `urartu`, launching actions is incredibly easy, offering you two options. 🚀 - **Local Marvel:** This option allows you to run jobs on your local machine, right where the script is executed. -- **Cluster Voyage:** This choice takes you on a journey to the Slurm cluster. By adjusting the `slurm.use_slurm` setting in `roleplay/configs/action_config/generate_dialogues.yaml`, you can easily switch between local and cluster execution. +- **Cluster Voyage:** This choice takes you on a journey to the Slurm cluster. By adjusting the `slurm.use_slurm` setting in `roleplay/configs/action_config/dialogue_generator.yaml`, you can easily switch between local and cluster execution. Enjoy the flexibility to choose the launch adventure that best suits your needs and goals! diff --git a/llm_roleplay/VERSION b/llm_roleplay/VERSION index f93ea0c..6acdb44 100644 --- a/llm_roleplay/VERSION +++ b/llm_roleplay/VERSION @@ -1 +1 @@ -2.0.2 \ No newline at end of file +2.0.3 \ No newline at end of file diff --git a/llm_roleplay/actions/generate_dialogues.py b/llm_roleplay/actions/dialogue_generator.py similarity index 70% rename from llm_roleplay/actions/generate_dialogues.py rename to llm_roleplay/actions/dialogue_generator.py index 5013c5c..491967d 100644 --- a/llm_roleplay/actions/generate_dialogues.py +++ b/llm_roleplay/actions/dialogue_generator.py @@ -1,6 +1,7 @@ import os from pathlib import Path +import logging import hydra import jsonlines import torch @@ -9,11 +10,12 @@ from tqdm import tqdm from urartu.common.action import Action from urartu.common.dataset import Dataset +from llm_roleplay.common.model import Model from llm_roleplay.common.persona import Persona -class Roleplay(Action): +class DialogueGenerator(Action): def __init__(self, cfg: DictConfig, aim_run: Run) -> None: super().__init__(cfg, aim_run) @@ -24,7 +26,7 @@ def track(self, prompt, name, context=None): context=context, ) - def main(self): + def initialize(self): self.aim_run["num_no_prompts"] = 0 self.aim_run["num_multiple_prompts"] = 0 self.aim_run["num_non_coherent"] = 0 @@ -33,44 +35,38 @@ def main(self): self.aim_run["num_non_coherent_model_responder"] = 0 self.aim_run["personas"] = {} - task_cfg = self.action_cfg.task + self.task_cfg = self.action_cfg.task - records_dir = Path(self.action_cfg.workdir).joinpath( + self.records_dir = Path(self.action_cfg.workdir).joinpath( "dialogs", - f"{task_cfg.model_inquirer.name.split('/')[-1]}", + f"{self.task_cfg.model_inquirer.name.split('/')[-1]}", str(self.aim_run.hash), ) - os.makedirs(records_dir, exist_ok=True) + os.makedirs(self.records_dir, exist_ok=True) - dataset = Dataset.get_dataset(task_cfg.dataset) - personas = Persona.get_personas(task_cfg.persona) + self.dataset = Dataset.get_dataset(self.task_cfg.dataset) + self.personas = Persona.get_personas(self.task_cfg.persona) - model_inquirer = hydra.utils.instantiate( - task_cfg.model_inquirer.type, task_cfg.model_inquirer, "model_inquirer" - ) - model_responder = hydra.utils.instantiate( - task_cfg.model_responder.type, task_cfg.model_responder, "model_responder" - ) + self.model_inquirer = Model.get_model(self.task_cfg.model_inquirer, role="model_inquirer") + self.model_responder = Model.get_model(self.task_cfg.model_responder, role="model_responder") - model_inquirer.spec_tokens = task_cfg.spec_tokens - model_responder.spec_tokens = task_cfg.spec_tokens - model_inquirer.aim_run = self.aim_run - model_responder.aim_run = self.aim_run + self.model_inquirer.spec_tokens = self.task_cfg.spec_tokens + self.model_responder.spec_tokens = self.task_cfg.spec_tokens + self.model_inquirer.aim_run = self.aim_run + self.model_responder.aim_run = self.aim_run - for idx, sample in tqdm( - enumerate(dataset.dataset), total=len(dataset.dataset), desc="samples" - ): - for persona, persona_hash in tqdm(personas, desc="personas", leave=False): + def generate(self) -> Path: + for idx, sample in tqdm(enumerate(self.dataset.dataset), total=len(self.dataset.dataset), desc="samples"): + for persona, persona_hash in tqdm(self.personas, desc="personas", leave=False): self.aim_run["personas"][persona_hash] = persona - model_inquirer.history = [] - model_responder.history = [] + self.model_inquirer.history = [] + self.model_responder.history = [] dialog = [] raw_dialog = [] instructions = [ - instruct.lstrip().rstrip() - for instruct in sample[task_cfg.dataset.input_key].split("\n") + instruct.lstrip().rstrip() for instruct in sample[self.task_cfg.dataset.input_key].split("\n") ] if self.action_cfg.task.model_inquirer.regenerate_tries: @@ -78,11 +74,11 @@ def main(self): inquirer_generate_cfg = None responder_output = None turn = 0 - with tqdm(total=task_cfg.num_turns, desc="turns", leave=False) as pbar: - while turn < task_cfg.num_turns: + with tqdm(total=self.task_cfg.num_turns, desc="turns", leave=False) as pbar: + while turn < self.task_cfg.num_turns: pbar.set_postfix(turn=turn + 1) - # ------------------------------------------ Model A ------------------------------------------ - inquirer_prompt = model_inquirer.get_prompt( + # ------------------------------------------ Inquirer Model ------------------------------------------ + inquirer_prompt = self.model_inquirer.get_prompt( turn=turn, response_msg=responder_output, persona=persona, @@ -98,7 +94,7 @@ def main(self): "persona_hash": persona_hash, }, ) - inquirer_output, _ = model_inquirer.generate( + inquirer_output, _ = self.model_inquirer.generate( prompt=inquirer_prompt, generate_cfg=( inquirer_generate_cfg @@ -119,28 +115,23 @@ def main(self): ) # --------------------- if model_inquirer failed to provide coherent text --------------------- - if model_inquirer.is_non_coherent(inquirer_output): + if self.model_inquirer.is_non_coherent(inquirer_output): self.aim_run["num_non_coherent"] += 1 break # --------------------- if model_inquirer wants to stop the dialog --------------------- - if model_inquirer.stop_dialog(inquirer_output): + if self.model_inquirer.stop_dialog(inquirer_output): break - inquirer_output_extract, num_prompts = ( - model_inquirer.extract_prompt(prompt=inquirer_output) + inquirer_output_extract, num_prompts = self.model_inquirer.extract_prompt( + prompt=inquirer_output ) if self.action_cfg.task.model_inquirer.regenerate_tries: # --------------------- if model_inquirer failed to provide prompt --------------------- if inquirer_output_extract is None: - if ( - regeneratinon_idx - < self.action_cfg.task.model_inquirer.regenerate_tries - ): - inquirer_generate_cfg = ( - model_inquirer.get_generation_cfg() - ) + if regeneratinon_idx < self.action_cfg.task.model_inquirer.regenerate_tries: + inquirer_generate_cfg = self.model_inquirer.get_generation_cfg() regeneratinon_idx += 1 continue else: @@ -169,14 +160,14 @@ def main(self): # As the context for model_inquirer is getting bigger much faster -> Starts answering it's own questions # To prevent this keep in the inquirer_history only the output prompt(the thing that model_responder will see). - model_inquirer.update_history( + self.model_inquirer.update_history( prompt=inquirer_prompt, output_extract=inquirer_output_extract, ) - # ------------------------------------------ Model B ------------------------------------------ + # ------------------------------------------ Responder Model ------------------------------------------ - responder_prompt = model_responder.get_prompt( + responder_prompt = self.model_responder.get_prompt( turn=turn, response_msg=inquirer_output_extract ) @@ -189,11 +180,9 @@ def main(self): "persona_hash": persona_hash, }, ) - responder_output, responder_model_output_template = ( - model_responder.generate( - prompt=responder_prompt, - generate_cfg=self.action_cfg.task.model_responder.generate, - ) + responder_output, responder_model_output_template = self.model_responder.generate( + prompt=responder_prompt, + generate_cfg=self.action_cfg.task.model_responder.generate, ) if not responder_output: break @@ -208,16 +197,16 @@ def main(self): ) # --------------------- if model_responder failed to provide coherent text --------------------- - if model_responder.is_non_coherent(responder_output): + if self.model_responder.is_non_coherent(responder_output): self.aim_run["num_non_coherent_model_responder"] += 1 break - model_responder.update_history( + self.model_responder.update_history( prompt=responder_prompt, output_extract=responder_model_output_template, ) - # --------------------------------------- Save the dialog --------------------------------------- + # --------------------------------------- Saving the dialogue --------------------------------------- dialog.append( { "turn": turn, @@ -232,9 +221,7 @@ def main(self): turn += 1 pbar.update(1) - with jsonlines.open( - records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a" - ) as writer: + with jsonlines.open(self.records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a") as writer: writer.write( { "persona": persona, @@ -244,7 +231,11 @@ def main(self): } ) + return self.records_dir + def main(cfg: DictConfig, aim_run: Run): - roleplay = Roleplay(cfg, aim_run) - roleplay.main() + dialogue_generator = DialogueGenerator(cfg, aim_run) + dialogue_generator.initialize() + dialogues_dir = dialogue_generator.generate() + logging.info(f"Dialogues succesfully generated and stored in: {dialogues_dir}") diff --git a/llm_roleplay/common/model.py b/llm_roleplay/common/model.py index 3808ad4..2573d3c 100644 --- a/llm_roleplay/common/model.py +++ b/llm_roleplay/common/model.py @@ -3,6 +3,7 @@ import random import re import string +import hydra from typing import Any, Dict, List from urartu.common.device import Device @@ -18,10 +19,14 @@ def __init__(self, cfg: List[Dict[str, Any]], role=None): self.tokenizer = None self.role = role self.history = [] - self._load_model() + self._get_model() - def _load_model(self): - raise NotImplementedError("method '_load_model' is not implemented") + @staticmethod + def get_model(cfg, role): + return hydra.utils.instantiate(cfg.type, cfg, role) + + def _get_model(self): + raise NotImplementedError("method '_get_model' is not implemented") def get_prompt(self, turn, response_msg, persona=None, instructions=None): raise NotImplementedError("method 'get_prompt' is not implemented") diff --git a/llm_roleplay/configs/action_config/generate_dialogues.yaml b/llm_roleplay/configs/action_config/dialogue_generator.yaml similarity index 99% rename from llm_roleplay/configs/action_config/generate_dialogues.yaml rename to llm_roleplay/configs/action_config/dialogue_generator.yaml index 4884e9b..10b7cd2 100644 --- a/llm_roleplay/configs/action_config/generate_dialogues.yaml +++ b/llm_roleplay/configs/action_config/dialogue_generator.yaml @@ -1,5 +1,5 @@ # @package _global_ -action_name: generate_dialogues +action_name: dialogue_generator seed: 5 action_config: @@ -15,7 +15,7 @@ action_config: dataset: type: - _target_: roleplay.datasets.hf_datasets.HFDatasets + _target_: llm_roleplay.datasets.hf_datasets.HFDatasets input_key: "instruction" data: instruction: diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/falcon.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/falcon.yaml index 647b83a..52c9054 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/falcon.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/falcon.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.causal_lm_model.CausalLMModel + _target_: llm_roleplay.models.causal_lm_model.CausalLMModel name: "tiiuae/falcon-40b-instruct" cache_dir: "" dtype: torch.float16 diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/gpt3.5.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/gpt3.5.yaml index 9f2659f..c6a0599 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/gpt3.5.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/gpt3.5.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.openai_model.OpenAIModel + _target_: llm_roleplay.models.openai_model.OpenAIModel openai_api_type: "azure" openai_api_version: "2023-05-15" azure_openai_endpoint: null diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/gpt4.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/gpt4.yaml index 5b21405..e021a05 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/gpt4.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/gpt4.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.openai_model.OpenAIModel + _target_: llm_roleplay.models.openai_model.OpenAIModel openai_api_type: "azure" openai_api_version: "2023-05-15" azure_openai_endpoint: null diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/llama.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/llama.yaml index 0a5097c..cb7b7b6 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/llama.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/llama.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.causal_lm_model.CausalLMModel + _target_: llm_roleplay.models.causal_lm_model.CausalLMModel name: "meta-llama/Llama-2-13b-chat-hf" cache_dir: "" dtype: torch.float16 diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/mixtral.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/mixtral.yaml index 2fc56ae..c85fa80 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/mixtral.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/mixtral.yaml @@ -1,6 +1,7 @@ type: - _target_: roleplay.models.causal_lm_model.CausalLMModel + _target_: llm_roleplay.models.causal_lm_model.CausalLMModel name: "mistralai/Mixtral-8x7B-Instruct-v0.1" +role: "inquirer" cache_dir: "" dtype: torch.float16 non_coherent_max_n: 4 diff --git a/llm_roleplay/configs/action_config/task/model_inquirer/vicuna.yaml b/llm_roleplay/configs/action_config/task/model_inquirer/vicuna.yaml index 8e9d578..007bab8 100644 --- a/llm_roleplay/configs/action_config/task/model_inquirer/vicuna.yaml +++ b/llm_roleplay/configs/action_config/task/model_inquirer/vicuna.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.causal_lm_model.CausalLMModel + _target_: llm_roleplay.models.causal_lm_model.CausalLMModel name: "lmsys/vicuna-13b-v1.5-16k" cache_dir: "" dtype: torch.float16 diff --git a/llm_roleplay/configs/action_config/task/model_responder/llama.yaml b/llm_roleplay/configs/action_config/task/model_responder/llama.yaml index 1cf1223..91ae0d3 100644 --- a/llm_roleplay/configs/action_config/task/model_responder/llama.yaml +++ b/llm_roleplay/configs/action_config/task/model_responder/llama.yaml @@ -1,5 +1,5 @@ type: - _target_: roleplay.models.pipeline_model.PipelineModel + _target_: llm_roleplay.models.pipeline_model.PipelineModel name: "models--llama-2-hf/13B-Chat" cache_dir: "" dtype: torch.float16 diff --git a/llm_roleplay/models/causal_lm_model.py b/llm_roleplay/models/causal_lm_model.py index 66fc30e..f3ec0df 100644 --- a/llm_roleplay/models/causal_lm_model.py +++ b/llm_roleplay/models/causal_lm_model.py @@ -6,7 +6,7 @@ from urartu.common.device import Device from urartu.utils.dtype import eval_dtype -from roleplay.common.model import Model +from llm_roleplay.common.model import Model class CausalLMModel(Model): @@ -18,7 +18,7 @@ class CausalLMModel(Model): def __init__(self, cfg, role) -> None: super().__init__(cfg, role) - def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: + def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: self.model = AutoModelForCausalLM.from_pretrained( self.cfg.name, cache_dir=self.cfg.cache_dir, diff --git a/llm_roleplay/models/openai_model.py b/llm_roleplay/models/openai_model.py index 69ecaba..7709a59 100644 --- a/llm_roleplay/models/openai_model.py +++ b/llm_roleplay/models/openai_model.py @@ -5,14 +5,14 @@ from langchain_openai import AzureChatOpenAI from transformers import AutoModelForCausalLM, AutoTokenizer -from roleplay.common.model import Model +from llm_roleplay.common.model import Model class OpenAIModel(Model): def __init__(self, cfg, role) -> None: super().__init__(cfg, role) - def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: + def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: self.model = AzureChatOpenAI( deployment_name=self.cfg.name, openai_api_type=self.cfg.openai_api_type, diff --git a/llm_roleplay/models/pipeline_model.py b/llm_roleplay/models/pipeline_model.py index 7fcf8a3..3476015 100644 --- a/llm_roleplay/models/pipeline_model.py +++ b/llm_roleplay/models/pipeline_model.py @@ -4,7 +4,7 @@ from urartu.common.device import Device from urartu.utils.dtype import eval_dtype -from roleplay.common.model import Model +from llm_roleplay.common.model import Model class PipelineModel(Model): @@ -16,7 +16,7 @@ class PipelineModel(Model): def __init__(self, cfg, role) -> None: super().__init__(cfg, role) - def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: + def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: model = AutoModelForCausalLM.from_pretrained( self.cfg.name, cache_dir=self.cfg.cache_dir, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_roleplay.py b/tests/test_roleplay.py new file mode 100644 index 0000000..144fae6 --- /dev/null +++ b/tests/test_roleplay.py @@ -0,0 +1,283 @@ +import importlib +import json +import unittest +from unittest.mock import patch + +from aim import Run +from aim.storage.context import Context +from llm_roleplay.actions.dialogue_generator import DialogueGenerator +from omegaconf import OmegaConf +from urartu.common.device import Device + + +class TestRoleplay(unittest.TestCase): + def setUp(self): + # Configuration and Run mock setup + self.cfg = OmegaConf.create( + { + "action_name": "test_roleplay", + "aim": {"repo": "tmp"}, + "action_config": { + "workdir": "tmp", + "experiment_name": "test experiment for roleplay", + "device": "auto", + "task": { + "num_turns": 2, + "model_inquirer": { + "type": {"_target_": "llm_roleplay.models.causal_lm_model.CausalLMModel"}, + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "cache_dir": "", + "dtype": "torch.float16", + "non_coherent_max_n": 4, + "non_coherent_r": 2, + "regenerate_tries": None, + "api_token": None, + "generate": {"do_sample": True, "max_new_tokens": 1000}, + "conv_template": { + "first_turn_input": ( + "[INST]\n" + 'You are . You will start a conversation with an assistant. If you accomplish your final goal during the conversation only say "".\n\n' + "Your ultimate goal is as follows: . What prompt will you use to direct the assistant toward achieving your goal? Please provide the prompt within double quotes. Use simple language, keep the prompts brief, and be on point. Do not greet the assistant. Maintain a casual style; avoid being overly friendly, don't say thank you. [/INST]" + ), + "n_th_turn_input": "[INST] [/INST]", + "model_output": "", + "response_forwarding": ( + 'If the assistant didn\'t help you achieve your goal, ask follow-up or clarification questions within double quotes. Be suspicious, curious, and demanding. Keep it simple, brief, and to the point. Stay casual; avoid being overly friendly. Assistant response: \n\n"".' + ), + }, + "idx_of_possible_prompt": 0, + }, + "model_responder": { + "type": {"_target_": "llm_roleplay.models.pipeline_model.PipelineModel"}, + "name": "models--llama-2-hf/13B-Chat", + "cache_dir": "", + "dtype": "torch.float16", + "non_coherent_max_n": 5, + "non_coherent_r": 2, + "api_token": None, + "generate": {"max_new_tokens": 4000}, + "conv_template": { + "first_turn_input": "[INST] <>\n You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n [/INST]", + "n_th_turn_input": "[INST] [/INST]", + "model_output": "", + }, + }, + "dataset": { + "type": {"_target_": "llm_roleplay.datasets.hf_datasets.HFDatasets"}, + "input_key": "instruction", + "data": { + "instruction": [ + "You want to know how fast you run different distances. You use a stopwatch to measure the time it takes you to complete a 50-meter, 100-meter, and 200-meter race. You want to know how can you calculate your speed for each race? Based on that, you also want to calculate how many calories you burned during each race." + ] + }, + }, + "persona": { + "prompt": "-year-old individual with a gender identity, holding and English your native language", + "fixed": [ + { + "person": { + "age": "a 25 to 34", + "race": "White", + "gender": "Male", + "education": "Master's degree", + "native_english": "is not", + } + } + ], + }, + "spec_tokens": { + "persona_placeholder": "", + "objective_placeholder": "", + "response_placeholder": "", + "conv_stop_placeholder": "", + "conv_stop_token": "FINISH", + "user_msg": "", + "model_answer": "", + "next_prompt": "", + "bos_token": "", + "sep_token": "", + }, + }, + }, + "seed": 42, + } + ) + + self.sample_inquirer_output = '''"Hey assistant, I'm looking to measure my running speed and calculate..."''' + self.sample_responder_output = '''"Hello! I'd be happy to help you with your questions about measuring..."''' + self.aim_run = Run(repo=self.cfg.aim.repo, experiment=self.cfg.action_config.experiment_name) + self.aim_run.set("cfg", self.cfg, strict=False) + + def test_tracking_calls(self): + dialogue_generator = DialogueGenerator(self.cfg, self.aim_run) + self.assertEqual( + self.aim_run["cfg"]["action_name"], + self.cfg.action_name, + "Action name in AIM run config does not match the expected value", + ) + + dialogue_generator.track(self.sample_inquirer_output, "test_inquirer_input") + text_seq = self.aim_run.get_text_sequence("test_inquirer_input", context=Context({})) + text_record = next(iter(text_seq.data), None) + self.assertIsNotNone(text_record, "No text records found in AIM run tracking") + + @patch("llm_roleplay.models.openai_model.OpenAIModel.generate") + @patch("llm_roleplay.models.pipeline_model.PipelineModel.generate") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel.generate") + @patch("llm_roleplay.models.openai_model.OpenAIModel._get_model") + @patch("llm_roleplay.models.pipeline_model.PipelineModel._get_model") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel._get_model") + def test_initialization( + self, + mock_get_model_clm, + mock_get_model_pipe, + mock_get_mode_openai, + mock_generate_clm, + mock_generate_pipe, + mock_generate_openai, + ): + dialogue_generator = DialogueGenerator(self.cfg, self.aim_run) + dialogue_generator.initialize() + + self.assertTrue(hasattr(dialogue_generator, "task_cfg"), "dialogue_generator is missing 'task_cfg' attribute") + self.assertIsNotNone(dialogue_generator.task_cfg, "'task_cfg' attribute is None") + + self.assertTrue( + hasattr(dialogue_generator, "records_dir"), "dialogue_generator is missing 'records_dir' attribute" + ) + self.assertIsNotNone(dialogue_generator.records_dir, "'records_dir' attribute is None") + + self.assertTrue(hasattr(dialogue_generator, "dataset"), "dialogue_generator is missing 'dataset' attribute") + self.assertIsNotNone(dialogue_generator.dataset, "'dataset' attribute is None") + self.assertEqual( + dialogue_generator.dataset.dataset.num_rows, + len(self.cfg.action_config.task.dataset.data.instruction), + "Number of rows in dataset does not match expected value", + ) + + self.assertTrue(hasattr(dialogue_generator, "personas"), "dialogue_generator is missing 'personas' attribute") + self.assertIsNotNone(dialogue_generator.personas, "'personas' attribute is None") + self.assertEqual( + len(dialogue_generator.personas), + len(self.cfg.action_config.task.persona.fixed), + "Mismatch in number of fixed personas", + ) + + self.assertTrue( + hasattr(dialogue_generator, "model_inquirer"), "dialogue_generator is missing 'model_inquirer' attribute" + ) + self.assertIsNotNone(dialogue_generator.model_inquirer, "'model_inquirer' attribute is None") + class_path = self.cfg.action_config.task.model_inquirer.type._target_ + module_name, class_name = class_path.rsplit('.', 1) + module = importlib.import_module(module_name) + assert isinstance(dialogue_generator.model_inquirer, getattr(module, class_name)), f"The 'model_inquirer' should be an instance of {class_name} from {module_name}, but got {type(dialogue_generator.model_inquirer).__name__}" + + self.assertTrue( + hasattr(dialogue_generator, "model_responder"), "dialogue_generator is missing 'model_responder' attribute" + ) + self.assertIsNotNone(dialogue_generator.model_responder, "'model_responder' attribute is None") + class_path = self.cfg.action_config.task.model_responder.type._target_ + module_name, class_name = class_path.rsplit('.', 1) + module = importlib.import_module(module_name) + assert isinstance(dialogue_generator.model_responder, getattr(module, class_name)), f"The 'model_responder' should be an instance of {class_name} from {module_name}, but got {type(dialogue_generator.model_responder).__name__}" + + + @patch("llm_roleplay.models.openai_model.OpenAIModel.generate") + @patch("llm_roleplay.models.pipeline_model.PipelineModel.generate") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel.generate") + @patch("llm_roleplay.models.openai_model.OpenAIModel._get_model") + @patch("llm_roleplay.models.pipeline_model.PipelineModel._get_model") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel._get_model") + @patch("torch.cuda.empty_cache") + def test_resource_management( + self, + mock_empty_cache, + mock_get_model_clm, + mock_get_model_pipe, + mock_get_mode_openai, + mock_generate_clm, + mock_generate_pipe, + mock_generate_openai, + ): + Device.set_device(self.cfg.action_config.device) + self.assertEqual( + Device.get_device(), + self.cfg.action_config.device, + "Device configuration does not match the expected setting", + ) + + dialogue_generator = DialogueGenerator(self.cfg, self.aim_run) + dialogue_generator.initialize() + + dialogue_generator.model_inquirer.generate.return_value = (self.sample_inquirer_output, None) + dialogue_generator.model_responder.generate.return_value = (self.sample_responder_output, None) + + records_dir = dialogue_generator.generate() + self.assertTrue(records_dir.is_dir(), "Generated records directory does not exist") + self.assertTrue( + (records_dir / f"{self.cfg.seed}.jsonl").exists(), "Expected jsonl file not found in records directory" + ) + self.assertTrue( + str(records_dir).startswith("tmp/dialogs"), "Records directory path does not start with 'tmp/dialogs'" + ) + + mock_empty_cache.assert_called() + + @patch("llm_roleplay.models.openai_model.OpenAIModel.generate") + @patch("llm_roleplay.models.pipeline_model.PipelineModel.generate") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel.generate") + @patch("llm_roleplay.models.openai_model.OpenAIModel._get_model") + @patch("llm_roleplay.models.pipeline_model.PipelineModel._get_model") + @patch("llm_roleplay.models.causal_lm_model.CausalLMModel._get_model") + def test_dialogue_generation( + self, + mock_get_model_clm, + mock_get_model_pipe, + mock_get_mode_openai, + mock_generate_clm, + mock_generate_pipe, + mock_generate_openai, + ): + dialogue_generator = DialogueGenerator(self.cfg, self.aim_run) + dialogue_generator.initialize() + + dialogue_generator.model_inquirer.generate.return_value = (self.sample_inquirer_output, None) + dialogue_generator.model_responder.generate.return_value = (self.sample_responder_output, None) + + records_dir = dialogue_generator.generate() + + line_count = 0 + with (records_dir / f"{self.cfg.seed}.jsonl").open("r", encoding="utf-8") as file: + for line in file: + file_content = json.loads(line) + line_count += 1 + + self.assertEqual( + line_count, + len(self.cfg.action_config.task.persona.fixed), + "Generated line count does not match the number of fixed personas", + ) + self.assertEqual( + file_content["num_turns"], + self.cfg.action_config.task.num_turns, + "Number of turns in dialogue does not match expected number", + ) + + for utterance in file_content["dialog"]: + self.assertIsNotNone(utterance["model_inquirer"], "Model inquirer response is None") + self.assertIsNotNone(utterance["model_responder"], "Model responder response is None") + + self.assertEqual( + dialogue_generator.model_inquirer.generate.call_count, + self.cfg.action_config.task.num_turns, + "Inquirer model was not called the expected number of times", + ) + self.assertEqual( + dialogue_generator.model_responder.generate.call_count, + self.cfg.action_config.task.num_turns, + "Responder model was not called the expected number of times", + ) + + +if __name__ == "__main__": + unittest.main()