Skip to content

Commit

Permalink
add Field examples to model prototypes (#210)
Browse files Browse the repository at this point in the history
* add Field examples

* add test

* clean up test

* handle serializing prototypes + better checking

* handle pydantic v1

* handle json strings, tests
  • Loading branch information
isabelizimm committed Mar 27, 2024
1 parent e1ab9bc commit 3f13ed5
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 29 deletions.
16 changes: 11 additions & 5 deletions vetiver/pin_read_write.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .vetiver_model import VetiverModel
from .meta import VetiverMeta
from .utils import inform
from .utils import inform, serialize_prototype
import warnings
import logging

Expand Down Expand Up @@ -72,10 +72,16 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
"user": model.metadata.user,
"vetiver_meta": {
"required_pkgs": model.metadata.required_pkgs,
"prototype": None if not model.prototype else model.prototype().json(),
"python_version": None
if not model.metadata.python_version
else list(model.metadata.python_version),
"prototype": (
None
if not model.prototype
else serialize_prototype(model.prototype)
),
"python_version": (
None
if not model.metadata.python_version
else list(model.metadata.python_version)
),
},
},
versioned=versioned,
Expand Down
15 changes: 12 additions & 3 deletions vetiver/prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import numpy as np
import pydantic
from pydantic import Field
from warnings import warn
from .types import create_prototype

Expand Down Expand Up @@ -159,7 +160,8 @@ def _item(value):
# if its a numpy type, we have to take the Python type due to Pydantic

dict_data = {
f"{key}": (type(value.item()), _item(value)) for key, value in dict_data.items()
f"{key}": (type(value.item()), Field(..., example=_item(value)))
for key, value in dict_data.items()
}
prototype = create_prototype(**dict_data)
return prototype
Expand All @@ -182,7 +184,14 @@ def _(data: dict):
# automatically create for simple prototypes
try:
for key, value in data["properties"].items():
dict_data.update({key: (type(value["default"]), value["default"])})
dict_data.update(
{
key: (
type(value["example"]),
Field(..., example=value["example"]),
)
}
)
# error for complex objects
except KeyError:
raise InvalidPTypeError(
Expand Down Expand Up @@ -223,5 +232,5 @@ def _(data: NoneType):
def _to_field(data):
basemodel_input = dict()
for key, value in data.items():
basemodel_input[key] = (type(value), value)
basemodel_input[key] = (type(value), Field(..., example=value))
return basemodel_input
50 changes: 47 additions & 3 deletions vetiver/tests/test_build_vetiver_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,36 @@ def test_vetiver_model_array_prototype():
description=None,
metadata=None,
)
try:
json_schema = v.prototype.model_json_schema()
expected = {
"properties": {
"0": {"example": 96, "title": "0", "type": "integer"},
"1": {"example": 11, "title": "1", "type": "integer"},
"2": {"example": 33, "title": "2", "type": "integer"},
},
"required": ["0", "1", "2"],
"title": "prototype",
"type": "object",
}
except AttributeError: # pydantic v1
json_schema = v.prototype.schema_json()
expected = '{\
"title": "prototype", \
"type": "object", \
"properties": {"0": {"title": "0", "example": 96, "type": "integer"}, \
"1": {"title": "1", "example": 11, "type": "integer"}, \
"2": {"title": "2", "example": 33, "type": "integer"}}, \
"required": ["0", "1", "2"]}'

assert v.model == model
assert issubclass(v.prototype, vetiver.Prototype)
# change to model_construct for pydantic v3
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
assert v.prototype.construct().__dict__ == {"0": 96, "1": 11, "2": 33}
assert json_schema == expected


@pytest.mark.parametrize("prototype_data", [{"B": 96, "C": 0, "D": 0}, X_df])
@pytest.mark.parametrize("prototype_data", [{"B": 96, "C": 11, "D": 33}, X_df])
def test_vetiver_model_dict_like_prototype(prototype_data):
v = VetiverModel(
model=model,
Expand All @@ -63,7 +84,30 @@ def test_vetiver_model_dict_like_prototype(prototype_data):
assert v.model == model
# change to model_construct for pydantic v3
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
assert v.prototype.construct().B == 96

try:
json_schema = v.prototype.model_json_schema()
expected = {
"properties": {
"B": {"example": 96, "title": "B", "type": "integer"},
"C": {"example": 11, "title": "C", "type": "integer"},
"D": {"example": 33, "title": "D", "type": "integer"},
},
"required": ["B", "C", "D"],
"title": "prototype",
"type": "object",
}
except AttributeError: # pydantic v1
json_schema = v.prototype.schema_json()
expected = '{\
"title": "prototype", \
"type": "object", \
"properties": {"B": {"title": "B", "example": 96, "type": "integer"}, \
"C": {"title": "C", "example": 11, "type": "integer"}, \
"D": {"title": "D", "example": 33, "type": "integer"}}, \
"required": ["B", "C", "D"]}'

assert json_schema == expected


@pytest.mark.parametrize("prototype_data", [MockPrototype(B=4, C=0, D=0), None])
Expand Down
32 changes: 16 additions & 16 deletions vetiver/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
vetiver_create_prototype,
InvalidPTypeError,
vetiver_endpoint,
predict,
)
from pydantic import BaseModel, conint
from fastapi.testclient import TestClient
Expand All @@ -14,7 +15,7 @@


