Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple the prompt set type (practice, official) from its file name #801

Merged
merged 8 commits into from
Jan 21, 2025
4 changes: 3 additions & 1 deletion src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import List, Sequence

import casefy
from modelgauge.sut import PromptResponseSUT
from modelgauge.locales import DEFAULT_LOCALE, validate_locale
from modelgauge.prompt_sets import validate_prompt_set
from modelgauge.sut import PromptResponseSUT

from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
Expand Down Expand Up @@ -134,6 +135,7 @@ class GeneralPurposeAiChatBenchmarkV1(BenchmarkDefinition):

def __init__(self, locale: str, prompt_set: str, evaluator: str = "default"):
validate_locale(locale)
validate_prompt_set(prompt_set)
self.locale = locale
self.prompt_set = prompt_set
self.evaluator = evaluator
Expand Down
2 changes: 2 additions & 0 deletions src/modelbench/hazards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import casefy
from modelgauge.base_test import BaseTest, PromptResponseTest
from modelgauge.locales import display_for, validate_locale
from modelgauge.prompt_sets import validate_prompt_set
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.test_registry import TESTS
Expand Down Expand Up @@ -67,6 +68,7 @@ class SafeHazardV1(HazardDefinition, ABC):
def __init__(self, hazard_key: str, locale: str, prompt_set: str, evaluator="default"):
assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}."
validate_locale(locale)
validate_prompt_set(prompt_set)
self.hazard_key = hazard_key
self.locale = locale
self.prompt_set = prompt_set
Expand Down
13 changes: 7 additions & 6 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from click import echo
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.locales import DEFAULT_LOCALE, EN_US, LOCALES, validate_locale
from modelgauge.locales import DEFAULT_LOCALE, LOCALES, PUBLISHED_LOCALES, validate_locale
from modelgauge.prompt_sets import PROMPT_SETS, validate_prompt_set
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS
from rich.console import Console
from rich.table import Table

Expand Down Expand Up @@ -230,9 +230,10 @@ def ensure_ensemble_annotators_loaded():
return False


def get_benchmark(version: str, locale: str, prompt_set: str, evaluator) -> BenchmarkDefinition:
def get_benchmark(version: str, locale: str, prompt_set: str, evaluator: str = "default") -> BenchmarkDefinition:
assert version == "1.0", ValueError(f"Version {version} is not supported.")
validate_locale(locale)
validate_prompt_set(prompt_set)
# TODO: Should probably also check that user has all the necessary secrets here e.g. can they run "official"?
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
Expand Down Expand Up @@ -345,9 +346,9 @@ def update_standards_to(standards_file):
exit(1)

benchmarks = []
for l in [EN_US]:
for prompt_set in PROMPT_SETS:
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(l, prompt_set, "ensemble"))
for locale in PUBLISHED_LOCALES:
for prompt_set in PROMPT_SETS.keys():
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, "ensemble"))
run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None)
all_hazard_numeric_scores = defaultdict(list)
for _, scores_by_sut in run_result.benchmark_scores.items():
Expand Down
2 changes: 2 additions & 0 deletions src/modelgauge/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

# add the other languages after we have official and practice prompt sets
LOCALES = (EN_US, FR_FR)
# all the languages we have official and practice prompt sets for
PUBLISHED_LOCALES = (EN_US,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the set of locales that we can run in modelbench?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOCALES should be runnable in modelbench. PUBLISHED_LOCALES is there to avoid hardcoded exceptions deeper in the code, e.g. the funky exceptions we had to put in to run French practice tests.



def is_valid(locale: str) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions src/modelgauge/prompt_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


# file name format:
# {prefix}_{version}_{type}(_{locale})_prompt_set_release
PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "airr_official_1.0_heldback_prompt_set_release",
"practice_fr_fr": "airr_official_1.0_practice_fr_fr_prompt_set_release",
}
PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org"


