Skip to content

Commit

Permalink
feat: PrivateEndpoint.stream_raw_predict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669479012
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 30, 2024
1 parent efbcb54 commit 197f333
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 1 deletion.
89 changes: 89 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3666,6 +3666,95 @@ def raw_predict(
headers=headers_with_token,
)

def stream_raw_predict(
self,
body: bytes,
headers: Dict[str, str],
endpoint_override: Optional[str] = None,
) -> Iterator[bytes]:
"""Make a streaming prediction request using arbitrary headers.
Example usage:
my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
# Prepare the request body
request_body = json.dumps({...}).encode('utf-8')
# Define the headers
headers = {
'Content-Type': 'application/json',
}
# Use stream_raw_predict to send the request and process the response
for stream_response in psc_endpoint.stream_raw_predict(
body=request_body,
headers=headers,
endpoint_override="10.128.0.26" # Replace with your actual endpoint
):
stream_response_text = stream_response.decode('utf-8')
Args:
body (bytes):
The body of the prediction request in bytes. This must not
exceed 10 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no
restrictions on the header.
endpoint_override (Optional[str]):
The Private Service Connect endpoint's IP address or DNS that
points to the endpoint's service attachment.
Yields:
predictions (Iterator[bytes]):
The streaming prediction results as lines of bytes.
Raises:
ValueError: If a endpoint override is not provided for PSC based
endpoint.
ValueError: If a endpoint override is invalid for PSC based endpoint.
"""
self.wait()
if self.network or not self.private_service_connect_config:
raise ValueError(
"PSA based private endpoint does not support streaming prediction."
)

if self.private_service_connect_config:
if not endpoint_override:
raise ValueError(
"Cannot make a predict request because endpoint override is"
"not provided. Please ensure an endpoint override is"
"provided."
)
if not self._validate_endpoint_override(endpoint_override):
raise ValueError(
"Invalid endpoint override provided. Please only use IP"
"address or DNS."
)
if not self.credentials.valid:
self.credentials.refresh(google_auth_requests.Request())

token = self.credentials.token
headers_with_token = dict(headers)
headers_with_token["Authorization"] = f"Bearer {token}"

if not self.authorized_session:
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
self.authorized_session = google_auth_requests.AuthorizedSession(
self.credentials
)

url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
with self.authorized_session.post(
url=url,
data=body,
headers=headers_with_token,
stream=True,
verify=False,
) as resp:
for line in resp.iter_lines():
yield line

def explain(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'explain' as of now."
Expand Down
96 changes: 95 additions & 1 deletion tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import copy
from datetime import datetime, timedelta
from importlib import reload
import requests
import json
import requests
from unittest import mock

from google.api_core import operation as ga_operation
Expand Down Expand Up @@ -920,6 +920,49 @@ def predict_private_endpoint_mock():
yield predict_mock


@pytest.fixture
def stream_raw_predict_private_endpoint_mock():
with mock.patch.object(
google_auth_requests.AuthorizedSession, "post"
) as stream_raw_predict_mock:
# Create a mock response object
mock_response = mock.Mock(spec=requests.Response)

# Configure the mock to be used as a context manager
stream_raw_predict_mock.return_value.__enter__.return_value = mock_response

# Set the status code to 200 (OK)
mock_response.status_code = 200

# Simulate streaming data with iter_lines
mock_response.iter_lines = mock.Mock(
return_value=iter(
[
json.dumps(
{
"predictions": [1.0, 2.0, 3.0],
"metadata": {"key": "value"},
"deployedModelId": "model-id-123",
"model": "model-name",
"modelVersionId": "1",
}
).encode("utf-8"),
json.dumps(
{
"predictions": [4.0, 5.0, 6.0],
"metadata": {"key": "value"},
"deployedModelId": "model-id-123",
"model": "model-name",
"modelVersionId": "1",
}
).encode("utf-8"),
]
)
)

yield stream_raw_predict_mock


@pytest.fixture
def health_check_private_endpoint_mock():
with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock:
Expand Down Expand Up @@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock):
},
)

@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock):
test_endpoint = models.PrivateEndpoint(
project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID
)

test_prediction_iterator = test_endpoint.stream_raw_predict(
body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
headers={
"Content-Type": "application/json",
"Authorization": "Bearer None",
},
endpoint_override=_TEST_ENDPOINT_OVERRIDE,
)

test_prediction = list(test_prediction_iterator)

stream_raw_predict_private_endpoint_mock.assert_called_once_with(
url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict",
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
headers={
"Content-Type": "application/json",
"Authorization": "Bearer None",
},
stream=True,
verify=False,
)

# Validate the content of the returned predictions
expected_predictions = [
json.dumps(
{
"predictions": [1.0, 2.0, 3.0],
"metadata": {"key": "value"},
"deployedModelId": "model-id-123",
"model": "model-name",
"modelVersionId": "1",
}
).encode("utf-8"),
json.dumps(
{
"predictions": [4.0, 5.0, 6.0],
"metadata": {"key": "value"},
"deployedModelId": "model-id-123",
"model": "model-name",
"modelVersionId": "1",
}
).encode("utf-8"),
]
assert test_prediction == expected_predictions

@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
def test_psc_predict_without_endpoint_override(self):
test_endpoint = models.PrivateEndpoint(
Expand Down

0 comments on commit 197f333

Please sign in to comment.