@pytest.fixture
def vetiver_model():
def model():
np.random.seed(500)
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
Expand All @@ -29,14 +30,7 @@ def vetiver_model():


@pytest.fixture
def client(vetiver_model):
app = VetiverAPI(vetiver_model)

return TestClient(app.app)


@pytest.fixture
def complex_prototype_model():
def complex_prototype_client():
np.random.seed(500)

class CustomPrototype(BaseModel):
Expand Down Expand Up @@ -83,27 +77,28 @@ def test_get_metadata(client):
}


def test_get_prototype(client, vetiver_model):
def test_get_prototype(client, model):
response = client.get("/prototype")
assert response.status_code == 200, response.text
assert response.json() == {
"properties": {
"B": {"default": 55, "type": "integer"},
"C": {"default": 65, "type": "integer"},
"D": {"default": 17, "type": "integer"},
"B": {"example": 55, "type": "integer"},
"C": {"example": 65, "type": "integer"},
"D": {"example": 17, "type": "integer"},
},
"required": ["B", "C", "D"],
"title": "prototype",
"type": "object",
}

assert (
vetiver_model.prototype.construct().dict()
model.prototype.construct().dict()
== vetiver_create_prototype(response.json()).construct().dict()
)


def test_complex_prototype(complex_prototype_model):
response = complex_prototype_model.get("/prototype")
def test_complex_prototype(complex_prototype_client):
response = complex_prototype_client.get("/prototype")
assert response.status_code == 200, response.text
assert response.json() == {
"properties": {
Expand All @@ -120,6 +115,11 @@ def test_complex_prototype(complex_prototype_model):
vetiver_create_prototype(response.json())


def test_predict_wrong_input(client):
with pytest.raises(TypeError):
predict(endpoint="/predict/", data=[{"B": 43, "C": 43}], test_client=client)


def test_vetiver_endpoint():
url_raw = "http://127.0.0.1:8000/predict/"
url = vetiver_endpoint(url_raw)
Expand Down
6 changes: 5 additions & 1 deletion vetiver/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def model() -> VetiverModel:

@pytest.mark.parametrize(
"data,expected_length",
[([{"B": 0, "C": 0, "D": 0}], 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)],
[
([{"B": 0, "C": 0, "D": 0}], 1),
(pd.Series(data=[0, 0, 0], index=["B", "C", "D"]), 1),
(X, 100),
],
)
def test_predict_sklearn_ptype(data, expected_length, client):
response = predict(endpoint="/predict/", data=data, test_client=client)
Expand Down
17 changes: 16 additions & 1 deletion vetiver/tests/test_spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,22 @@ def test_bad_prototype_shape(data, spacy_model):
def test_good_prototype_shape(data, spacy_model):
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)

assert v.prototype.construct().dict() == {"col": "1"}
try:
model_schema = v.prototype.model_json_schema()
expected = {
"properties": {
"col": {"example": "1", "title": "Col", "type": "string"},
},
"required": ["col"],
"title": "prototype",
"type": "object",
}
except AttributeError: # pydantic v1
model_schema = v.prototype.schema_json()
expected = '{"title": "prototype", "type": "object", "properties": \
{"col": {"title": "Col", "example": "1", "type": "string"}}, "required": ["col"]}'

assert model_schema == expected


def test_vetiver_predict_with_prototype(client: TestClient):
Expand Down
14 changes: 14 additions & 0 deletions vetiver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import os
import subprocess
import json
from types import SimpleNamespace

no_notebook = False
Expand Down Expand Up @@ -53,3 +54,16 @@ def get_workbench_path(port):
return path
else:
return None


def serialize_prototype(prototype):
try:
schema = prototype.model_json_schema().get("properties")
except AttributeError: # pydantic v1
schema = json.loads(prototype.schema_json()).get("properties")

serialized_schema = dict()
for key, value in schema.items():
serialized_schema[key] = value.get("example") or value.get("default")

return json.dumps(serialized_schema)

0 comments on commit 3f13ed5

Please sign in to comment.