Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs, improve code clarity, and enhance overall reliability across several files. #339

Merged
merged 12 commits into from
May 17, 2024
30 changes: 27 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,41 @@ This "editable" mode lets you edit the source without needing to reinstall the p

### Testing

Use the builtin unittest package:
To ensure the integrity of the codebase, we have a suite of tests located in the `generative-ai-python/tests` directory.

You can run all these tests using Python's built-in `unittest` module or the `pytest` library.

For `unittest`, open a terminal and navigate to the root directory of the project. Then, execute the following command:

```
python -m unittest discover -s tests

# or more simply
python -m unittest
```
python -m unittest

Alternatively, if you prefer using `pytest`, you can install it using pip:

```
pip install pytest
```

Then, run the tests with the following command:

```
pytest tests

# or more simply
pytest
```


Or to debug, use:

```commandline
pip install nose2

nose2 --debugger
```

### Type checking

Expand Down
6 changes: 3 additions & 3 deletions google/generativeai/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP

if not isinstance(source, Iterable):
raise TypeError(
f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`."
f"The 'source' argument must be an instance of 'GroundingPassagesOptions', but got a '{type(source).__name__}' object instead."
)

passages = []
Expand Down Expand Up @@ -182,7 +182,7 @@ def _make_generate_answer_request(
temperature: float | None = None,
) -> glm.GenerateAnswerRequest:
"""
Calls the API to generate a grounded answer from the model.
constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.

Args:
model: Name of the model used to generate the grounded response.
Expand Down Expand Up @@ -219,7 +219,7 @@ def _make_generate_answer_request(
elif semantic_retriever is not None:
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
else:
TypeError(
raise TypeError(
f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`"
)

Expand Down
4 changes: 2 additions & 2 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
return _client_manager.get_default_client("retriever_async")


def get_dafault_permission_client() -> glm.PermissionServiceClient:
def get_default_permission_client() -> glm.PermissionServiceClient:
return _client_manager.get_default_client("permission")


