Skip to content

Commit

Permalink
feat: GenAI - Add support for logprobs and response_logprobs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677832804
  • Loading branch information
sasha-gitg authored and copybara-github committed Sep 23, 2024
1 parent 86fc215 commit 7acf0f7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def test_generate_content_with_parameters(self, api_endpoint_env_name):
candidate_count=1,
max_output_tokens=100,
stop_sequences=["STOP!"],
response_logprobs=True,
logprobs=3,
),
safety_settings={
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ def test_generate_content(self, generative_models: generative_models):
stop_sequences=["\n\n\n"],
presence_penalty=0.0,
frequency_penalty=0.0,
logprobs=5,
response_logprobs=True,
),
safety_settings=[
generative_models.SafetySetting(
Expand Down
10 changes: 10 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,8 @@ def __init__(
response_schema: Optional[Dict[str, Any]] = None,
seed: Optional[int] = None,
routing_config: Optional["RoutingConfig"] = None,
logprobs: Optional[int] = None,
response_logprobs: Optional[bool] = None,
):
r"""Constructs a GenerationConfig object.
Expand All @@ -1603,6 +1605,8 @@ def __init__(
response_schema: Output response schema of the genreated candidate text. Only valid when
response_mime_type is application/json.
routing_config: Model routing preference set in the request.
logprobs: Logit probabilities.
reponse_logprobs: If true, export the logprobs results in response.
Usage:
```
Expand Down Expand Up @@ -1637,6 +1641,8 @@ def __init__(
response_mime_type=response_mime_type,
response_schema=raw_schema,
seed=seed,
logprobs=logprobs,
response_logprobs=response_logprobs,
)
if routing_config is not None:
self._raw_generation_config.routing_config = (
Expand Down Expand Up @@ -2223,6 +2229,10 @@ def content(self) -> "Content":
def avg_logprobs(self) -> float:
return self._raw_candidate.avg_logprobs

@property
def logprobs_result(self) -> gapic_content_types.LogprobsResult:
return self._raw_candidate.logprobs_result

@property
def finish_reason(self) -> gapic_content_types.Candidate.FinishReason:
return self._raw_candidate.finish_reason
Expand Down

0 comments on commit 7acf0f7

Please sign in to comment.