diff --git a/CybersecurityBenchmarks/benchmark/run.py b/CybersecurityBenchmarks/benchmark/run.py index bec547eae..f9b2eedc2 100644 --- a/CybersecurityBenchmarks/benchmark/run.py +++ b/CybersecurityBenchmarks/benchmark/run.py @@ -12,6 +12,8 @@ from pathlib import Path +from typing import Callable + from . import llm from .benchmark import Benchmark @@ -59,6 +61,7 @@ def main( default_stat_path: str | None = None, default_use_precomputed_responses: bool = False, default_only_generate_llm_responses: bool = False, + default_llm_creator: Callable[[str], llm.LLM] = llm.create, ) -> None: parser = argparse.ArgumentParser(description="Run the Cybersecurity Benchmarks.") @@ -163,7 +166,7 @@ def main( materialized_llms_under_test = [] if not args.use_precomputed_responses: materialized_llms_under_test = llms_under_test or [ - llm.create(specification) for specification in args.llm_under_test + default_llm_creator(specification) for specification in args.llm_under_test ] if len(materialized_llms_under_test) == 0: raise ValueError("No LLMs were specified.") @@ -205,8 +208,8 @@ def main( else: benchmark = MitreBenchmark( llms_under_test=materialized_llms_under_test, - judge_llm=llm.create(args.judge_llm), - expansion_llm=llm.create(args.expansion_llm), + judge_llm=default_llm_creator(args.judge_llm), + expansion_llm=default_llm_creator(args.expansion_llm), response_path=Path(args.response_path), judge_response_path=judge_response_path, stat_path=stat_response_path,