Skip to content

Commit

Permalink
feat: Add support for Prediction dedicated endpoint. predict/rawPredi…
Browse files Browse the repository at this point in the history
…ct/streamRawPredict can use dedicated DNS to access the dedicated endpoint.

PiperOrigin-RevId: 667018843
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 24, 2024
1 parent a0d4ff2 commit 3d68777
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 6 deletions.
142 changes: 136 additions & 6 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ def create(
enable_request_response_logging=False,
request_response_logging_sampling_rate: Optional[float] = None,
request_response_logging_bq_destination_table: Optional[str] = None,
dedicated_endpoint_enabled=False,
) -> "Endpoint":
"""Creates a new endpoint.
Expand Down Expand Up @@ -849,6 +850,10 @@ def create(
request_response_logging_bq_destination_table (str):
Optional. The request response logging bigquery destination. If not set, will create a table with name:
``bq://{project_id}.logging_{endpoint_display_name}_{endpoint_id}.request_response_logging``.
dedicated_endpoint_enabled (bool):
Optional. If enabled, a dedicated dns will be created and your
traffic will be fully isolated from other customers' traffic and
latency will be reduced.
Returns:
endpoint (aiplatform.Endpoint):
Expand Down Expand Up @@ -893,6 +898,7 @@ def create(
create_request_timeout=create_request_timeout,
endpoint_id=endpoint_id,
predict_request_response_logging_config=predict_request_response_logging_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
)

@classmethod
Expand All @@ -918,6 +924,7 @@ def _create(
private_service_connect_config: Optional[
gca_service_networking.PrivateServiceConnectConfig
] = None,
dedicated_endpoint_enabled=False,
) -> "Endpoint":
"""Creates a new endpoint by calling the API client.
Expand Down Expand Up @@ -984,6 +991,10 @@ def _create(
private_service_connect_config (aiplatform.service_network.PrivateServiceConnectConfig):
If enabled, the endpoint can be accessible via [Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect).
Cannot be enabled when network is specified.
dedicated_endpoint_enabled (bool):
Optional. If enabled, a dedicated dns will be created and your
traffic will be fully isolated from other customers' traffic and
latency will be reduced.
Returns:
endpoint (aiplatform.Endpoint):
Expand All @@ -1002,6 +1013,7 @@ def _create(
network=network,
predict_request_response_logging_config=predict_request_response_logging_config,
private_service_connect_config=private_service_connect_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
)

operation_future = api_client.create_endpoint(
Expand Down Expand Up @@ -2167,9 +2179,18 @@ def predict(
parameters: Optional[Dict] = None,
timeout: Optional[float] = None,
use_raw_predict: Optional[bool] = False,
*,
use_dedicated_endpoint: Optional[bool] = False,
) -> Prediction:
"""Make a prediction against this Endpoint.
For dedicated endpoint, set use_dedicated_endpoint = True:
```
response = my_endpoint.predict(instances=[...],
use_dedicated_endpoint=True)
my_predictions = response.predictions
```
Args:
instances (List):
Required. The instances that are the input to the
Expand All @@ -2194,6 +2215,9 @@ def predict(
use_raw_predict (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
against Endpoint.raw_predict().
use_dedicated_endpoint (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
using the dedicated endpoint dns.
Returns:
prediction (aiplatform.Prediction):
Expand All @@ -2204,6 +2228,7 @@ def predict(
raw_predict_response = self.raw_predict(
body=json.dumps({"instances": instances, "parameters": parameters}),
headers={"Content-Type": "application/json"},
use_dedicated_endpoint=use_dedicated_endpoint,
)
json_response = raw_predict_response.json()
return Prediction(
Expand All @@ -2219,6 +2244,51 @@ def predict(
_RAW_PREDICT_MODEL_VERSION_ID_KEY, None
),
)

if use_dedicated_endpoint:
self._sync_gca_resource_if_skipped()
if (
not self._gca_resource.dedicated_endpoint_enabled
or self._gca_resource.dedicated_endpoint_dns is None
):
raise ValueError(
"Dedicated endpoint is not enabled or DNS is empty."
"Please make sure endpoint has dedicated endpoint enabled"
"and model are ready before making a prediction."
)

if not self.authorized_session:
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
self.authorized_session = google_auth_requests.AuthorizedSession(
self.credentials
)

headers = {
"Content-Type": "application/json",
}

url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:predict"
response = self.authorized_session.post(
url=url,
data=json.dumps(
{
"instances": instances,
"parameters": parameters,
}
),
headers=headers,
)

prediction_response = json.loads(response.text)

return Prediction(
predictions=prediction_response.get("predictions"),
metadata=prediction_response.get("metadata"),
deployed_model_id=prediction_response.get("deployedModelId"),
model_resource_name=prediction_response.get("model"),
model_version_id=prediction_response.get("modelVersionId"),
)

else:
prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
Expand Down Expand Up @@ -2307,7 +2377,11 @@ async def predict_async(
)

def raw_predict(
self, body: bytes, headers: Dict[str, str]
self,
body: bytes,
headers: Dict[str, str],
*,
use_dedicated_endpoint: Optional[bool] = False,
) -> requests.models.Response:
"""Makes a prediction request using arbitrary headers.
Expand All @@ -2317,6 +2391,12 @@ def raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
headers = {'Content-Type':'application/json'}
)
# For dedicated endpoint:
response = my_endpoint.raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
headers = {'Content-Type':'application/json'},
dedicated_endpoint=True,
)
status_code = response.status_code
results = json.dumps(response.text)
Expand All @@ -2325,6 +2405,9 @@ def raw_predict(
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no restrictions on the header.
use_dedicated_endpoint (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
using the dedicated endpoint dns.
Returns:
A requests.models.Response object containing the status code and prediction results.
Expand All @@ -2338,12 +2421,29 @@ def raw_predict(
if self.raw_predict_request_url is None:
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"

return self.authorized_session.post(
url=self.raw_predict_request_url, data=body, headers=headers
)
url = self.raw_predict_request_url

if use_dedicated_endpoint:
self._sync_gca_resource_if_skipped()
if (
not self._gca_resource.dedicated_endpoint_enabled
or self._gca_resource.dedicated_endpoint_dns is None
):
raise ValueError(
"Dedicated endpoint is not enabled or DNS is empty."
"Please make sure endpoint has dedicated endpoint enabled"
"and model are ready before making a prediction."
)
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"

return self.authorized_session.post(url=url, data=body, headers=headers)

def stream_raw_predict(
self, body: bytes, headers: Dict[str, str]
self,
body: bytes,
headers: Dict[str, str],
*,
use_dedicated_endpoint: Optional[bool] = False,
) -> Iterator[requests.models.Response]:
"""Makes a streaming prediction request using arbitrary headers.
Expand All @@ -2358,13 +2458,28 @@ def stream_raw_predict(
stream_result = json.dumps(response.text)
```
For dedicated endpoint:
```
my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
for stream_response in my_endpoint.stream_raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
headers = {'Content-Type':'application/json'},
use_dedicated_endpoint=True,
):
status_code = response.status_code
stream_result = json.dumps(response.text)
```
Args:
body (bytes):
The body of the prediction request in bytes. This must not
exceed 10 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no
restrictions on the header.
use_dedicated_endpoint (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
using the dedicated endpoint dns.
Yields:
predictions (Iterator[requests.models.Response]):
Expand All @@ -2379,8 +2494,23 @@ def stream_raw_predict(
if self.stream_raw_predict_request_url is None:
self.stream_raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"

url = self.raw_predict_request_url

if use_dedicated_endpoint:
self._sync_gca_resource_if_skipped()
if (
not self._gca_resource.dedicated_endpoint_enabled
or self._gca_resource.dedicated_endpoint_dns is None
):
raise ValueError(
"Dedicated endpoint is not enabled or DNS is empty."
"Please make sure endpoint has dedicated endpoint enabled"
"and model are ready before making a prediction."
)
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict"

with self.authorized_session.post(
url=self.stream_raw_predict_request_url,
url=url,
data=body,
headers=headers,
stream=True,
Expand Down
Loading

0 comments on commit 3d68777

Please sign in to comment.