Skip to content

Commit

Permalink
Add temporary validation for seed + PP
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and dtrifiro committed Jul 22, 2024
1 parent a87d800 commit 0f7df61
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def __init__(
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

# Temporary to validate parameters currently broken with PP
self.pipeline_parallel = args.pipeline_parallel_size > 1

# Backwards compatibility for TGIS: PREFIX_STORE_PATH
adapter_cache_path = args.adapter_cache or args.prefix_store_path
self.adapter_store = (
Expand Down Expand Up @@ -504,7 +507,11 @@ async def _validate_and_convert_params(
"""Return (sampling_params, deadline)."""
# First run TGIS validation to raise errors that match the TGIS api
try:
validate_params(params, self.max_max_new_tokens)
validate_params(
params,
self.max_max_new_tokens,
pipeline_parallel=self.pipeline_parallel,
)
except ValueError as tgis_validation_error:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(tgis_validation_error))
Expand Down
12 changes: 11 additions & 1 deletion src/vllm_tgis_adapter/grpc/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ def validate_input(
)


def validate_params( # noqa: C901
def validate_params( # noqa: C901 PLR0912
params: Parameters,
max_max_new_tokens: int,
*,
pipeline_parallel: bool = False,
) -> None:
"""Raise ValueError from the TGISValidationError enum if Parameters is invalid."""
# TODO: split into checks that are covered by vllm.SamplingParams
Expand All @@ -91,6 +93,14 @@ def validate_params( # noqa: C901
stopping = params.stopping
decoding = params.decoding

# Temporary validation -- this will be removed once we fix use of seed
# with pipeline parallel
if pipeline_parallel and sampling.HasField("seed") and sampling.seed:
raise ValueError(
"seed parameter is currently not supported for "
"pipeline parallel model deployments"
)

# Decoding parameter checks
if decoding.HasField("length_penalty") and not (
1.0 <= decoding.length_penalty.decay_factor <= 10.0
Expand Down

0 comments on commit 0f7df61

Please sign in to comment.