Skip to content

Commit

Permalink
[feat] Add test suit. Minor improvements in the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
tamoyan committed Jul 10, 2024
1 parent bb87515 commit 58c5dda
Show file tree
Hide file tree
Showing 18 changed files with 383 additions and 80 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,26 @@ jobs:
pylint --disable=trailing-whitespace,missing-class-docstring,missing-final-newline,trailing-newlines \
--fail-under=9.0 \
$(git ls-files '*.py') || echo "::warning::Pylint check failed, but the workflow will continue."
test:
runs-on: ubuntu-latest
needs: linter
strategy:
matrix:
python-version: [3.9, 3.10]

steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Run tests
run: |
pytest
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ urartu launch --name=roleplay action_config=roleplay +action_config/task/model_i

The `action_config` parameter specifies which configuration file to use to run the action. Afterward, we define the configuration file for the inquirer using the `model_inquirer` argument and set the configuration for the responder with the `model_responder` argument.

To execute the command on a Slurm cluster, modify the `roleplay/configs/action_config/generate_dialogues.yaml` file with the corresponding fields, and then use the same command to run the job. For more details on how to edit the configuration files, please refer to the upcoming sections.
To execute the command on a Slurm cluster, modify the `roleplay/configs/action_config/dialogue_generator.yaml` file with the corresponding fields, and then use the same command to run the job. For more details on how to edit the configuration files, please refer to the upcoming sections.

> **Huggingface Authentication**
> You might need to log in to HuggingFace to authenticate your use of Mistral 8x7B. To do this, use the `huggingface-cli` login command and provide your access token.
Expand All @@ -103,7 +103,7 @@ The default configs which shape the way of configs are defined in `urartu` under
You have two flexible options for tailoring your configurations in `roleplay`.

1. **Custom Config Files**: To simplify configuration adjustments, `roleplay` provides a dedicated `configs` directory where you can store personalized configuration files. These files seamlessly integrate with Hydra's search path. The directory structure mirrors that of `urartu/config`. You can define project-specific configurations in specially named files.
The `generate_dialogues.yaml` file within the `configs` directory houses all the configurations specific to our `roleplay` project, with customized settings.
The `dialogue_generator.yaml` file within the `configs` directory houses all the configurations specific to our `roleplay` project, with customized settings.

- **Personalized User Configs**: To further tailor configurations for individual users, create a directory named `configs_{username}` at the same level as the `configs` directory, where `{username}` represents your operating system username (check out `configs_tamoyan` for an example). The beauty of this approach is that there are no additional steps required. Your customizations will smoothly load and override the default configurations, ensuring a seamless and hassle-free experience. ✨

