From a75e81c9e8bfe577468205fc0fc97366ff06f19d Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 10 Jan 2024 10:48:53 -0800 Subject: [PATCH] feat: enable inline context in grounding to TextGenerationModel predict. PiperOrigin-RevId: 597296033 --- vertexai/language_models/_language_models.py | 58 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2ae7a29e9c..ed96b20775 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -749,6 +749,27 @@ def _to_grounding_source_dict(self) -> Dict[str, Any]: } +@dataclasses.dataclass +class InlineContext(_GroundingSourceBase): + """InlineContext represents a grounding source using provided inline context. + Attributes: + inline_context: The content used as inline context. + """ + + inline_context: str + _type: str = dataclasses.field(default="INLINE", init=False, repr=False) + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return { + "sources": [ + { + "type": self._type, + } + ], + "inlineContext": self.inline_context, + } + + @dataclasses.dataclass class VertexAISearch(_GroundingSourceBase): """VertexAISearchDatastore represents a grounding source using Vertex AI Search datastore @@ -792,6 +813,7 @@ class GroundingSource: WebSearch = WebSearch VertexAISearch = VertexAISearch + InlineContext = InlineContext @dataclasses.dataclass @@ -976,7 +998,11 @@ def predict( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -1053,7 +1079,11 @@ async def predict_async( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -1284,7 +1314,11 @@ def _create_text_generation_prediction_request( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -2136,7 +2170,11 @@ def _prepare_request( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> _PredictionRequest: """Prepares a request for the language model. @@ -2289,7 +2327,11 @@ def send_message( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> "MultiCandidateTextGenerationResponse": """Sends message to the language model and gets a response. @@ -2352,7 +2394,11 @@ async def send_message_async( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> "MultiCandidateTextGenerationResponse": """Asynchronously sends message to the language model and gets a response.