def prompt_set_file_base_name(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> str:
filename = prompt_sets.get(prompt_set, None)
return filename


def validate_prompt_set(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> bool:
filename = prompt_set_file_base_name(prompt_set, prompt_sets)
if not filename:
raise ValueError(f"Invalid prompt set {prompt_set}. Must be one of {prompt_sets.keys()}.")
return True
58 changes: 16 additions & 42 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,21 @@
from modelgauge.external_data import ExternalData, WebData
from modelgauge.locales import EN_US, LOCALES, validate_locale
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.prompt_sets import ( # usort: skip
PROMPT_SET_DOWNLOAD_HOST,
PROMPT_SETS,
ModellabFileDownloadToken,
prompt_set_file_base_name,
validate_prompt_set,
)
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.test_decorator import modelgauge_test
from modelgauge.test_registry import TESTS
from pydantic import BaseModel


# TODO: agree on a standard for this
# The strings in PROMPT_SETS are additionally decorated with the locale string in lowercase
# upon download. This is to support downloading the "same" prompt set for different locales
# and keep all of them downloaded locally.
PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "airr_official_1.0_heldback_prompt_set_release",
}
PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org"


class SafeTestItemContext(BaseModel):
persona_type: str

Expand All @@ -43,16 +39,6 @@ class SafeTestResult(BaseModel):
persona_results: Mapping[str, PersonaResult]


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


class SafePersonasVersion1(str, Enum):
NORMAL = "normal"
SKILLED = "skilled"
Expand Down Expand Up @@ -98,7 +84,7 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC):

@staticmethod
def create_uid(hazard: str, locale: str, prompt_set: str, evaluator=None):
assert prompt_set in PROMPT_SETS
validate_prompt_set(prompt_set)
validate_locale(locale)
if evaluator is None or evaluator == "default":
postfix = ""
Expand All @@ -123,13 +109,13 @@ def __init__(
assert len(set(persona_types)) == len(
persona_types
), f"Must specify a unique set of persona types, but got {persona_types}"
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
validate_prompt_set(prompt_set)
validate_locale(locale)

self.hazard = hazard
self.locale = locale
self.persona_types = persona_types
self.prompt_set_file_name = PROMPT_SETS[prompt_set]
self.prompt_set_file_base_name = prompt_set_file_base_name(prompt_set)
self.token = token

@classmethod
Expand All @@ -138,23 +124,14 @@ def _check_annotators(cls):
if not hasattr(cls, "annotators") or cls.annotators is None:
raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.")

def __localize_filename(self) -> str:
# TODO implement a better standard for prompt file names by locale
if self.locale != EN_US:
tail = f"_{self.locale}"
else:
tail = ""
filename = f"{self.prompt_set_file_name}{tail}"
return filename

def get_dependencies(self) -> Mapping[str, ExternalData]:
prompt_set_base_download_url = f"https://{PROMPT_SET_DOWNLOAD_HOST}/files/download"
headers = None
if self.token is not None:
headers = {"auth-token": self.token.value}
source_url = f"{prompt_set_base_download_url}/{self.__localize_filename()}.csv"
source_url = f"{prompt_set_base_download_url}/{self.prompt_set_file_base_name}.csv"
return {
self.prompt_set_file_name: WebData(
self.prompt_set_file_base_name: WebData(
source_url=source_url,
headers=headers,
)
Expand All @@ -171,7 +148,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
temperature=0.01,
)
test_items: List[TestItem] = []
data_file = dependency_helper.get_local_path(self.prompt_set_file_name)
data_file = dependency_helper.get_local_path(self.prompt_set_file_base_name)

# We're not using a context manager here for better exception handling,
# so the operator has some insight into why things aren't working.
Expand Down Expand Up @@ -247,12 +224,9 @@ class SafeTestVersion1(BaseSafeTestVersion1):
def register_tests(cls, evaluator=None):
for locale in LOCALES:
for hazard in cls.hazards:
for prompt_set in PROMPT_SETS:
for prompt_set in PROMPT_SETS.keys():
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
token = None
# only practice prompt sets in English are publicly available for now
if prompt_set == "official" or locale != EN_US:
token = InjectSecret(ModellabFileDownloadToken)
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)


Expand Down
25 changes: 6 additions & 19 deletions tests/modelbench_tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,16 @@

import pytest

from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
GeneralPurposeAiChatBenchmarkV1,
)
from modelbench.hazards import (
HazardDefinition,
HazardScore,
SafeHazardV1,
STANDARDS,
)
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1
from modelbench.hazards import STANDARDS, HazardDefinition, HazardScore, SafeHazardV1 # usort: skip
from modelbench.scoring import ValueEstimate
from modelgauge.base_test import BaseTest
from modelgauge.locales import EN_US

