From 9337d407499acae29c1208f06dd302810fe23b39 Mon Sep 17 00:00:00 2001 From: Roger Date: Tue, 21 Jan 2025 10:37:16 -0800 Subject: [PATCH] Decouple the prompt set type (practice, official) from its file name (#801) * decouple the file name from the prompt type * use new prompt set logic * decouple prompt set type and locale from file name * use new prompt set logic that decouples the prompt set type and locale from its filename * noop; formatting * change prompt set key for consistency * inject the prompt sets instead of using the global prompt set dict, for better testing * remove prompt_sets from where it doesn't belong conceptually --- src/modelbench/benchmarks.py | 4 +- src/modelbench/hazards.py | 2 + src/modelbench/run.py | 13 +++-- src/modelgauge/locales.py | 2 + src/modelgauge/prompt_sets.py | 33 +++++++++++ src/modelgauge/tests/safe_v1.py | 58 +++++-------------- tests/modelbench_tests/test_benchmark.py | 25 ++------ tests/modelbench_tests/test_record.py | 2 +- tests/modelbench_tests/test_run.py | 24 ++++---- .../fake_dependency_helper.py | 3 +- tests/modelgauge_tests/test_prompt_sets.py | 19 ++++++ tests/modelgauge_tests/test_safe.py | 17 +++--- 12 files changed, 113 insertions(+), 89 deletions(-) create mode 100644 src/modelgauge/prompt_sets.py create mode 100644 tests/modelgauge_tests/test_prompt_sets.py diff --git a/src/modelbench/benchmarks.py b/src/modelbench/benchmarks.py index 2067b28b..60a02c97 100644 --- a/src/modelbench/benchmarks.py +++ b/src/modelbench/benchmarks.py @@ -5,8 +5,9 @@ from typing import List, Sequence import casefy -from modelgauge.sut import PromptResponseSUT from modelgauge.locales import DEFAULT_LOCALE, validate_locale +from modelgauge.prompt_sets import validate_prompt_set +from modelgauge.sut import PromptResponseSUT from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade @@ -134,6 +135,7 @@ class GeneralPurposeAiChatBenchmarkV1(BenchmarkDefinition): def __init__(self, locale: str, prompt_set: str, evaluator: str = "default"): validate_locale(locale) + validate_prompt_set(prompt_set) self.locale = locale self.prompt_set = prompt_set self.evaluator = evaluator diff --git a/src/modelbench/hazards.py b/src/modelbench/hazards.py index b51a051f..87b33319 100644 --- a/src/modelbench/hazards.py +++ b/src/modelbench/hazards.py @@ -8,6 +8,7 @@ import casefy from modelgauge.base_test import BaseTest, PromptResponseTest from modelgauge.locales import display_for, validate_locale +from modelgauge.prompt_sets import validate_prompt_set from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets from modelgauge.test_registry import TESTS @@ -67,6 +68,7 @@ class SafeHazardV1(HazardDefinition, ABC): def __init__(self, hazard_key: str, locale: str, prompt_set: str, evaluator="default"): assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}." validate_locale(locale) + validate_prompt_set(prompt_set) self.hazard_key = hazard_key self.locale = locale self.prompt_set = prompt_set diff --git a/src/modelbench/run.py b/src/modelbench/run.py index f441c73b..63400e03 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -21,11 +21,11 @@ from click import echo from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config from modelgauge.load_plugins import load_plugins -from modelgauge.locales import DEFAULT_LOCALE, EN_US, LOCALES, validate_locale +from modelgauge.locales import DEFAULT_LOCALE, LOCALES, PUBLISHED_LOCALES, validate_locale +from modelgauge.prompt_sets import PROMPT_SETS, validate_prompt_set from modelgauge.sut import SUT from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS -from modelgauge.tests.safe_v1 import PROMPT_SETS from rich.console import Console from rich.table import Table @@ -230,9 +230,10 @@ def ensure_ensemble_annotators_loaded(): return False -def get_benchmark(version: str, locale: str, prompt_set: str, evaluator) -> BenchmarkDefinition: +def get_benchmark(version: str, locale: str, prompt_set: str, evaluator: str = "default") -> BenchmarkDefinition: assert version == "1.0", ValueError(f"Version {version} is not supported.") validate_locale(locale) + validate_prompt_set(prompt_set) # TODO: Should probably also check that user has all the necessary secrets here e.g. can they run "official"? if evaluator == "ensemble": if not ensure_ensemble_annotators_loaded(): @@ -345,9 +346,9 @@ def update_standards_to(standards_file): exit(1) benchmarks = [] - for l in [EN_US]: - for prompt_set in PROMPT_SETS: - benchmarks.append(GeneralPurposeAiChatBenchmarkV1(l, prompt_set, "ensemble")) + for locale in PUBLISHED_LOCALES: + for prompt_set in PROMPT_SETS.keys(): + benchmarks.append(GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, "ensemble")) run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None) all_hazard_numeric_scores = defaultdict(list) for _, scores_by_sut in run_result.benchmark_scores.items(): diff --git a/src/modelgauge/locales.py b/src/modelgauge/locales.py index e4b80c2b..9ad932ff 100644 --- a/src/modelgauge/locales.py +++ b/src/modelgauge/locales.py @@ -9,6 +9,8 @@ # add the other languages after we have official and practice prompt sets LOCALES = (EN_US, FR_FR) +# all the languages we have official and practice prompt sets for +PUBLISHED_LOCALES = (EN_US,) def is_valid(locale: str) -> bool: diff --git a/src/modelgauge/prompt_sets.py b/src/modelgauge/prompt_sets.py new file mode 100644 index 00000000..cf07f42e --- /dev/null +++ b/src/modelgauge/prompt_sets.py @@ -0,0 +1,33 @@ +from modelgauge.secret_values import RequiredSecret, SecretDescription + + +class ModellabFileDownloadToken(RequiredSecret): + @classmethod + def description(cls) -> SecretDescription: + return SecretDescription( + scope="modellab_files", + key="token", + instructions="Please ask MLCommons admin for permission.", + ) + + +# file name format: +# {prefix}_{version}_{type}(_{locale})_prompt_set_release +PROMPT_SETS = { + "practice": "airr_official_1.0_practice_prompt_set_release", + "official": "airr_official_1.0_heldback_prompt_set_release", + "practice_fr_fr": "airr_official_1.0_practice_fr_fr_prompt_set_release", +} +PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org" + + +def prompt_set_file_base_name(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> str: + filename = prompt_sets.get(prompt_set, None) + return filename + + +def validate_prompt_set(prompt_set: str, prompt_sets: dict = PROMPT_SETS) -> bool: + filename = prompt_set_file_base_name(prompt_set, prompt_sets) + if not filename: + raise ValueError(f"Invalid prompt set {prompt_set}. Must be one of {prompt_sets.keys()}.") + return True diff --git a/src/modelgauge/tests/safe_v1.py b/src/modelgauge/tests/safe_v1.py index 9a310aac..eb234ac8 100644 --- a/src/modelgauge/tests/safe_v1.py +++ b/src/modelgauge/tests/safe_v1.py @@ -11,7 +11,14 @@ from modelgauge.external_data import ExternalData, WebData from modelgauge.locales import EN_US, LOCALES, validate_locale from modelgauge.prompt import SUTOptions, TextPrompt -from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription +from modelgauge.prompt_sets import ( # usort: skip + PROMPT_SET_DOWNLOAD_HOST, + PROMPT_SETS, + ModellabFileDownloadToken, + prompt_set_file_base_name, + validate_prompt_set, +) +from modelgauge.secret_values import InjectSecret from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.test_decorator import modelgauge_test @@ -19,17 +26,6 @@ from pydantic import BaseModel -# TODO: agree on a standard for this -# The strings in PROMPT_SETS are additionally decorated with the locale string in lowercase -# upon download. This is to support downloading the "same" prompt set for different locales -# and keep all of them downloaded locally. -PROMPT_SETS = { - "practice": "airr_official_1.0_practice_prompt_set_release", - "official": "airr_official_1.0_heldback_prompt_set_release", -} -PROMPT_SET_DOWNLOAD_HOST = "ailuminate.mlcommons.org" - - class SafeTestItemContext(BaseModel): persona_type: str @@ -43,16 +39,6 @@ class SafeTestResult(BaseModel): persona_results: Mapping[str, PersonaResult] -class ModellabFileDownloadToken(RequiredSecret): - @classmethod - def description(cls) -> SecretDescription: - return SecretDescription( - scope="modellab_files", - key="token", - instructions="Please ask MLCommons admin for permission.", - ) - - class SafePersonasVersion1(str, Enum): NORMAL = "normal" SKILLED = "skilled" @@ -98,7 +84,7 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC): @staticmethod def create_uid(hazard: str, locale: str, prompt_set: str, evaluator=None): - assert prompt_set in PROMPT_SETS + validate_prompt_set(prompt_set) validate_locale(locale) if evaluator is None or evaluator == "default": postfix = "" @@ -123,13 +109,13 @@ def __init__( assert len(set(persona_types)) == len( persona_types ), f"Must specify a unique set of persona types, but got {persona_types}" - assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}." + validate_prompt_set(prompt_set) validate_locale(locale) self.hazard = hazard self.locale = locale self.persona_types = persona_types - self.prompt_set_file_name = PROMPT_SETS[prompt_set] + self.prompt_set_file_base_name = prompt_set_file_base_name(prompt_set) self.token = token @classmethod @@ -138,23 +124,14 @@ def _check_annotators(cls): if not hasattr(cls, "annotators") or cls.annotators is None: raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.") - def __localize_filename(self) -> str: - # TODO implement a better standard for prompt file names by locale - if self.locale != EN_US: - tail = f"_{self.locale}" - else: - tail = "" - filename = f"{self.prompt_set_file_name}{tail}" - return filename - def get_dependencies(self) -> Mapping[str, ExternalData]: prompt_set_base_download_url = f"https://{PROMPT_SET_DOWNLOAD_HOST}/files/download" headers = None if self.token is not None: headers = {"auth-token": self.token.value} - source_url = f"{prompt_set_base_download_url}/{self.__localize_filename()}.csv" + source_url = f"{prompt_set_base_download_url}/{self.prompt_set_file_base_name}.csv" return { - self.prompt_set_file_name: WebData( + self.prompt_set_file_base_name: WebData( source_url=source_url, headers=headers, ) @@ -171,7 +148,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem] temperature=0.01, ) test_items: List[TestItem] = [] - data_file = dependency_helper.get_local_path(self.prompt_set_file_name) + data_file = dependency_helper.get_local_path(self.prompt_set_file_base_name) # We're not using a context manager here for better exception handling, # so the operator has some insight into why things aren't working. @@ -247,12 +224,9 @@ class SafeTestVersion1(BaseSafeTestVersion1): def register_tests(cls, evaluator=None): for locale in LOCALES: for hazard in cls.hazards: - for prompt_set in PROMPT_SETS: + for prompt_set in PROMPT_SETS.keys(): test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator) - token = None - # only practice prompt sets in English are publicly available for now - if prompt_set == "official" or locale != EN_US: - token = InjectSecret(ModellabFileDownloadToken) + token = InjectSecret(ModellabFileDownloadToken) TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token) diff --git a/tests/modelbench_tests/test_benchmark.py b/tests/modelbench_tests/test_benchmark.py index 1a2afa31..6c961e76 100644 --- a/tests/modelbench_tests/test_benchmark.py +++ b/tests/modelbench_tests/test_benchmark.py @@ -5,29 +5,16 @@ import pytest -from modelbench.benchmarks import ( - BenchmarkDefinition, - BenchmarkScore, - GeneralPurposeAiChatBenchmarkV1, -) -from modelbench.hazards import ( - HazardDefinition, - HazardScore, - SafeHazardV1, - STANDARDS, -) +from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1 +from modelbench.hazards import STANDARDS, HazardDefinition, HazardScore, SafeHazardV1 # usort: skip from modelbench.scoring import ValueEstimate from modelgauge.base_test import BaseTest from modelgauge.locales import EN_US + +from modelgauge.prompt_sets import PROMPT_SETS from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets -from modelgauge.tests.safe_v1 import ( - PROMPT_SETS, - PersonaResult, - SafeTestResult, - SafeTestVersion1, - SafePersonasVersion1, -) +from modelgauge.tests.safe_v1 import PersonaResult, SafePersonasVersion1, SafeTestResult, SafeTestVersion1 @pytest.mark.parametrize("ai", ("ai", "AI", "aI", "Ai")) @@ -64,7 +51,7 @@ def test_benchmark_v1_definition_basics(prompt_set, fake_secrets): assert hazard.hazard_key == hazard_key assert hazard.locale == EN_US assert hazard.prompt_set == prompt_set - assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name + assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_base_name @pytest.mark.parametrize( diff --git a/tests/modelbench_tests/test_record.py b/tests/modelbench_tests/test_record.py index 4fd71a74..5c6481c2 100644 --- a/tests/modelbench_tests/test_record.py +++ b/tests/modelbench_tests/test_record.py @@ -81,7 +81,7 @@ def test_v1_hazard_definition_without_tests_loaded(): def test_v1_hazard_definition_with_tests_loaded(): hazard = SafeHazardV1("dfm", EN_US, "practice") - hazard.tests({"together": {"api_key": "ignored"}}) + hazard.tests({"together": {"api_key": "fake"}, "modellab_files": {"token": "fake"}}) j = encode_and_parse(hazard) assert j["uid"] == hazard.uid assert j["tests"] == ["safe-dfm-en_us-practice-1.0"] diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index afb632b6..fb042094 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -9,20 +9,16 @@ from click.testing import CliRunner from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner -from modelbench.benchmarks import ( - BenchmarkDefinition, - BenchmarkScore, - GeneralPurposeAiChatBenchmarkV1, -) -from modelbench.hazards import HazardScore, HazardDefinition, SafeHazardV1 -from modelbench.run import benchmark, cli, find_suts_for_sut_argument, get_benchmark +from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore, GeneralPurposeAiChatBenchmarkV1 +from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1 +from modelbench.run import cli, find_suts_for_sut_argument, get_benchmark from modelbench.scoring import ValueEstimate from modelgauge.base_test import PromptResponseTest -from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES, display_for, is_valid +from modelgauge.locales import DEFAULT_LOCALE, EN_US, FR_FR, LOCALES +from modelgauge.prompt_sets import PROMPT_SETS from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets from modelgauge.sut import PromptResponseSUT -from modelgauge.tests.safe_v1 import PROMPT_SETS from modelgauge_tests.fake_sut import FakeSUT @@ -138,7 +134,10 @@ def test_benchmark_basic_run_produces_json( if prompt_set is not None: benchmark_options.extend(["--prompt-set", prompt_set]) benchmark = get_benchmark( - version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default" + version, + locale if locale else DEFAULT_LOCALE, + prompt_set if prompt_set else "practice", + "default", ) command_options = [ "benchmark", @@ -173,7 +172,10 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr if prompt_set is not None: benchmark_options.extend(["--prompt-set", prompt_set]) benchmark = get_benchmark( - version, locale if locale else DEFAULT_LOCALE, prompt_set if prompt_set else "practice", "default" + version, + locale if locale else DEFAULT_LOCALE, + prompt_set if prompt_set else "practice", + "default", ) mock = MagicMock(return_value=[self.mock_score("fake-2", benchmark), self.mock_score("fake-2", benchmark)]) diff --git a/tests/modelgauge_tests/fake_dependency_helper.py b/tests/modelgauge_tests/fake_dependency_helper.py index 0af7f034..272c5f21 100644 --- a/tests/modelgauge_tests/fake_dependency_helper.py +++ b/tests/modelgauge_tests/fake_dependency_helper.py @@ -1,9 +1,10 @@ import csv import io import os -from modelgauge.dependency_helper import DependencyHelper from typing import List, Mapping +from modelgauge.dependency_helper import DependencyHelper + class FakeDependencyHelper(DependencyHelper): """Test version of Dependency helper that lets you set the text in files. diff --git a/tests/modelgauge_tests/test_prompt_sets.py b/tests/modelgauge_tests/test_prompt_sets.py new file mode 100644 index 00000000..e4698324 --- /dev/null +++ b/tests/modelgauge_tests/test_prompt_sets.py @@ -0,0 +1,19 @@ +import pytest +from modelgauge.prompt_sets import ( + PROMPT_SETS, + prompt_set_file_base_name, + validate_prompt_set, +) # usort: skip + + +def test_file_base_name(): + assert prompt_set_file_base_name("bad") is None + assert prompt_set_file_base_name("practice") == PROMPT_SETS["practice"] + assert prompt_set_file_base_name("practice", PROMPT_SETS) == PROMPT_SETS["practice"] + + +def test_validate_prompt_set(): + for s in PROMPT_SETS.keys(): + assert validate_prompt_set(s, PROMPT_SETS) + with pytest.raises(ValueError): + validate_prompt_set("should raise") diff --git a/tests/modelgauge_tests/test_safe.py b/tests/modelgauge_tests/test_safe.py index 7cc8b0a2..6acc3326 100644 --- a/tests/modelgauge_tests/test_safe.py +++ b/tests/modelgauge_tests/test_safe.py @@ -3,9 +3,10 @@ from modelgauge.auth.together_key import TogetherApiKey from modelgauge.locales import EN_US, FR_FR, LOCALES from modelgauge.prompt import TextPrompt +from modelgauge.prompt_sets import PROMPT_SETS from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem +from modelgauge.test_registry import TESTS from modelgauge.tests.safe_v1 import ( - PROMPT_SETS, BaseSafeTestVersion1, PersonaResult, SafePersonasVersion1, @@ -13,7 +14,6 @@ SafeTestResult, SafeTestVersion1, ) -from modelgauge.test_registry import TESTS from modelgauge_tests.fake_dependency_helper import FakeDependencyHelper, make_csv @@ -125,10 +125,11 @@ def test_make_test_items_skips_out_of_scope_prompts(self, tmpdir, safe_test): ["3", "a", self.hazard, FR_FR, "normal"], ["4", "a", self.hazard, FR_FR, "skilled"], ] - dependencies = {"fake-prompts": make_csv(csv_header, rows)} + _pps = PROMPT_SETS["practice"] + dependencies = {_pps: make_csv(csv_header, rows)} dependency_helper = FakeDependencyHelper(tmpdir, dependencies) - safe_test.prompt_set_file_name = "fake-prompts" + safe_test.prompt_set_file_name = _pps test_items = safe_test.make_test_items(dependency_helper) assert len(test_items) == 1 @@ -140,10 +141,11 @@ def test_make_test_items_skips_bad_hazards(self, tmpdir, safe_test): ["1", "a", self.hazard, EN_US, SafePersonasVersion1.NORMAL], ["2", "a", "wrong", EN_US, SafePersonasVersion1.NORMAL], ] - dependencies = {"fake-prompts": make_csv(csv_header, rows)} + pps = PROMPT_SETS["practice"] + dependencies = {pps: make_csv(csv_header, rows)} dependency_helper = FakeDependencyHelper(tmpdir, dependencies) - safe_test.prompt_set_file_name = "fake-prompts" + safe_test.prompt_set_file_name = pps test_items = safe_test.make_test_items(dependency_helper) assert len(test_items) == 1 @@ -186,10 +188,9 @@ def test_error_getting_undefined_annotators(self): def test_registered_tests(self): hazards = ["vcr", "ncr", "src", "cse", "dfm", "spc", "prv", "ipv", "iwp", "hte", "ssh", "sxc"] - prompt_sets = PROMPT_SETS.keys() for hazard in hazards: for locale in LOCALES: - for prompt_set in prompt_sets: + for prompt_set in PROMPT_SETS.keys(): uid = f"safe-{hazard}-{locale}-{prompt_set}-1.0" assert TESTS._lookup.get(uid)