Skip to content

Commit

Permalink
Decouple the prompt set type (practice, official) from its file name (#…
Browse files Browse the repository at this point in the history
…801)

* decouple the file name from the prompt type

* use new prompt set logic

* decouple prompt set type and locale from file name

* use new prompt set logic that decouples the prompt set type and locale from its filename

* noop; formatting

* change prompt set key for consistency

* inject the prompt sets instead of using the global prompt set dict, for better testing

* remove prompt_sets from where it doesn't belong conceptually
  • Loading branch information
rogthefrog authored Jan 21, 2025
1 parent 5943f58 commit 9337d40
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 89 deletions.
4 changes: 3 additions & 1 deletion src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import List, Sequence

import casefy
from modelgauge.sut import PromptResponseSUT
from modelgauge.locales import DEFAULT_LOCALE, validate_locale
from modelgauge.prompt_sets import validate_prompt_set
from modelgauge.sut import PromptResponseSUT

from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
Expand Down Expand Up @@ -134,6 +135,7 @@ class GeneralPurposeAiChatBenchmarkV1(BenchmarkDefinition):

def __init__(self, locale: str, prompt_set: str, evaluator: str = "default"):
validate_locale(locale)
validate_prompt_set(prompt_set)
self.locale = locale
self.prompt_set = prompt_set
self.evaluator = evaluator
Expand Down
2 changes: 2 additions & 0 deletions src/modelbench/hazards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import casefy
from modelgauge.base_test import BaseTest, PromptResponseTest
from modelgauge.locales import display_for, validate_locale
from modelgauge.prompt_sets import validate_prompt_set
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.test_registry import TESTS
Expand Down Expand Up @@ -67,6 +68,7 @@ class SafeHazardV1(HazardDefinition, ABC):
def __init__(self, hazard_key: str, locale: str, prompt_set: str, evaluator="default"):
assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}."
validate_locale(locale)
validate_prompt_set(prompt_set)
self.hazard_key = hazard_key
self.locale = locale
self.prompt_set = prompt_set
Expand Down
13 changes: 7 additions & 6 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from click import echo
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.locales import DEFAULT_LOCALE, EN_US, LOCALES, validate_locale
from modelgauge.locales import DEFAULT_LOCALE, LOCALES, PUBLISHED_LOCALES, validate_locale
from modelgauge.prompt_sets import PROMPT_SETS, validate_prompt_set
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS
from rich.console import Console
from rich.table import Table

Expand Down Expand Up @@ -230,9 +230,10 @@ def ensure_ensemble_annotators_loaded():
return False


def get_benchmark(version: str, locale: str, prompt_set: str, evaluator) -> BenchmarkDefinition:
def get_benchmark(version: str, locale: str, prompt_set: str, evaluator: str = "default") -> BenchmarkDefinition:
assert version == "1.0", ValueError(f"Version {version} is not supported.")
validate_locale(locale)
validate_prompt_set(prompt_set)
# TODO: Should probably also check that user has all the necessary secrets here e.g. can they run "official"?
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
Expand Down Expand Up @@ -345,9 +346,9 @@ def update_standards_to(standards_file):
exit(1)

benchmarks = []
for l in [EN_US]:
for prompt_set in PROMPT_SETS:
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(l, prompt_set, "ensemble"))
for locale in PUBLISHED_LOCALES:
for prompt_set in PROMPT_SETS.keys():
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, "ensemble"))
run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None)
all_hazard_numeric_scores = defaultdict(list)
for _, scores_by_sut in run_result.benchmark_scores.items():
Expand Down
2 changes: 2 additions & 0 deletions src/modelgauge/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

# add the other languages after we have official and practice prompt sets
LOCALES = (EN_US, FR_FR)
# all the languages we have official and practice prompt sets for
PUBLISHED_LOCALES = (EN_US,)


def is_valid(locale: str) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions src/modelgauge/prompt_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


# file name format:
# {prefix}_{version}_{type}(_{locale})_prompt_set_release
PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "airr_official_1.0_heldback_prompt_set_release",
"practice_fr_fr": "airr_official_1.0_practice_fr_fr_prompt_set_release",
}
PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org"


