Skip to content

Commit

Permalink
Update dataset construction code to handle more input types. (google-…
Browse files Browse the repository at this point in the history
…gemini#79)

* Allow dicts and pd.DataFrames as tuning datasets

* fix type hints.

* Support csv and json files, csv URLs.

* docs

* add TODO

* Allow json-urls and streaming decoding of CSVs.
  • Loading branch information
MarkDaoust authored and markmcd committed Oct 30, 2023
1 parent 00e0d57 commit f6e425a
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 22 deletions.
24 changes: 19 additions & 5 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,18 @@ def create_tuned_model(
epoch_count: int | None = None,
batch_size: int | None = None,
learning_rate: float | None = None,
input_key: str = "text_input",
output_key: str = "output",
client: glm.ModelServiceClient | None = None,
) -> operations.CreateTunedModelOperation:
"""Launches a tuning job to create a TunedModel.
Since tuning a model can take significant time, this API doesn't wait for the tuning to complete.
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the status
of the tuning job, or wait for it to complete, and check the result.
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the
status of the tuning job, or wait for it to complete, and check the result.
After the job completes you can either find the resulting `TunedModel` object in `Operation.result()`
or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
After the job completes you can either find the resulting `TunedModel` object in
`Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
```
my_id = "my-tuned-model-id"
Expand All @@ -275,6 +277,16 @@ def create_tuned_model(
*`glm.TuningExample`,
* {'text_input': text_input, 'output': output} dicts, or
* `(text_input, output)` tuples.
* A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which
columns to use as the input/output
* A csv file (will be read with `pd.read_csv` and handles as a `Mapping`
above). This can be:
* A local path as a `str` or `pathlib.Path`.
* A url for a csv file.
* The url of a Google Sheets file.
* A JSON file - Its contents will be handled either as an `Iterable` or `Mapping`
above. This can be:
* A local path as a `str` or `pathlib.Path`.
id: The model identifier, used to refer to the model in the API
`tunedModels/{id}`. Must be unique.
display_name: A human-readable name for display.
Expand Down Expand Up @@ -308,7 +320,9 @@ def create_tuned_model(
else:
ValueError(f"Not understood: `{source_model=}`")

training_data = model_types.encode_tuning_data(training_data)
training_data = model_types.encode_tuning_data(
training_data, input_key=input_key, output_key=output_key
)

hyperparameters = glm.Hyperparameters(
epoch_count=epoch_count,
Expand Down
110 changes: 94 additions & 16 deletions google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
"""Type definitions for the models service."""
from __future__ import annotations

from collections.abc import Mapping
import csv
import dataclasses
import datetime
import json
import pathlib
import re
from typing import Any, Iterable, TypedDict, Union
import urllib.request

import google.ai.generativelanguage as glm
from google.generativeai import string_utils
Expand Down Expand Up @@ -72,17 +77,18 @@ class Model:
"""A dataclass representation of a `glm.Model`.
Attributes:
name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming convention of:
"{base_model_id}-{version}". For example: `models/chat-bison-001`.
name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming
convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`.
base_model_id: The base name of the model. For example: `chat-bison`.
version: The major version number of the model. For example: `001`.
display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up to 128 characters
long and can consist of any UTF-8 characters.
display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
description: A short description of the model.
input_token_limit: Maximum number of input tokens allowed for this model.
output_token_limit: Maximum number of output tokens available for this model.
supported_generation_methods: lists which methods are supported by the model. The method names are defined as
Pascal case strings, such as `generateMessage` which correspond to API methods.
supported_generation_methods: lists which methods are supported by the model. The method
names are defined as Pascal case strings, such as `generateMessage` which correspond to
API methods.
"""

name: str
Expand Down Expand Up @@ -187,28 +193,100 @@ class TuningExampleDict(TypedDict):
output: str


TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str]]
TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]]

# TODO(markdaoust): gs:// URLS? File-type argument for files without extension?
TuningDataOptions = Union[
glm.Dataset, Iterable[TuningExampleOptions]
] # TODO(markdaoust): csv, json, pandas, np
pathlib.Path,
str,
glm.Dataset,
Mapping[str, Iterable[str]],
Iterable[TuningExampleOptions],
]


def encode_tuning_data(data: TuningDataOptions) -> glm.Dataset:
def encode_tuning_data(
data: TuningDataOptions, input_key="text_input", output_key="output"
) -> glm.Dataset:
if isinstance(data, glm.Dataset):
return data

if isinstance(data, str):
# Strings are either URLs or system paths.
if re.match("^\w+://\S+$", data):
data = _normalize_url(data)
else:
# Normalize system paths to use pathlib
data = pathlib.Path(data)

if isinstance(data, (str, pathlib.Path)):
if isinstance(data, str):
f = urllib.request.urlopen(data)
# csv needs strings, json does not.
content = (line.decode("utf-8") for line in f)
else:
f = data.open("r")
content = f

if str(data).lower().endswith(".json"):
with f:
data = json.load(f)
else:
with f:
data = csv.DictReader(content)
return _convert_iterable(data, input_key, output_key)

if hasattr(data, "keys"):
return _convert_dict(data, input_key, output_key)
else:
return _convert_iterable(data, input_key, output_key)


def _normalize_url(url: str) -> str:
sheet_base = "https://docs.google.com/spreadsheets"
if url.startswith(sheet_base):
# Normalize google-sheets URLs to download the csv.
match = re.match(f"{sheet_base}/d/[^/]+", url)
if match is None:
raise ValueError("Incomplete Google Sheets URL: {data}")
url = f"{match.group(0)}/export?format=csv"
return url


def _convert_dict(data, input_key, output_key):
new_data = list()

try:
inputs = data[input_key]
except KeyError as e:
raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}')

try:
outputs = data[output_key]
except KeyError as e:
raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}')

for i, o in zip(inputs, outputs):
new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)}))
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))


def _convert_iterable(data, input_key, output_key):
new_data = list()
for example in data:
example = encode_tuning_example(example)
example = encode_tuning_example(example, input_key, output_key)
new_data.append(example)
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))


def encode_tuning_example(example: TuningExampleOptions):
if isinstance(example, tuple):
example = glm.TuningExample(text_input=example[0], output=example[1])
else: # dict or glm.TuningExample
example = glm.TuningExample(example)
def encode_tuning_example(example: TuningExampleOptions, input_key, output_key):
if isinstance(example, glm.TuningExample):
return example
elif isinstance(example, (tuple, list)):
a, b = example
example = glm.TuningExample(text_input=a, output=b)
else: # dict
example = glm.TuningExample(text_input=example[input_key], output=example[output_key])
return example


Expand Down
4 changes: 4 additions & 0 deletions tests/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
text_input,output
a,1
b,2
c,3
5 changes: 5 additions & 0 deletions tests/test1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{"text_input": "a", "output": "1"},
{"text_input": "b", "output": "2"},
{"text_input": "c", "output": "3"}
]
1 change: 1 addition & 0 deletions tests/test2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}
5 changes: 5 additions & 0 deletions tests/test3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
["a","1"],
["b","2"],
["c","3"]
]
80 changes: 79 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import datetime
import dataclasses
import pathlib
import pytz
from typing import Any, Union
import unittest
Expand All @@ -29,7 +30,10 @@
from google.generativeai import models
from google.generativeai import client
from google.generativeai.types import model_types
from google.protobuf import field_mask_pb2

import pandas as pd

HERE = pathlib.Path(__file__).parent


class UnitTests(parameterized.TestCase):
Expand Down Expand Up @@ -385,6 +389,80 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source):
"models/swim-fish-000",
)

@parameterized.named_parameters(
[
"glm",
glm.Dataset(
examples=glm.TuningExamples(
examples=[
{"text_input": "a", "output": "1"},
{"text_input": "b", "output": "2"},
{"text_input": "c", "output": "3"},
]
)
),
],
[
"list",
[
("a", "1"),
{"text_input": "b", "output": "2"},
glm.TuningExample({"text_input": "c", "output": "3"}),
],
],
["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}],
[
"dict_custom_keys",
{"my_inputs": ["a", "b", "c"], "my_outputs": ["1", "2", "3"]},
"my_inputs",
"my_outputs",
],
[
"pd.DataFrame",
pd.DataFrame(
[
{"text_input": "a", "output": "1"},
{"text_input": "b", "output": "2"},
{"text_input": "c", "output": "3"},
]
),
],
["csv-path-string", str(HERE / "test.csv")],
["csv-path", HERE / "test.csv"],
["json-file-1", HERE / "test1.json"],
["json-file-2", HERE / "test2.json"],
["json-file-3", HERE / "test3.json"],
[
"json-url",
"https://storage.googleapis.com/generativeai-downloads/data/test1.json",
],
[
"csv-url",
"https://storage.googleapis.com/generativeai-downloads/data/test.csv",
],
[
"sheet-share",
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/edit?usp=sharing",
],
[
"sheet-export-csv",
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/export?format=csv",
],
)
def test_create_dataset(self, data, ik="text_input", ok="output"):
ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok)

expect = glm.Dataset(
examples=glm.TuningExamples(
examples=[
{"text_input": "a", "output": "1"},
{"text_input": "b", "output": "2"},
{"text_input": "c", "output": "3"},
]
)
)
self.assertEqual(expect, ds)


if __name__ == "__main__":
absltest.main()

0 comments on commit f6e425a

Please sign in to comment.