From 0f7df61caa47f3140f8633b8191ca683f0978395 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 19 Jul 2024 11:47:28 -0700 Subject: [PATCH] Add temporary validation for seed + PP --- src/vllm_tgis_adapter/grpc/grpc_server.py | 9 ++++++++- src/vllm_tgis_adapter/grpc/validation.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index ccacc71..4d41bcd 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -186,6 +186,9 @@ def __init__( self.skip_special_tokens = not args.output_special_tokens self.default_include_stop_seqs = args.default_include_stop_seqs + # Temporary to validate parameters currently broken with PP + self.pipeline_parallel = args.pipeline_parallel_size > 1 + # Backwards compatibility for TGIS: PREFIX_STORE_PATH adapter_cache_path = args.adapter_cache or args.prefix_store_path self.adapter_store = ( @@ -504,7 +507,11 @@ async def _validate_and_convert_params( """Return (sampling_params, deadline).""" # First run TGIS validation to raise errors that match the TGIS api try: - validate_params(params, self.max_max_new_tokens) + validate_params( + params, + self.max_max_new_tokens, + pipeline_parallel=self.pipeline_parallel, + ) except ValueError as tgis_validation_error: service_metrics.count_request_failure(FailureReasonLabel.VALIDATION) await context.abort(StatusCode.INVALID_ARGUMENT, str(tgis_validation_error)) diff --git a/src/vllm_tgis_adapter/grpc/validation.py b/src/vllm_tgis_adapter/grpc/validation.py index 327c10f..14bd1f3 100644 --- a/src/vllm_tgis_adapter/grpc/validation.py +++ b/src/vllm_tgis_adapter/grpc/validation.py @@ -78,9 +78,11 @@ def validate_input( ) -def validate_params( # noqa: C901 +def validate_params( # noqa: C901 PLR0912 params: Parameters, max_max_new_tokens: int, + *, + pipeline_parallel: bool = False, ) -> None: """Raise ValueError from the TGISValidationError enum if Parameters is invalid.""" # TODO: split into checks that are covered by vllm.SamplingParams @@ -91,6 +93,14 @@ def validate_params( # noqa: C901 stopping = params.stopping decoding = params.decoding + # Temporary validation -- this will be removed once we fix use of seed + # with pipeline parallel + if pipeline_parallel and sampling.HasField("seed") and sampling.seed: + raise ValueError( + "seed parameter is currently not supported for " + "pipeline parallel model deployments" + ) + # Decoding parameter checks if decoding.HasField("length_penalty") and not ( 1.0 <= decoding.length_penalty.decay_factor <= 10.0