def get_dafault_permission_async_client() -> glm.PermissionServiceAsyncClient:
def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
return _client_manager.get_default_client("permission_async")
10 changes: 6 additions & 4 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def start_chat(
>>> response = chat.send_message("Hello?")

Arguments:
history: An iterable of `glm.Content` objects, or equvalents to initialize the session.
history: An iterable of `glm.Content` objects, or equivalents to initialize the session.
"""
if self._generation_config.get("candidate_count", 1) > 1:
raise ValueError("Can't chat with `candidate_count > 1`")
Expand All @@ -403,11 +403,13 @@ def start_chat(
class ChatSession:
"""Contains an ongoing conversation with the model.

>>> model = genai.GenerativeModel(model="gemini-pro")
>>> model = genai.GenerativeModel('models/gemini-pro')
>>> chat = model.start_chat()
>>> response = chat.send_message("Hello")
>>> print(response.text)
>>> response = chat.send_message(...)
>>> response = chat.send_message("Hello again")
>>> print(response.text)
>>> response = chat.send_message(...

This `ChatSession` object collects the messages sent and received, in its
`ChatSession.history` attribute.
Expand Down Expand Up @@ -446,7 +448,7 @@ def send_message(

Appends the request and response to the conversation history.

>>> model = genai.GenerativeModel(model="gemini-pro")
>>> model = genai.GenerativeModel('models/gemini-pro')
>>> chat = model.start_chat()
>>> response = chat.send_message("Hello")
>>> print(response.text)
Expand Down
24 changes: 13 additions & 11 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,31 @@ def get_model(
client=None,
request_options: dict[str, Any] | None = None,
) -> model_types.Model | model_types.TunedModel:
"""Given a model name, fetch the `types.Model` or `types.TunedModel` object.
"""Given a model name, fetch the `types.Model`

```
import pprint
model = genai.get_tuned_model(model_name):
model = genai.get_model('models/gemini-pro')
pprint.pprint(model)
```

Args:
name: The name of the model to fetch.
name: The name of the model to fetch. Should start with `models/`
client: The client to use.
request_options: Options for the request.

Returns:
A `types.Model` or `types.TunedModel` object.
A `types.Model`
"""
name = model_types.make_model_name(name)
if name.startswith("models/"):
return get_base_model(name, client=client, request_options=request_options)
elif name.startswith("tunedModels/"):
return get_tuned_model(name, client=client, request_options=request_options)
else:
raise ValueError("Model names must start with `models/` or `tunedModels/`")
raise ValueError(
f"Model names must start with `models/` or `tunedModels/`. Received: {name}"
)


def get_base_model(
Expand All @@ -68,12 +70,12 @@ def get_base_model(

```
import pprint
model = genai.get_model('models/chat-bison-001'):
model = genai.get_base_model('models/chat-bison-001')
pprint.pprint(model)
```

Args:
name: The name of the model to fetch.
name: The name of the model to fetch. Should start with `models/`
client: The client to use.
request_options: Options for the request.

Expand All @@ -88,7 +90,7 @@ def get_base_model(

name = model_types.make_model_name(name)
if not name.startswith("models/"):
raise ValueError(f"Base model names must start with `models/`, got: {name}")
raise ValueError(f"Base model names must start with `models/`, received: {name}")

result = client.get_model(name=name, **request_options)
result = type(result).to_dict(result)
Expand All @@ -105,12 +107,12 @@ def get_tuned_model(

```
import pprint
model = genai.get_tuned_model('tunedModels/my-model-1234'):
model = genai.get_tuned_model('tunedModels/gemini-1.0-pro-001')
pprint.pprint(model)
```

Args:
name: The name of the model to fetch.
name: The name of the model to fetch. Should start with `tunedModels/`
client: The client to use.
request_options: Options for the request.

Expand All @@ -126,7 +128,7 @@ def get_tuned_model(
name = model_types.make_model_name(name)

if not name.startswith("tunedModels/"):
raise ValueError("Tuned model names must start with `tunedModels/`")
raise ValueError("Tuned model names must start with `tunedModels/` received: {name}")

result = client.get_tuned_model(name=name, **request_options)

Expand Down
28 changes: 14 additions & 14 deletions google/generativeai/types/permission_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from google.protobuf import field_mask_pb2

from google.generativeai.client import get_dafault_permission_client
from google.generativeai.client import get_dafault_permission_async_client
from google.generativeai.client import get_default_permission_client
from google.generativeai.client import get_default_permission_async_client
from google.generativeai.utils import flatten_update_paths
from google.generativeai import string_utils

Expand Down Expand Up @@ -107,7 +107,7 @@ def delete(
Delete permission (self).
"""
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()
delete_request = glm.DeletePermissionRequest(name=self.name)
client.delete_permission(request=delete_request)

Expand All @@ -119,7 +119,7 @@ async def delete_async(
This is the async version of `Permission.delete`.
"""
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()
delete_request = glm.DeletePermissionRequest(name=self.name)
await client.delete_permission(request=delete_request)

Expand All @@ -146,7 +146,7 @@ def update(
`Permission` object with specified updates.
"""
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()

updates = flatten_update_paths(updates)
for update_path in updates:
Expand Down Expand Up @@ -176,7 +176,7 @@ async def update_async(
This is the async version of `Permission.update`.
"""
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()

updates = flatten_update_paths(updates)
for update_path in updates:
Expand Down Expand Up @@ -224,7 +224,7 @@ def get(
Requested permission as an instance of `Permission`.
"""
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()
get_perm_request = glm.GetPermissionRequest(name=name)
get_perm_response = client.get_permission(request=get_perm_request)
get_perm_response = type(get_perm_response).to_dict(get_perm_response)
Expand All @@ -240,7 +240,7 @@ async def get_async(
This is the async version of `Permission.get`.
"""
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()
get_perm_request = glm.GetPermissionRequest(name=name)
get_perm_response = await client.get_permission(request=get_perm_request)
get_perm_response = type(get_perm_response).to_dict(get_perm_response)
Expand Down Expand Up @@ -313,7 +313,7 @@ def create(
ValueError: When email_address is not specified and grantee_type is not set to EVERYONE.
"""
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()

request = self._make_create_permission_request(
role=role, grantee_type=grantee_type, email_address=email_address
Expand All @@ -333,7 +333,7 @@ async def create_async(
This is the async version of `PermissionAdapter.create_permission`.
"""
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()

request = self._make_create_permission_request(
role=role, grantee_type=grantee_type, email_address=email_address
Expand All @@ -358,7 +358,7 @@ def list(
Paginated list of `Permission` objects.
"""
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()

request = glm.ListPermissionsRequest(
parent=self.parent, page_size=page_size # pytype: disable=attribute-error
Expand All @@ -376,7 +376,7 @@ async def list_async(
This is the async version of `PermissionAdapter.list_permissions`.
"""
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()

request = glm.ListPermissionsRequest(
parent=self.parent, page_size=page_size # pytype: disable=attribute-error
Expand All @@ -400,7 +400,7 @@ def transfer_ownership(
if self.parent.startswith("corpora"):
raise NotImplementedError("Can'/t transfer_ownership for a Corpus")
if client is None:
client = get_dafault_permission_client()
client = get_default_permission_client()
transfer_request = glm.TransferOwnershipRequest(
name=self.parent, email_address=email_address # pytype: disable=attribute-error
)
Expand All @@ -415,7 +415,7 @@ async def transfer_ownership_async(
if self.parent.startswith("corpora"):
raise NotImplementedError("Can'/t transfer_ownership for a Corpus")
if client is None:
client = get_dafault_permission_async_client()
client = get_default_permission_async_client()
transfer_request = glm.TransferOwnershipRequest(
name=self.parent, email_address=email_address # pytype: disable=attribute-error
)
Expand Down
38 changes: 32 additions & 6 deletions tests/notebook/text_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,47 @@ def _generate_text(


class TextModelTestCase(absltest.TestCase):
def test_generate_text(self):
def test_generate_text_without_args(self):
model = TestModel()

result = model.call_model("prompt goes in")
self.assertEqual(result.text_results[0], "prompt goes in_1")
self.assertIsNone(result.text_results[1])
self.assertIsNone(result.text_results[2])
self.assertIsNone(result.text_results[3])

def test_generate_text_without_args_none_results(self):
model = TestModel()

result = model.call_model("prompt goes in")
self.assertEqual(result.text_results[1], "None")
self.assertEqual(result.text_results[2], "None")
self.assertEqual(result.text_results[3], "None")

def test_generate_text_with_args_first_result(self):
model = TestModel()
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)

result = model.call_model("prompt goes in", args)
self.assertEqual(result.text_results[0], "prompt goes in_1")

def test_generate_text_with_args_model_name(self):
model = TestModel()
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)

result = model.call_model("prompt goes in", args)
self.assertEqual(result.text_results[1], "model_name")
self.assertEqual(result.text_results[2], 0.42)
self.assertEqual(result.text_results[3], 5)

def test_generate_text_with_args_temperature(self):
model = TestModel()
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)
result = model.call_model("prompt goes in", args)

self.assertEqual(result.text_results[2], str(0.42))

def test_generate_text_with_args_candidate_count(self):
model = TestModel()
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)

result = model.call_model("prompt goes in", args)
self.assertEqual(result.text_results[3], str(5))

def test_retry(self):
model = TestModel()
Expand Down
Loading