Skip to content

Commit

Permalink
chore: LLM - update disableAttribution field level to reflect the r…
Browse files Browse the repository at this point in the history
…ecent backend schema change.

PiperOrigin-RevId: 582884203
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 16, 2023
1 parent 595b580 commit 3a8f22c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
49 changes: 26 additions & 23 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,15 +1584,18 @@ def test_text_generation_multiple_candidates_grounding(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB", "disableAttribution": False}]},
{
"sources": [{"type": "WEB"}],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]

Expand Down Expand Up @@ -1702,18 +1705,18 @@ async def test_text_generation_multiple_candidates_grounding_async(self):
"sources": [
{
"type": "WEB",
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]

Expand Down Expand Up @@ -2499,18 +2502,18 @@ def test_chat(self):
"sources": [
{
"type": "WEB",
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]
for test_grounding_source, expected_grounding_source in zip(
Expand Down Expand Up @@ -2552,18 +2555,18 @@ def test_chat(self):
"sources": [
{
"type": "WEB",
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]
for test_grounding_source, expected_grounding_source in zip(
Expand Down Expand Up @@ -2636,18 +2639,18 @@ async def test_chat_async(self):
"sources": [
{
"type": "WEB",
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]
for test_grounding_source, expected_grounding_source in zip(
Expand Down Expand Up @@ -2693,18 +2696,18 @@ async def test_chat_async(self):
"sources": [
{
"type": "WEB",
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
],
"disableAttribution": False,
},
]
for test_grounding_source, expected_grounding_source in zip(
Expand Down
27 changes: 20 additions & 7 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,14 @@ class WebSearch(_GroundingSourceBase):
_type: str = dataclasses.field(default="WEB", init=False, repr=False)

def _to_grounding_source_dict(self) -> Dict[str, Any]:
return {"type": self._type, "disableAttribution": self.disable_attribution}
return {
"sources": [
{
"type": self._type,
}
],
"disableAttribution": self.disable_attribution,
}


@dataclasses.dataclass
Expand Down Expand Up @@ -770,8 +777,12 @@ def _get_datastore_path(self) -> str:

def _to_grounding_source_dict(self) -> Dict[str, Any]:
return {
"type": self._type,
"vertexAiSearchDatastore": self._get_datastore_path(),
"sources": [
{
"type": self._type,
"vertexAiSearchDatastore": self._get_datastore_path(),
}
],
"disableAttribution": self.disable_attribution,
}

Expand Down Expand Up @@ -1206,8 +1217,9 @@ def _create_text_generation_prediction_request(
prediction_parameters["candidateCount"] = candidate_count

if grounding_source is not None:
sources = [grounding_source._to_grounding_source_dict()]
prediction_parameters["groundingConfig"] = {"sources": sources}
prediction_parameters[
"groundingConfig"
] = grounding_source._to_grounding_source_dict()

return _PredictionRequest(
instance=instance,
Expand Down Expand Up @@ -2044,8 +2056,9 @@ def _prepare_request(
prediction_parameters["candidateCount"] = candidate_count

if grounding_source is not None:
sources = [grounding_source._to_grounding_source_dict()]
prediction_parameters["groundingConfig"] = {"sources": sources}
prediction_parameters[
"groundingConfig"
] = grounding_source._to_grounding_source_dict()

message_structs = []
for past_message in self._message_history:
Expand Down

0 comments on commit 3a8f22c

Please sign in to comment.