Skip to content

Commit

Permalink
New arg: customizable LLM creator
Browse files Browse the repository at this point in the history
Summary:
Previously, our testable LLMs were hardcoded to use the `llm.create` method, which made it difficult for people using different modules (LLMs) besides `llm`.

This diff introduces the capability to provide customized `create` functions.

Reviewed By: mbhatt1

Differential Revision: D52981139

fbshipit-source-id: 74ee93a0c5a865c930ea1794a772401c4e8a2770
  • Loading branch information
SimonWan authored and facebook-github-bot committed Jan 24, 2024
1 parent 9e3a5e7 commit c488d93
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions CybersecurityBenchmarks/benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from pathlib import Path

from typing import Callable

from . import llm

from .benchmark import Benchmark
Expand Down Expand Up @@ -59,6 +61,7 @@ def main(
default_stat_path: str | None = None,
default_use_precomputed_responses: bool = False,
default_only_generate_llm_responses: bool = False,
default_llm_creator: Callable[[str], llm.LLM] = llm.create,
) -> None:
parser = argparse.ArgumentParser(description="Run the Cybersecurity Benchmarks.")

Expand Down Expand Up @@ -163,7 +166,7 @@ def main(
materialized_llms_under_test = []
if not args.use_precomputed_responses:
materialized_llms_under_test = llms_under_test or [
llm.create(specification) for specification in args.llm_under_test
default_llm_creator(specification) for specification in args.llm_under_test
]
if len(materialized_llms_under_test) == 0:
raise ValueError("No LLMs were specified.")
Expand Down Expand Up @@ -205,8 +208,8 @@ def main(
else:
benchmark = MitreBenchmark(
llms_under_test=materialized_llms_under_test,
judge_llm=llm.create(args.judge_llm),
expansion_llm=llm.create(args.expansion_llm),
judge_llm=default_llm_creator(args.judge_llm),
expansion_llm=default_llm_creator(args.expansion_llm),
response_path=Path(args.response_path),
judge_response_path=judge_response_path,
stat_path=stat_response_path,
Expand Down

0 comments on commit c488d93

Please sign in to comment.