Skip to content

Commit

Permalink
(testing) Add unit tests for lab generate (NVIDIA#715)
Browse files Browse the repository at this point in the history
Closes NVIDIA#259

Signed-off-by: Anik Bhattacharjee <anbhatta@redhat.com>
  • Loading branch information
anik120 authored Mar 27, 2024
1 parent 19e74e3 commit 19b9f47
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 90 deletions.
6 changes: 4 additions & 2 deletions cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
DEFAULT_GENERATED_FILES_OUTPUT_DIR = "generated"
DEFAULT_GREEDY_MODE = False
DEFAULT_YAML_RULES = "yaml_rules.yaml"
DEFAULT_NUM_CPUS = 10
DEFAULT_NUM_INSTRUCTIONS = 100


class ConfigException(Exception):
Expand Down Expand Up @@ -156,8 +158,8 @@ def get_default_config():
)
generate = _generate(
model=DEFAULT_MODEL,
num_cpus=10,
num_instructions=100,
num_cpus=DEFAULT_NUM_CPUS,
num_instructions=DEFAULT_NUM_INSTRUCTIONS,
taxonomy_path=DEFAULT_TAXONOMY_PATH,
taxonomy_base=DEFAULT_TAXONOMY_BASE,
output_dir=DEFAULT_GENERATED_FILES_OUTPUT_DIR,
Expand Down
91 changes: 17 additions & 74 deletions cli/generator/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
# Third Party
from jinja2 import Template
from rouge_score import rouge_scorer
import gitdb
import click
import tqdm
import yaml

try:
# Third Party
import git
except ImportError:
pass

# Local
from ..utils import get_taxonomy_diff
from . import utils

DEFAULT_PROMPT_TEMPLATE = """\
Expand Down Expand Up @@ -325,7 +320,6 @@ def generate_data(
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
seed_tasks_path: Optional[str] = None,
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
Expand Down Expand Up @@ -355,10 +349,7 @@ def generate_data(
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

seeds = len(seed_instruction_data)
logger.debug(
f"Loaded {seeds} human-written seed instructions from "
f"{taxonomy or seed_tasks_path}"
)
logger.debug(f"Loaded {seeds} human-written seed instructions from {taxonomy}")
if not seeds:
raise SystemExit("Nothing to generate. Exiting.")

Expand All @@ -370,13 +361,20 @@ def unescape(s):
user = seed_example["instruction"]
if len(seed_example["input"]) > 0:
user += "\n" + seed_example["input"]
test_data.append(
{
"system": utils.SYSTEM_PROMPT,
"user": unescape(user),
"assistant": unescape(seed_example["output"]),
}
)
try:
test_data.append(
{
"system": utils.SYSTEM_PROMPT,
"user": unescape(user),
"assistant": unescape(seed_example["output"]),
}
)
except TypeError as exc:
click.secho(
f"Error reading seed examples: {exc}. Please make sure your answers are verbose enough.",
fg="red",
)
raise click.exceptions.Exit(1)

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand Down Expand Up @@ -524,61 +522,6 @@ def unescape(s):
logger.info(f"Generation took {generate_duration:.2f}s")


def istaxonomyfile(fn):
topleveldir = fn.split("/")[0]
if fn.endswith(".yaml") and topleveldir in ["compositional_skills", "knowledge"]:
return True
return False


def get_taxonomy_diff(repo="taxonomy", base="origin/main"):
repo = git.Repo(repo)
untracked_files = [u for u in repo.untracked_files if istaxonomyfile(u)]

branches = [b.name for b in repo.branches]

head_commit = None
if "/" in base:
re_git_branch = re.compile(f"remotes/{base}$", re.MULTILINE)
elif base in branches:
re_git_branch = re.compile(f"{base}$", re.MULTILINE)
else:
try:
head_commit = repo.commit(base)
except gitdb.exc.BadName as exc:
raise SystemExit(
yaml.YAMLError(
f'Couldn\'t find the taxonomy git ref "{base}" from the current HEAD'
)
) from exc

# Move backwards from HEAD until we find the first commit that is part of base
# then we can take our diff from there
current_commit = repo.commit("HEAD")
while not head_commit:
branches = repo.git.branch("-a", "--contains", current_commit.hexsha)
if re_git_branch.findall(branches):
head_commit = current_commit
break
try:
current_commit = current_commit.parents[0]
except IndexError as exc:
raise SystemExit(
yaml.YAMLError(
f'Couldn\'t find the taxonomy base branch "{base}" from the current HEAD'
)
) from exc

modified_files = [
d.b_path
for d in head_commit.diff(None)
if not d.deleted_file and istaxonomyfile(d.b_path)
]

updated_taxonomy_files = list(set(untracked_files + modified_files))
return updated_taxonomy_files


# pylint: disable=broad-exception-caught
def read_taxonomy_file(logger, file_path, yaml_rules: Optional[str] = None):
# pylint: disable=C0415
Expand Down
34 changes: 21 additions & 13 deletions cli/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,29 @@ def serve(ctx, model_path, gpu_layers, num_threads, max_ctx_size):
@cli.command()
@click.option(
"--model",
default=config.DEFAULT_MODEL,
show_default=True,
help="Name of the model used during generation.",
)
@click.option(
"--num-cpus",
type=click.INT,
help="Number of processes to use. Defaults to 10.",
help="Number of processes to use.",
default=config.DEFAULT_NUM_CPUS,
show_default=True,
)
@click.option(
"--num-instructions",
type=click.INT,
help="Number of instructions to generate. Defaults to 100.",
help="Number of instructions to generate.",
default=config.DEFAULT_NUM_INSTRUCTIONS,
show_default=True,
)
@click.option(
"--taxonomy-path",
type=click.Path(),
default=config.DEFAULT_TAXONOMY_PATH,
show_default=True,
help=f"Path to {config.DEFAULT_TAXONOMY_REPO} clone or local file path.",
)
@click.option(
Expand All @@ -338,13 +346,9 @@ def serve(ctx, model_path, gpu_layers, num_threads, max_ctx_size):
@click.option(
"--output-dir",
type=click.Path(),
default=config.DEFAULT_GENERATED_FILES_OUTPUT_DIR,
help="Path to output generated files.",
)
@click.option(
"--seed-file",
type=click.Path(),
help="Path to a seed file.",
)
@click.option(
"--rouge-threshold",
type=click.FLOAT,
Expand Down Expand Up @@ -384,7 +388,6 @@ def generate(
taxonomy_path,
taxonomy_base,
output_dir,
seed_file,
rouge_threshold,
quiet,
endpoint_url,
Expand All @@ -399,6 +402,12 @@ def generate(
from .server import ensure_server

server_process = None
logger = logging.getLogger("TODO")
prompt_file_path = config.DEFAULT_PROMPT_FILE
if ctx.obj is not None:
logger = ctx.obj.logger
prompt_file_path = ctx.obj.config.generate.prompt_file

if endpoint_url:
api_base = endpoint_url
else:
Expand All @@ -409,11 +418,11 @@ def generate(
if not api_base:
api_base = ctx.obj.config.serve.api_base()
try:
ctx.obj.logger.info(
f"Generating model '{model}' using {num_cpus} cpus, taxonomy: '{taxonomy_path}' and seed '{seed_file}' against {api_base} server"
click.echo(
f"Generating synthetic data using '{model}' model, taxonomy:'{taxonomy_path}' against {api_base} server"
)
generate_data(
logger=ctx.obj.logger,
logger=logger,
api_base=api_base,
api_key=api_key,
model_name=model,
Expand All @@ -422,8 +431,7 @@ def generate(
taxonomy=taxonomy_path,
taxonomy_base=taxonomy_base,
output_dir=output_dir,
prompt_file_path=ctx.obj.config.generate.prompt_file,
seed_tasks_path=seed_file,
prompt_file_path=prompt_file_path,
rouge_threshold=rouge_threshold,
console_output=not quiet,
yaml_rules=yaml_rules,
Expand Down
59 changes: 59 additions & 0 deletions cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import functools
import os
import platform
import re
import subprocess

# Third Party
import click
import git
import gitdb
import yaml


def macos_requirement(echo_func, exit_exception):
Expand Down Expand Up @@ -91,3 +95,58 @@ def lab_check_callback(*args, **kwargs):

lab_check.callback = lab_check_callback
cli.add_command(lab_check)


def istaxonomyfile(fn):
topleveldir = fn.split("/")[0]
if fn.endswith(".yaml") and topleveldir in ["compositional_skills", "knowledge"]:
return True
return False


def get_taxonomy_diff(repo="taxonomy", base="origin/main"):
repo = git.Repo(repo)
untracked_files = [u for u in repo.untracked_files if istaxonomyfile(u)]

branches = [b.name for b in repo.branches]

head_commit = None
if "/" in base:
re_git_branch = re.compile(f"remotes/{base}$", re.MULTILINE)
elif base in branches:
re_git_branch = re.compile(f"{base}$", re.MULTILINE)
else:
try:
head_commit = repo.commit(base)
except gitdb.exc.BadName as exc:
raise SystemExit(
yaml.YAMLError(
f'Couldn\'t find the taxonomy git ref "{base}" from the current HEAD'
)
) from exc

# Move backwards from HEAD until we find the first commit that is part of base
# then we can take our diff from there
current_commit = repo.commit("HEAD")
while not head_commit:
branches = repo.git.branch("-a", "--contains", current_commit.hexsha)
if re_git_branch.findall(branches):
head_commit = current_commit
break
try:
current_commit = current_commit.parents[0]
except IndexError as exc:
raise SystemExit(
yaml.YAMLError(
f'Couldn\'t find the taxonomy base branch "{base}" from the current HEAD'
)
) from exc

modified_files = [
d.b_path
for d in head_commit.diff(None)
if not d.deleted_file and istaxonomyfile(d.b_path)
]

updated_taxonomy_files = list(set(untracked_files + modified_files))
return updated_taxonomy_files
2 changes: 1 addition & 1 deletion tests/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_untracked(self, rel_path: str, contents: Optional[bytes] = None) -> N
"""Create a new untracked file in the repository.
Args:
rel_path (str): Relative path (from repository root) to the file.
rel_path (str): Relative path (from repository root) to the file.
contents (bytes): (optional) Byte string to be written to the file.
"""
assert not Path(rel_path).is_absolute()
Expand Down
Loading

0 comments on commit 19b9f47

Please sign in to comment.