Expand All @@ -122,7 +122,7 @@ Choose the method that suits your workflow best and enjoy the flexibility `urart
With `urartu`, launching actions is incredibly easy, offering you two options. 🚀

- **Local Marvel:** This option allows you to run jobs on your local machine, right where the script is executed.
- **Cluster Voyage:** This choice takes you on a journey to the Slurm cluster. By adjusting the `slurm.use_slurm` setting in `roleplay/configs/action_config/generate_dialogues.yaml`, you can easily switch between local and cluster execution.
- **Cluster Voyage:** This choice takes you on a journey to the Slurm cluster. By adjusting the `slurm.use_slurm` setting in `roleplay/configs/action_config/dialogue_generator.yaml`, you can easily switch between local and cluster execution.

Enjoy the flexibility to choose the launch adventure that best suits your needs and goals!

Expand Down
2 changes: 1 addition & 1 deletion llm_roleplay/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.2
2.0.3
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from pathlib import Path

import logging
import hydra
import jsonlines
import torch
Expand All @@ -9,11 +10,12 @@
from tqdm import tqdm
from urartu.common.action import Action
from urartu.common.dataset import Dataset
from llm_roleplay.common.model import Model

from llm_roleplay.common.persona import Persona


class Roleplay(Action):
class DialogueGenerator(Action):
def __init__(self, cfg: DictConfig, aim_run: Run) -> None:
super().__init__(cfg, aim_run)

Expand All @@ -24,7 +26,7 @@ def track(self, prompt, name, context=None):
context=context,
)

def main(self):
def initialize(self):
self.aim_run["num_no_prompts"] = 0
self.aim_run["num_multiple_prompts"] = 0
self.aim_run["num_non_coherent"] = 0
Expand All @@ -33,56 +35,50 @@ def main(self):
self.aim_run["num_non_coherent_model_responder"] = 0
self.aim_run["personas"] = {}

task_cfg = self.action_cfg.task
self.task_cfg = self.action_cfg.task

records_dir = Path(self.action_cfg.workdir).joinpath(
self.records_dir = Path(self.action_cfg.workdir).joinpath(
"dialogs",
f"{task_cfg.model_inquirer.name.split('/')[-1]}",
f"{self.task_cfg.model_inquirer.name.split('/')[-1]}",
str(self.aim_run.hash),
)
os.makedirs(records_dir, exist_ok=True)
os.makedirs(self.records_dir, exist_ok=True)

dataset = Dataset.get_dataset(task_cfg.dataset)
personas = Persona.get_personas(task_cfg.persona)
self.dataset = Dataset.get_dataset(self.task_cfg.dataset)
self.personas = Persona.get_personas(self.task_cfg.persona)

model_inquirer = hydra.utils.instantiate(
task_cfg.model_inquirer.type, task_cfg.model_inquirer, "model_inquirer"
)
model_responder = hydra.utils.instantiate(
task_cfg.model_responder.type, task_cfg.model_responder, "model_responder"
)
self.model_inquirer = Model.get_model(self.task_cfg.model_inquirer, role="model_inquirer")
self.model_responder = Model.get_model(self.task_cfg.model_responder, role="model_responder")

model_inquirer.spec_tokens = task_cfg.spec_tokens
model_responder.spec_tokens = task_cfg.spec_tokens
model_inquirer.aim_run = self.aim_run
model_responder.aim_run = self.aim_run
self.model_inquirer.spec_tokens = self.task_cfg.spec_tokens
self.model_responder.spec_tokens = self.task_cfg.spec_tokens
self.model_inquirer.aim_run = self.aim_run
self.model_responder.aim_run = self.aim_run

for idx, sample in tqdm(
enumerate(dataset.dataset), total=len(dataset.dataset), desc="samples"
):
for persona, persona_hash in tqdm(personas, desc="personas", leave=False):
def generate(self) -> Path:
for idx, sample in tqdm(enumerate(self.dataset.dataset), total=len(self.dataset.dataset), desc="samples"):
for persona, persona_hash in tqdm(self.personas, desc="personas", leave=False):
self.aim_run["personas"][persona_hash] = persona

model_inquirer.history = []
model_responder.history = []
self.model_inquirer.history = []
self.model_responder.history = []
dialog = []
raw_dialog = []

instructions = [
instruct.lstrip().rstrip()
for instruct in sample[task_cfg.dataset.input_key].split("\n")
instruct.lstrip().rstrip() for instruct in sample[self.task_cfg.dataset.input_key].split("\n")
]

if self.action_cfg.task.model_inquirer.regenerate_tries:
regeneratinon_idx = 0
inquirer_generate_cfg = None
responder_output = None
turn = 0
with tqdm(total=task_cfg.num_turns, desc="turns", leave=False) as pbar:
while turn < task_cfg.num_turns:
with tqdm(total=self.task_cfg.num_turns, desc="turns", leave=False) as pbar:
while turn < self.task_cfg.num_turns:
pbar.set_postfix(turn=turn + 1)
# ------------------------------------------ Model A ------------------------------------------
inquirer_prompt = model_inquirer.get_prompt(
# ------------------------------------------ Inquirer Model ------------------------------------------
inquirer_prompt = self.model_inquirer.get_prompt(
turn=turn,
response_msg=responder_output,
persona=persona,
Expand All @@ -98,7 +94,7 @@ def main(self):
"persona_hash": persona_hash,
},
)
inquirer_output, _ = model_inquirer.generate(
inquirer_output, _ = self.model_inquirer.generate(
prompt=inquirer_prompt,
generate_cfg=(
inquirer_generate_cfg
Expand All @@ -119,28 +115,23 @@ def main(self):
)

# --------------------- if model_inquirer failed to provide coherent text ---------------------
if model_inquirer.is_non_coherent(inquirer_output):
if self.model_inquirer.is_non_coherent(inquirer_output):
self.aim_run["num_non_coherent"] += 1
break

# --------------------- if model_inquirer wants to stop the dialog ---------------------
if model_inquirer.stop_dialog(inquirer_output):
if self.model_inquirer.stop_dialog(inquirer_output):
break

inquirer_output_extract, num_prompts = (
model_inquirer.extract_prompt(prompt=inquirer_output)
inquirer_output_extract, num_prompts = self.model_inquirer.extract_prompt(
prompt=inquirer_output
)

if self.action_cfg.task.model_inquirer.regenerate_tries:
# --------------------- if model_inquirer failed to provide prompt ---------------------
if inquirer_output_extract is None:
if (
regeneratinon_idx
< self.action_cfg.task.model_inquirer.regenerate_tries
):
inquirer_generate_cfg = (
model_inquirer.get_generation_cfg()
)
if regeneratinon_idx < self.action_cfg.task.model_inquirer.regenerate_tries:
inquirer_generate_cfg = self.model_inquirer.get_generation_cfg()
regeneratinon_idx += 1
continue
else:
Expand Down Expand Up @@ -169,14 +160,14 @@ def main(self):

# As the context for model_inquirer is getting bigger much faster -> Starts answering it's own questions
# To prevent this keep in the inquirer_history only the output prompt(the thing that model_responder will see).
model_inquirer.update_history(
self.model_inquirer.update_history(
prompt=inquirer_prompt,
output_extract=inquirer_output_extract,
)

# ------------------------------------------ Model B ------------------------------------------
# ------------------------------------------ Responder Model ------------------------------------------

responder_prompt = model_responder.get_prompt(
responder_prompt = self.model_responder.get_prompt(
turn=turn, response_msg=inquirer_output_extract
)

Expand All @@ -189,11 +180,9 @@ def main(self):
"persona_hash": persona_hash,
},
)
responder_output, responder_model_output_template = (
model_responder.generate(
prompt=responder_prompt,
generate_cfg=self.action_cfg.task.model_responder.generate,
)
responder_output, responder_model_output_template = self.model_responder.generate(
prompt=responder_prompt,
generate_cfg=self.action_cfg.task.model_responder.generate,
)
if not responder_output:
break
Expand All @@ -208,16 +197,16 @@ def main(self):
)

# --------------------- if model_responder failed to provide coherent text ---------------------
if model_responder.is_non_coherent(responder_output):
if self.model_responder.is_non_coherent(responder_output):
self.aim_run["num_non_coherent_model_responder"] += 1
break

model_responder.update_history(
self.model_responder.update_history(
prompt=responder_prompt,
output_extract=responder_model_output_template,
)

# --------------------------------------- Save the dialog ---------------------------------------
# --------------------------------------- Saving the dialogue ---------------------------------------
dialog.append(
{
"turn": turn,
Expand All @@ -232,9 +221,7 @@ def main(self):
turn += 1
pbar.update(1)

with jsonlines.open(
records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a"
) as writer:
with jsonlines.open(self.records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a") as writer:
writer.write(
{
"persona": persona,
Expand All @@ -244,7 +231,11 @@ def main(self):
}
)

return self.records_dir


def main(cfg: DictConfig, aim_run: Run):
roleplay = Roleplay(cfg, aim_run)
roleplay.main()
dialogue_generator = DialogueGenerator(cfg, aim_run)
dialogue_generator.initialize()
dialogues_dir = dialogue_generator.generate()
logging.info(f"Dialogues succesfully generated and stored in: {dialogues_dir}")
11 changes: 8 additions & 3 deletions llm_roleplay/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import re
import string
import hydra
from typing import Any, Dict, List

from urartu.common.device import Device
Expand All @@ -18,10 +19,14 @@ def __init__(self, cfg: List[Dict[str, Any]], role=None):
self.tokenizer = None
self.role = role
self.history = []
self._load_model()
self._get_model()

def _load_model(self):
raise NotImplementedError("method '_load_model' is not implemented")
@staticmethod
def get_model(cfg, role):
return hydra.utils.instantiate(cfg.type, cfg, role)

def _get_model(self):
raise NotImplementedError("method '_get_model' is not implemented")

def get_prompt(self, turn, response_msg, persona=None, instructions=None):
raise NotImplementedError("method 'get_prompt' is not implemented")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @package _global_
action_name: generate_dialogues
action_name: dialogue_generator
seed: 5

action_config:
Expand All @@ -15,7 +15,7 @@ action_config:

dataset:
type:
_target_: roleplay.datasets.hf_datasets.HFDatasets
_target_: llm_roleplay.datasets.hf_datasets.HFDatasets
input_key: "instruction"
data:
instruction:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.causal_lm_model.CausalLMModel
_target_: llm_roleplay.models.causal_lm_model.CausalLMModel
name: "tiiuae/falcon-40b-instruct"
cache_dir: ""
dtype: torch.float16
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.openai_model.OpenAIModel
_target_: llm_roleplay.models.openai_model.OpenAIModel
openai_api_type: "azure"
openai_api_version: "2023-05-15"
azure_openai_endpoint: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.openai_model.OpenAIModel
_target_: llm_roleplay.models.openai_model.OpenAIModel
openai_api_type: "azure"
openai_api_version: "2023-05-15"
azure_openai_endpoint: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.causal_lm_model.CausalLMModel
_target_: llm_roleplay.models.causal_lm_model.CausalLMModel
name: "meta-llama/Llama-2-13b-chat-hf"
cache_dir: ""
dtype: torch.float16
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
type:
_target_: roleplay.models.causal_lm_model.CausalLMModel
_target_: llm_roleplay.models.causal_lm_model.CausalLMModel
name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
role: "inquirer"
cache_dir: ""
dtype: torch.float16
non_coherent_max_n: 4
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.causal_lm_model.CausalLMModel
_target_: llm_roleplay.models.causal_lm_model.CausalLMModel
name: "lmsys/vicuna-13b-v1.5-16k"
cache_dir: ""
dtype: torch.float16
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type:
_target_: roleplay.models.pipeline_model.PipelineModel
_target_: llm_roleplay.models.pipeline_model.PipelineModel
name: "models--llama-2-hf/13B-Chat"
cache_dir: ""
dtype: torch.float16
Expand Down
Loading

0 comments on commit 58c5dda

Please sign in to comment.