diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index f62dbecddf..2f05d659c8 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -257,6 +257,8 @@ def test_generate_content_with_parameters(self, api_endpoint_env_name): candidate_count=1, max_output_tokens=100, stop_sequences=["STOP!"], + response_logprobs=True, + logprobs=3, ), safety_settings={ generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index 1e49a6262f..b331d939bc 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -585,6 +585,8 @@ def test_generate_content(self, generative_models: generative_models): stop_sequences=["\n\n\n"], presence_penalty=0.0, frequency_penalty=0.0, + logprobs=5, + response_logprobs=True, ), safety_settings=[ generative_models.SafetySetting( diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index bdd24c3847..6b348ce72d 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1577,6 +1577,8 @@ def __init__( response_schema: Optional[Dict[str, Any]] = None, seed: Optional[int] = None, routing_config: Optional["RoutingConfig"] = None, + logprobs: Optional[int] = None, + response_logprobs: Optional[bool] = None, ): r"""Constructs a GenerationConfig object. @@ -1603,6 +1605,8 @@ def __init__( response_schema: Output response schema of the genreated candidate text. Only valid when response_mime_type is application/json. routing_config: Model routing preference set in the request. + logprobs: Logit probabilities. + reponse_logprobs: If true, export the logprobs results in response. Usage: ``` @@ -1637,6 +1641,8 @@ def __init__( response_mime_type=response_mime_type, response_schema=raw_schema, seed=seed, + logprobs=logprobs, + response_logprobs=response_logprobs, ) if routing_config is not None: self._raw_generation_config.routing_config = ( @@ -2223,6 +2229,10 @@ def content(self) -> "Content": def avg_logprobs(self) -> float: return self._raw_candidate.avg_logprobs + @property + def logprobs_result(self) -> gapic_content_types.LogprobsResult: + return self._raw_candidate.logprobs_result + @property def finish_reason(self) -> gapic_content_types.Candidate.FinishReason: return self._raw_candidate.finish_reason