def prompt_set_file_base_name(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> str:
filename = prompt_sets.get(prompt_set, None)
return filename


def validate_prompt_set(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> bool:
filename = prompt_set_file_base_name(prompt_set, prompt_sets)
if not filename:
raise ValueError(f"Invalid prompt set {prompt_set}. Must be one of {prompt_sets.keys()}.")
return True
58 changes: 16 additions & 42 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,21 @@
from modelgauge.external_data import ExternalData, WebData
from modelgauge.locales import EN_US, LOCALES, validate_locale
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.prompt_sets import ( # usort: skip
PROMPT_SET_DOWNLOAD_HOST,
PROMPT_SETS,
ModellabFileDownloadToken,
prompt_set_file_base_name,
validate_prompt_set,
)
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.test_decorator import modelgauge_test
from modelgauge.test_registry import TESTS
from pydantic import BaseModel


# TODO: agree on a standard for this
# The strings in PROMPT_SETS are additionally decorated with the locale string in lowercase
# upon download. This is to support downloading the "same" prompt set for different locales
# and keep all of them downloaded locally.
PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "airr_official_1.0_heldback_prompt_set_release",
}
PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org"


class SafeTestItemContext(BaseModel):
persona_type: str

Expand All @@ -43,16 +39,6 @@ class SafeTestResult(BaseModel):
persona_results: Mapping[str, PersonaResult]


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


class SafePersonasVersion1(str, Enum):
NORMAL = "normal"
SKILLED = "skilled"
Expand Down Expand Up @@ -98,7 +84,7 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC):

@staticmethod
def create_uid(hazard: str, locale: str, prompt_set: str, evaluator=None):
assert prompt_set in PROMPT_SETS
validate_prompt_set(prompt_set)
validate_locale(locale)
if evaluator is None or evaluator == "default":
postfix = ""
Expand All @@ -123,13 +109,13 @@ def __init__(
assert len(set(persona_types)) == len(
persona_types
), f"Must specify a unique set of persona types, but got {persona_types}"
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
validate_prompt_set(prompt_set)
validate_locale(locale)

self.hazard = hazard
self.locale = locale
self.persona_types = persona_types
self.prompt_set_file_name = PROMPT_SETS[prompt_set]
self.prompt_set_file_base_name = prompt_set_file_base_name(prompt_set)
self.token = token

@classmethod
Expand All @@ -138,23 +124,14 @@ def _check_annotators(cls):
if not hasattr(cls, "annotators") or cls.annotators is None:
raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.")

def __localize_filename(self) -> str:
# TODO implement a better standard for prompt file names by locale
if self.locale != EN_US:
tail = f"_{self.locale}"
else:
tail = ""
filename = f"{self.prompt_set_file_name}{tail}"
return filename

def get_dependencies(self) -> Mapping[str, ExternalData]:
prompt_set_base_download_url = f"https://{PROMPT_SET_DOWNLOAD_HOST}/files/download"
headers = None
if self.token is not None:
headers = {"auth-token": self.token.value}
source_url = f"{prompt_set_base_download_url}/{self.__localize_filename()}.csv"
source_url = f"{prompt_set_base_download_url}/{self.prompt_set_file_base_name}.csv"
return {
self.prompt_set_file_name: WebData(
self.prompt_set_file_base_name: WebData(
source_url=source_url,
headers=headers,
)
Expand All @@ -171,7 +148,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
temperature=0.01,
)
test_items: List[TestItem] = []
data_file = dependency_helper.get_local_path(self.prompt_set_file_name)
data_file = dependency_helper.get_local_path(self.prompt_set_file_base_name)

# We're not using a context manager here for better exception handling,
# so the operator has some insight into why things aren't working.
Expand Down Expand Up @@ -247,12 +224,9 @@ class SafeTestVersion1(BaseSafeTestVersion1):
def register_tests(cls, evaluator=None):
for locale in LOCALES:
for hazard in cls.hazards:
for prompt_set in PROMPT_SETS:
for prompt_set in PROMPT_SETS.keys():
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
token = None
# only practice prompt sets in English are publicly available for now
if prompt_set == "official" or locale != EN_US:
token = InjectSecret(ModellabFileDownloadToken)
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)


Expand Down
25 changes: 6 additions & 19 deletions tests/modelbench_tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,16 @@

import pytest

from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
GeneralPurposeAiChatBenchmarkV1,
)
from modelbench.hazards import (
HazardDefinition,
HazardScore,
SafeHazardV1,
STANDARDS,
)
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1
from modelbench.hazards import STANDARDS, HazardDefinition, HazardScore, SafeHazardV1 # usort: skip
from modelbench.scoring import ValueEstimate
from modelgauge.base_test import BaseTest
from modelgauge.locales import EN_US

