Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Change-Id: I15fd5701dd5c4200461a32c968fa19e375403a7e
  • Loading branch information
MarkDaoust committed Sep 23, 2024
1 parent 6e00eed commit 907e5a2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
5 changes: 2 additions & 3 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,10 @@ def _join_prompt_feedbacks(

def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
chunks = tuple(chunks)
if 'usage_metadata' in chunks[-1]:
if "usage_metadata" in chunks[-1]:
usage_metadata = chunks[-1].usage_metadata
else:
usage_metadata=None

usage_metadata = None

return protos.GenerateContentResponse(
candidates=_join_candidate_lists(c.candidates for c in chunks),
Expand Down
4 changes: 3 additions & 1 deletion google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):

def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
if isinstance(tuned_model, protos.TunedModel):
tuned_model = type(tuned_model).to_dict(tuned_model, including_default_value_fields=False) # pytype: disable=attribute-error
tuned_model = type(tuned_model).to_dict(
tuned_model, including_default_value_fields=False
) # pytype: disable=attribute-error
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))

base_model = tuned_model.pop("base_model", None)
Expand Down
23 changes: 15 additions & 8 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Person(TypedDict):

class UnitTests(parameterized.TestCase):
maxDiff = None

@parameterized.named_parameters(
[
"protos.GenerationConfig",
Expand Down Expand Up @@ -473,7 +474,10 @@ def test_join_prompt_feedbacks(self):
def test_join_candidates(self):
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
result = generation_types._join_candidate_lists(candidate_lists)
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r, including_default_value_fields=False) for r in result])
self.assertEqual(
self.MERGED_CANDIDATES,
[type(r).to_dict(r, including_default_value_fields=False) for r in result],
)

def test_join_chunks(self):
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
Expand All @@ -485,7 +489,9 @@ def test_join_chunks(self):
],
)

chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(prompt_token_count=5)
chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(
prompt_token_count=5
)

result = generation_types._join_chunks(chunks)

Expand All @@ -502,15 +508,16 @@ def test_join_chunks(self):
}
],
},
"usage_metadata": {
"prompt_token_count": 5
}

"usage_metadata": {"prompt_token_count": 5},
},
)

expected = json.dumps(type(expected).to_dict(expected, including_default_value_fields=False), indent=4)
result = json.dumps(type(result).to_dict(result, including_default_value_fields=False), indent=4)
expected = json.dumps(
type(expected).to_dict(expected, including_default_value_fields=False), indent=4
)
result = json.dumps(
type(result).to_dict(result, including_default_value_fields=False), indent=4
)

self.assertEqual(expected, result)

Expand Down

0 comments on commit 907e5a2

Please sign in to comment.