Skip to content

Commit

Permalink
feat: Support api keys in initializer and create_client
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661401962
  • Loading branch information
matthew29tang authored and copybara-github committed Aug 9, 2024
1 parent d352cec commit 7404f67
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
28 changes: 25 additions & 3 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _set_project_as_env_var_or_google_auth_default(self):
the project and credentials have already been set.
"""

if not self._project:
if not self._project and not self._api_key:
# Project is not set. Trying to get it from the environment.
# See https://github.com/googleapis/python-aiplatform/issues/852
# See https://github.com/googleapis/google-auth-library-python/issues/924
Expand Down Expand Up @@ -104,7 +104,7 @@ def _set_project_as_env_var_or_google_auth_default(self):
self._credentials = self._credentials or credentials
self._project = project

if not self._credentials:
if not self._credentials and not self._api_key:
credentials, _ = google.auth.default()
self._credentials = credentials

Expand All @@ -117,6 +117,7 @@ def __init__(self):
self._network = None
self._service_account = None
self._api_endpoint = None
self._api_key = None
self._api_transport = None
self._request_metadata = None
self._resource_type = None
Expand All @@ -137,6 +138,7 @@ def init(
network: Optional[str] = None,
service_account: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_transport: Optional[str] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
Expand Down Expand Up @@ -197,6 +199,9 @@ def init(
api_endpoint (str):
Optional. The desired API endpoint,
e.g., us-central1-aiplatform.googleapis.com
api_key (str):
Optional. The API key to use for service calls.
NOTE: Not all services support API keys.
api_transport (str):
Optional. The transport method which is either 'grpc' or 'rest'.
NOTE: "rest" transport functionality is currently in a
Expand Down Expand Up @@ -252,6 +257,8 @@ def init(
self._service_account = service_account
if request_metadata is not None:
self._request_metadata = request_metadata
if api_key is not None:
self._api_key = api_key
self._resource_type = None

# Finally, perform secondary state updates
Expand Down Expand Up @@ -304,6 +311,11 @@ def api_endpoint(self) -> Optional[str]:
"""Default API endpoint, if provided."""
return self._api_endpoint

@property
def api_key(self) -> Optional[str]:
"""API Key, if provided."""
return self._api_key

@property
def project(self) -> str:
"""Default project."""
Expand All @@ -325,7 +337,7 @@ def project(self) -> str:
except GoogleAuthError as exc:
raise GoogleAuthError(project_not_found_exception_str) from exc

if not project_id:
if not project_id and not self.api_key:
raise ValueError(project_not_found_exception_str)

return project_id
Expand Down Expand Up @@ -403,6 +415,7 @@ def get_client_options(
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
api_key: Optional[str] = None,
api_path_override: Optional[str] = None,
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand All @@ -414,6 +427,7 @@ def get_client_options(
Vertex AI.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_base_path_override (str): Optional. Override default API base path.
api_key (str): Optional. API key to use for the client.
api_path_override (str): Optional. Override default api path.
Returns:
clients_options (google.api_core.client_options.ClientOptions):
Expand Down Expand Up @@ -447,6 +461,11 @@ def get_client_options(
else api_path_override
)

# Project/location take precedence over api_key
if api_key and not self._project:
return client_options.ClientOptions(
api_endpoint=api_endpoint, api_key=api_key
)
return client_options.ClientOptions(api_endpoint=api_endpoint)

def common_location_path(
Expand Down Expand Up @@ -479,6 +498,7 @@ def create_client(
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
api_key: Optional[str] = None,
api_path_override: Optional[str] = None,
appended_user_agent: Optional[List[str]] = None,
appended_gapic_version: Optional[str] = None,
Expand All @@ -493,6 +513,7 @@ def create_client(
Optional. Custom auth credentials. If not provided will use the current config.
location_override (str): Optional. location override.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_key (str): Optional. API key to use for the client.
api_base_path_override (str): Optional. Override default api base path.
api_path_override (str): Optional. Override default api path.
appended_user_agent (List[str]):
Expand Down Expand Up @@ -539,6 +560,7 @@ def create_client(
"client_options": self.get_client_options(
location_override=location_override,
prediction_client=prediction_client,
api_key=api_key,
api_base_path_override=api_base_path_override,
api_path_override=api_path_override,
),
Expand Down
4 changes: 3 additions & 1 deletion vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _get_resource_name_from_model_name(
) -> str:
"""Returns the full resource name starting with projects/ given a model name."""
if model_name.startswith("publishers/"):
if not project:
return model_name
return f"projects/{project}/locations/{location}/{model_name}"
elif model_name.startswith("projects/"):
return model_name
Expand Down Expand Up @@ -337,7 +339,7 @@ def __init__(

location = aiplatform_utils.extract_project_and_location_from_parent(
prediction_resource_name
)["location"]
).get("location")

self._model_name = model_name
self._prediction_resource_name = prediction_resource_name
Expand Down

0 comments on commit 7404f67

Please sign in to comment.