from modelgauge.prompt_sets import PROMPT_SETS
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.tests.safe_v1 import (
PROMPT_SETS,
PersonaResult,
SafeTestResult,
SafeTestVersion1,
SafePersonasVersion1,
)
from modelgauge.tests.safe_v1 import PersonaResult, SafePersonasVersion1, SafeTestResult, SafeTestVersion1


@pytest.mark.parametrize("ai", ("ai", "AI", "aI", "Ai"))
Expand Down Expand Up @@ -64,7 +51,7 @@ def test_benchmark_v1_definition_basics(prompt_set, fake_secrets):
assert hazard.hazard_key == hazard_key
assert hazard.locale == EN_US
assert hazard.prompt_set == prompt_set
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_base_name


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/modelbench_tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_v1_hazard_definition_without_tests_loaded():

def test_v1_hazard_definition_with_tests_loaded():
hazard = SafeHazardV1("dfm", EN_US, "practice")
hazard.tests({"together": {"api_key": "ignored"}})
hazard.tests({"together": {"api_key": "fake"}, "modellab_files": {"token": "fake"}})
j = encode_and_parse(hazard)
assert j["uid"] == hazard.uid
assert j["tests"] == ["safe-dfm-en_us-practice-1.0"]
Expand Down
24 changes: 13 additions & 11 deletions tests/modelbench_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
from click.testing import CliRunner

from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner
from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
GeneralPurposeAiChatBenchmarkV1,
)
from modelbench.hazards import HazardScore, HazardDefinition, SafeHazardV1
from modelbench.run import benchmark, cli, find_suts_for_sut_argument, get_benchmark
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1
from modelbench.run import cli, find_suts_for_sut_argument, get_benchmark
from modelbench.scoring import ValueEstimate
from modelgauge.base_test import PromptResponseTest
from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES, display_for, is_valid
from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES
from modelgauge.prompt_sets import PROMPT_SETS
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.sut import PromptResponseSUT
from modelgauge.tests.safe_v1 import PROMPT_SETS

from modelgauge_tests.fake_sut import FakeSUT

Expand Down Expand Up @@ -138,7 +134,10 @@ def test_benchmark_basic_run_produces_json(
if prompt_set is not None:
benchmark_options.extend(["--prompt-set", prompt_set])
benchmark = get_benchmark(
version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default"
version,
locale if locale else DEFAULT_LOCALE,
prompt_set if prompt_set else "practice",
"default",
)
command_options = [
"benchmark",
Expand Down Expand Up @@ -173,7 +172,10 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr
if prompt_set is not None:
benchmark_options.extend(["--prompt-set", prompt_set])
benchmark = get_benchmark(
version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default"
version,
locale if locale else DEFAULT_LOCALE,
prompt_set if prompt_set else "practice",
"default",
)

mock = MagicMock(return_value=[self.mock_score("fake-2", benchmark), self.mock_score("fake-2", benchmark)])
Expand Down
3 changes: 2 additions & 1 deletion tests/modelgauge_tests/fake_dependency_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import csv
import io
import os
from modelgauge.dependency_helper import DependencyHelper
from typing import List, Mapping

from modelgauge.dependency_helper import DependencyHelper


class FakeDependencyHelper(DependencyHelper):
"""Test version of Dependency helper that lets you set the text in files.
Expand Down
19 changes: 19 additions & 0 deletions tests/modelgauge_tests/test_prompt_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from modelgauge.prompt_sets import (
PROMPT_SETS,
prompt_set_file_base_name,
validate_prompt_set,
) # usort: skip


def test_file_base_name():
assert prompt_set_file_base_name("bad") is None
assert prompt_set_file_base_name("practice") == PROMPT_SETS["practice"]
assert prompt_set_file_base_name("practice", PROMPT_SETS) == PROMPT_SETS["practice"]


def test_validate_prompt_set():
for s in PROMPT_SETS.keys():
assert validate_prompt_set(s, PROMPT_SETS)
with pytest.raises(ValueError):
validate_prompt_set("should raise")
Loading
Loading