Skip to content

Commit

Permalink
[misc] update utils to support comparing multiple settings (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#9140)

Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
  • Loading branch information
youkaichao authored and sumitd2 committed Nov 14, 2024
1 parent ea51af4 commit 4e60a33
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,38 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

compare_all_settings(
model,
[arg1, arg2],
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
)


def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
and compare the results of the API calls with the first set of arguments.
Args:
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
"""

trust_remote_code = False
for args in (arg1, arg2):
for args in all_args:
if "--trust-remote-code" in args:
trust_remote_code = True
break

tokenizer_mode = "auto"
for args in (arg1, arg2):
for args in all_args:
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
Expand All @@ -330,8 +354,10 @@ def compare_two_settings(model: str,

prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
results = []
for args, env in ((arg1, env1), (arg2, env2)):
ref_results: List = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
compare_results: List = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
env_dict=env,
Expand All @@ -355,13 +381,20 @@ def compare_two_settings(model: str,
else:
assert_never(method)

n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if i > 0:
# if any setting fails, raise an error early
ref_args = all_args[0]
ref_envs = all_envs[0]
compare_args = all_args[i]
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
f"{compare_args=} {compare_envs=}\n"
f"{ref_result=}\n"
f"{compare_result=}\n")


def init_test_distributed_environment(
Expand Down

0 comments on commit 4e60a33

Please sign in to comment.