from modelgauge.prompt_sets import PROMPT_SETS
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.tests.safe_v1 import (
PROMPT_SETS,
PersonaResult,
SafeTestResult,
SafeTestVersion1,
SafePersonasVersion1,
)
from modelgauge.tests.safe_v1 import PersonaResult, SafePersonasVersion1, SafeTestResult, SafeTestVersion1


@pytest.mark.parametrize("ai", ("ai", "AI", "aI", "Ai"))
Expand Down Expand Up @@ -64,7 +51,7 @@ def test_benchmark_v1_definition_basics(prompt_set, fake_secrets):
assert hazard.hazard_key == hazard_key
assert hazard.locale == EN_US
assert hazard.prompt_set == prompt_set
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_base_name


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/modelbench_tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_v1_hazard_definition_without_tests_loaded():

def test_v1_hazard_definition_with_tests_loaded():
hazard = SafeHazardV1("dfm", EN_US, "practice")
hazard.tests({"together": {"api_key": "ignored"}})
hazard.tests({"together": {"api_key": "fake"}, "modellab_files": {"token": "fake"}})
j = encode_and_parse(hazard)
assert j["uid"] == hazard.uid
assert j["tests"] == ["safe-dfm-en_us-practice-1.0"]
Expand Down
24 changes: 13 additions & 11 deletions tests/modelbench_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
from click.testing import CliRunner

from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner
from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
GeneralPurposeAiChatBenchmarkV1,
)
from modelbench.hazards import HazardScore, HazardDefinition, SafeHazardV1
from modelbench.run import benchmark, cli, find_suts_for_sut_argument, get_benchmark
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1
from modelbench.run import cli, find_suts_for_sut_argument, get_benchmark
from modelbench.scoring import ValueEstimate
from modelgauge.base_test import PromptResponseTest
from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES, display_for, is_valid
from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES
from modelgauge.prompt_sets import PROMPT_SETS
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.sut import PromptResponseSUT
from modelgauge.tests.safe_v1 import PROMPT_SETS

from modelgauge_tests.fake_sut import FakeSUT

Expand Down Expand Up @@ -138,7 +134,10 @@ def test_benchmark_basic_run_produces_json(
if prompt_set is not None:
benchmark_options.extend(["--prompt-set", prompt_set])
benchmark = get_benchmark(
version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default"
version,
locale if locale else DEFAULT_LOCALE,
prompt_set if prompt_set else "practice",
"default",
)
command_options = [
"benchmark",
Expand Down Expand Up @@ -173,7 +172,10 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr
if prompt_set is not None:
benchmark_options.extend(["--prompt-set", prompt_set])
benchmark = get_benchmark(
version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default"
version,
locale if locale else DEFAULT_LOCALE,
prompt_set if prompt_set else "practice",
"default",
)

mock = MagicMock(return_value=[self.mock_score("fake-2", benchmark), self.mock_score("fake-2", benchmark)])
Expand Down
3 changes: 2 additions & 1 deletion tests/modelgauge_tests/fake_dependency_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import csv
import io
import os
from modelgauge.dependency_helper import DependencyHelper
from typing import List, Mapping

from modelgauge.dependency_helper import DependencyHelper


class FakeDependencyHelper(DependencyHelper):
"""Test version of Dependency helper that lets you set the text in files.
Expand Down
19 changes: 19 additions & 0 deletions tests/modelgauge_tests/test_prompt_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from modelgauge.prompt_sets import (
PROMPT_SETS,
prompt_set_file_base_name,
validate_prompt_set,
) # usort: skip


def test_file_base_name():
assert prompt_set_file_base_name("bad") is None
assert prompt_set_file_base_name("practice") == PROMPT_SETS["practice"]
assert prompt_set_file_base_name("practice", PROMPT_SETS) == PROMPT_SETS["practice"]


def test_validate_prompt_set():
for s in PROMPT_SETS.keys():
assert validate_prompt_set(s, PROMPT_SETS)
with pytest.raises(ValueError):
validate_prompt_set("should raise")
Loading

0 comments on commit 9337d40

Please sign in to comment.