Skip to content

Commit

Permalink
Remove asynctest dependancy. (google-gemini#63)
Browse files Browse the repository at this point in the history
* Remove asynctest, fixes:b/278080256

* black .
  • Loading branch information
MarkDaoust authored and markmcd committed Oct 30, 2023
1 parent e32dc21 commit 96e07fe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 63 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def get_version():
extras_require = {
"dev": [
"absl-py",
"asynctest",
"black",
"nose2",
"pandas",
Expand Down
109 changes: 47 additions & 62 deletions tests/test_discuss_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,81 +16,66 @@
import sys
import unittest

if sys.version_info < (3, 11):
import asynctest
from asynctest import mock as async_mock

import google.ai.generativelanguage_v1beta3 as glm

from google.generativeai import discuss
from absl.testing import absltest
from absl.testing import parameterized

bases = (parameterized.TestCase,)

if sys.version_info < (3, 11):
bases = bases + (asynctest.TestCase,)

unittest.skipIf(
sys.version_info >= (3, 11), "asynctest is not suported on python 3.11+"
)


class AsyncTests(*bases):
if sys.version_info < (3, 11):
class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):
async def test_chat_async(self):
client = unittest.mock.AsyncMock()

async def test_chat_async(self):
client = async_mock.MagicMock()
observed_request = None

observed_request = None

async def fake_generate_message(
request: glm.GenerateMessageRequest,
) -> glm.GenerateMessageResponse:
nonlocal observed_request
observed_request = request
return glm.GenerateMessageResponse(
candidates=[
glm.Message(
author="1", content="Why did the chicken cross the road?"
)
]
)

client.generate_message = fake_generate_message
async def fake_generate_message(
request: glm.GenerateMessageRequest,
) -> glm.GenerateMessageResponse:
nonlocal observed_request
observed_request = request
return glm.GenerateMessageResponse(
candidates=[
glm.Message(
author="1", content="Why did the chicken cross the road?"
)
]
)

observed_response = await discuss.chat_async(
client.generate_message = fake_generate_message

observed_response = await discuss.chat_async(
model="models/bard",
context="Example Prompt",
examples=[["Example from human", "Example response from AI"]],
messages=["Tell me a joke"],
temperature=0.75,
candidate_count=1,
client=client,
)

self.assertEqual(
observed_request,
glm.GenerateMessageRequest(
model="models/bard",
context="Example Prompt",
examples=[["Example from human", "Example response from AI"]],
messages=["Tell me a joke"],
prompt=glm.MessagePrompt(
context="Example Prompt",
examples=[
glm.Example(
input=glm.Message(content="Example from human"),
output=glm.Message(content="Example response from AI"),
)
],
messages=[glm.Message(author="0", content="Tell me a joke")],
),
temperature=0.75,
candidate_count=1,
client=client,
)

self.assertEqual(
observed_request,
glm.GenerateMessageRequest(
model="models/bard",
prompt=glm.MessagePrompt(
context="Example Prompt",
examples=[
glm.Example(
input=glm.Message(content="Example from human"),
output=glm.Message(content="Example response from AI"),
)
],
messages=[glm.Message(author="0", content="Tell me a joke")],
),
temperature=0.75,
candidate_count=1,
),
)
self.assertEqual(
observed_response.candidates,
[{"author": "1", "content": "Why did the chicken cross the road?"}],
)
),
)
self.assertEqual(
observed_response.candidates,
[{"author": "1", "content": "Why did the chicken cross the road?"}],
)


if __name__ == "__main__":
Expand Down

0 comments on commit 96e07fe

Please sign in to comment.