Skip to content

Commit

Permalink
add metadata handling
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 3, 2023
1 parent 2a33920 commit 41c2b62
Showing 1 changed file with 179 additions and 105 deletions.
284 changes: 179 additions & 105 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from __future__ import annotations

import os
from typing import cast, Optional, Union
import inspect
import dataclasses
import types
from typing import Any, cast, Sequence

import google.ai.generativelanguage as glm

Expand All @@ -26,15 +29,164 @@

from google.generativeai import version


USER_AGENT = "genai-py"

default_client_config = {}
default_discuss_client = None
default_discuss_async_client = None
default_model_client = None
default_text_client = None
default_operations_client = None

@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
metadata: Sequence[tuple[str, str]] = ()
discuss_client: glm.DiscussServiceClient | None = None
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
model_client: glm.ModelServiceClient | None = None
text_client: glm.TextServiceClient | None = None
operations_client = None

def configure(
self,
*,
api_key: str | None = None,
credentials: ga_credentials.Credentials | dict | None = None,
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
# See `_transport_registry` in `DiscussServiceClientMeta`.
# Since the transport classes align with the client classes it wouldn't make
# sense to accept a `Transport` object here even though the client classes can.
# We could accept a dict since all the `Transport` classes take the same args,
# but that seems rare. Users that need it can just switch to the low level API.
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
):
"""Captures default client configuration.
If no API key has been provided (either directly, or on `client_options`) and the
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
Args:
Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments.
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
api_key: The API-Key to use when creating the default clients (each service uses
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default (key, value) metadata pairs to send with every request.
"""
if isinstance(client_options, dict):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()
client_options = cast(client_options_lib.ClientOptions, client_options)
had_api_key_value = getattr(client_options, "api_key", None)

if had_api_key_value:
if api_key is not None:
raise ValueError(
"You can't set both `api_key` and `client_options['api_key']`."
)
else:
if api_key is None:
# If no key is provided explicitly, attempt to load one from the
# environment.
api_key = os.getenv("GOOGLE_API_KEY")

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
else:
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)

client_config = {
"credentials": credentials,
"transport": transport,
"client_options": client_options,
"client_info": client_info,
}

client_config = {
key: value for key, value in client_config.items() if value is not None
}

self.client_config = client_config
self.default_metadata = default_metadata
self.discuss_client = None
self.text_client = None
self.model_client = None
self.operations_client = None

def make_client(self, cls):
# Attempt to configure using defaults.
if self.client_config is None:
configure()

client = cls(**self.client_config)

if not self.default_metadata:
return client

def keep(name, f):
if name.startswith("_"):
return False
if not isinstance(f, types.FunctionType):
return False
if isinstance(f, classmethod):
return False
if isinstance(f, staticmethod):
False

return True

def add_default_metadata_wrapper(f):
def call(*args, metadata=(), **kwargs):
metadata = list(metadata) + list(self.default_metadata)
return f(*args, **kwargs, metadata=metadata)

return call

for name, value in cls.__dict__.items():
if not keep(name, value):
continue
f = getattr(client, name)
f = add_default_metadata_wrapper(f)
setattr(client, name, f)

return client

def get_default_discuss_client(self) -> glm.DiscussServiceClient:
if self.discuss_client is None:
self.discuss_client = self.make_client(glm.DiscussServiceClient)
return self.discuss_client

def get_default_text_client(self) -> glm.TextServiceClient:
if self.text_client is None:
self.text_client = self.make_client(glm.TextServiceClient)
return self.text_client

def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
if self.discuss_async_client is None:
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
return self.discuss_async_client

def get_default_model_client(self) -> glm.ModelServiceClient:
if self.model_client is None:
self.model_client = self.make_client(glm.ModelServiceClient)
return self.model_client

def get_default_operations_client(self) -> operations_v1.OperationsClient:
if self.operations_client is None:
self.model_client = get_default_model_client()
self.operations_client = model_client._transport.operations_client

return self.operations_client


_client_manager = _ClientManager()


def configure(
Expand All @@ -50,6 +202,7 @@ def configure(
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
):
"""Captures default client configuration.
Expand All @@ -58,117 +211,38 @@ def configure(
Args:
Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments.
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
api_key: The API-Key to use when creating the default clients (each service uses
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default `(key, value)` metadata pairs to send with every request.
"""
global default_client_config
global default_discuss_client
global default_model_client
global default_text_client
global default_operations_client

if isinstance(client_options, dict):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()
client_options = cast(client_options_lib.ClientOptions, client_options)
had_api_key_value = getattr(client_options, "api_key", None)

if had_api_key_value:
if api_key is not None:
raise ValueError(
"You can't set both `api_key` and `client_options['api_key']`."
)
else:
if api_key is None:
# If no key is provided explicitly, attempt to load one from the
# environment.
api_key = os.getenv("GOOGLE_API_KEY")

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
else:
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)

new_default_client_config = {
"credentials": credentials,
"transport": transport,
"client_options": client_options,
"client_info": client_info,
}

new_default_client_config = {
key: value
for key, value in new_default_client_config.items()
if value is not None
}

default_client_config = new_default_client_config
default_discuss_client = None
default_text_client = None
default_model_client = None
default_operations_client = None
return _client_manager.configure(
api_key=api_key,
credentials=credentials,
transport=transport,
client_options=client_options,
client_info=client_info,
default_metadata=default_metadata,
)


def get_default_discuss_client() -> glm.DiscussServiceClient:
global default_discuss_client
if default_discuss_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_discuss_client = glm.DiscussServiceClient(**default_client_config)

return default_discuss_client
return _client_manager.get_default_discuss_client()


def get_default_text_client() -> glm.TextServiceClient:
global default_text_client
if default_text_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_text_client = glm.TextServiceClient(**default_client_config)

return default_text_client


def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
global default_discuss_async_client
if default_discuss_async_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_discuss_async_client = glm.DiscussServiceAsyncClient(
**default_client_config
)

return default_discuss_async_client
return _client_manager.get_default_discuss_client()


def get_default_model_client() -> glm.ModelServiceClient:
global default_model_client
if default_model_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_model_client = glm.ModelServiceClient(**default_client_config)
def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_operations_client()

return default_model_client

def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_discuss_async_client()

def get_default_operations_client() -> operations_v1.OperationsClient:
global default_operations_client
if default_operations_client is None:
model_client = get_default_model_client()
default_operations_client = model_client._transport.operations_client

return default_operations_client
def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_model_client()

0 comments on commit 41c2b62

Please sign in to comment.