Skip to content

Commit

Permalink
More TF2OpenAPI Unit Tests (kubeflow#248)
Browse files Browse the repository at this point in the history
* Test with real models

* Resolve non-determinism in array matching

* Make OpenAPI restrictions more strict

* disallow additional properties in req body that aren't specified in
OpenAPI

* Update unit tests to be more restrictive

* Add generator unit test

* Add generator tests

* Add JSON extension to spec files for highlighting

* Move utils to kfserving utils

* Fix formatting

* Organize utils types into separate file
  • Loading branch information
jc2729 authored and k8s-ci-robot committed Jul 18, 2019
1 parent b1ae27b commit 3107cdf
Show file tree
Hide file tree
Showing 24 changed files with 764 additions and 25 deletions.
9 changes: 9 additions & 0 deletions pkg/utils/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package utils

func Bool(b bool) *bool {
return &b
}

func UInt64(u uint64) *uint64 {
return &u
}
24 changes: 24 additions & 0 deletions pkg/utils/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package utils

import (
"testing"
"github.com/google/go-cmp/cmp"
)

func TestBool(t *testing.T) {
input := true
expected := &input
result := Bool(input)
if diff := cmp.Diff(expected, result); diff != "" {
t.Errorf("Test %q unexpected result (-want +got): %v", t.Name(), diff)
}
}

func TestUInt64(t *testing.T) {
input := uint64(63)
expected := &input
result := UInt64(input)
if diff := cmp.Diff(expected, result); diff != "" {
t.Errorf("Test %q unexpected result (-want +got): %v", t.Name(), diff)
}
}
2 changes: 1 addition & 1 deletion tools/tf2openapi/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TF_PROTO_OUT := generated

# Run tests
test: generate
go test ./types/...
go test ./types/... ./generator/...

# Generate code
generate:
Expand Down
12 changes: 6 additions & 6 deletions tools/tf2openapi/generator/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
)

const (
defaultSigDefKey = "serving_default"
defaultTag = "serve"
DefaultSigDefKey = "serving_default"
DefaultTag = "serve"
)

// Known error messages
Expand All @@ -33,10 +33,10 @@ type Builder struct {

func (b *Builder) Build() Generator {
if b.Generator.sigDefKey == "" {
b.SetSigDefKey(defaultSigDefKey)
b.SetSigDefKey(DefaultSigDefKey)
}
if len(b.Generator.metaGraphTags) == 0 {
b.SetMetaGraphTags([]string{defaultTag})
b.SetMetaGraphTags([]string{DefaultTag})
}
return b.Generator
}
Expand Down Expand Up @@ -68,10 +68,10 @@ func (g *Generator) GenerateOpenAPI(model *pb.SavedModel) (string, error) {
}
json, marshallingErr := spec.MarshalJSON()
if marshallingErr != nil {
return "", fmt.Errorf(UnmarshallableSpecError, marshallingErr.Error(), json)
panic(fmt.Errorf(UnmarshallableSpecError, marshallingErr.Error(), json))
}
if validationErr := validateOpenAPI(json); validationErr != nil {
return "", validationErr
panic(validationErr)
}
return string(json), nil
}
Expand Down
308 changes: 308 additions & 0 deletions tools/tf2openapi/generator/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
package generator

import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/golang/protobuf/proto"
"github.com/kubeflow/kfserving/tools/tf2openapi/generated/framework"
pb "github.com/kubeflow/kfserving/tools/tf2openapi/generated/protobuf"
"github.com/kubeflow/kfserving/tools/tf2openapi/types"
"github.com/onsi/gomega"
"io/ioutil"
"net/http"
"path/filepath"
"testing"
)

func TestGeneratorBuilderSpecifiedFields(t *testing.T) {
g := gomega.NewGomegaWithT(t)
builder := &Builder{}
builder.SetName("model")
builder.SetVersion("1")
builder.SetMetaGraphTags([]string{"tag"})
builder.SetSigDefKey("sigDefKey")
generator := builder.Build()
expectedGenerator := Generator{
name: "model",
version: "1",
metaGraphTags: []string{"tag"},
sigDefKey: "sigDefKey",
}
g.Expect(generator).Should(gomega.Equal(expectedGenerator))
}

func TestGeneratorBuilderDefault(t *testing.T) {
g := gomega.NewGomegaWithT(t)
builder := &Builder{}
builder.SetName("model")
builder.SetVersion("1")
generator := builder.Build()
expectedGenerator := defaultGenerator()
g.Expect(generator).Should(gomega.Equal(expectedGenerator))
}

func TestGenerateOpenAPIConstructionErr(t *testing.T) {
g := gomega.NewGomegaWithT(t)
generator := defaultGenerator()
model := &pb.SavedModel{
MetaGraphs: []*pb.MetaGraphDef{
{
MetaInfoDef: &pb.MetaGraphDef_MetaInfoDef{
Tags: []string{
"serve",
},
},
SignatureDef: map[string]*pb.SignatureDef{
"sigDefKey": {
MethodName: "tensorflow/serving/predict",
Inputs: map[string]*pb.TensorInfo{
"inputTensorName": {
// Unsupported data type will err
Dtype: framework.DataType_DT_COMPLEX128,
TensorShape: &framework.TensorShapeProto{
Dim: []*framework.TensorShapeProto_Dim{
{Size: -1},
{Size: 3},
},
UnknownRank: false,
},
},
},
Outputs: map[string]*pb.TensorInfo{
"outputTensorName": {
Dtype: framework.DataType_DT_INT8,
TensorShape: &framework.TensorShapeProto{
Dim: []*framework.TensorShapeProto_Dim{
{Size: -1},
{Size: 3},
},
UnknownRank: false,
},
},
},
},
},
},
},
}
_, specErr := generator.GenerateOpenAPI(model)
expectedErr := fmt.Sprintf(types.UnsupportedDataTypeError, "inputTensorName", "DT_COMPLEX128")
g.Expect(specErr).To(gomega.MatchError(expectedErr))
}

func TestGenerateOpenAPISpecGenerationErr(t *testing.T) {
g := gomega.NewGomegaWithT(t)
generator := defaultGenerator()
model := &pb.SavedModel{
MetaGraphs: []*pb.MetaGraphDef{
{
MetaInfoDef: &pb.MetaGraphDef_MetaInfoDef{
Tags: []string{
"serve",
},
},
SignatureDef: map[string]*pb.SignatureDef{
"serving_default": {
MethodName: "tensorflow/serving/classify",
Inputs: map[string]*pb.TensorInfo{
"inputTensorName": {
Dtype: framework.DataType_DT_INT8,
TensorShape: &framework.TensorShapeProto{
Dim: []*framework.TensorShapeProto_Dim{
{Size: -1},
{Size: 3},
},
UnknownRank: false,
},
},
},
Outputs: map[string]*pb.TensorInfo{
"outputTensorName": {
Dtype: framework.DataType_DT_INT8,
TensorShape: &framework.TensorShapeProto{
Dim: []*framework.TensorShapeProto_Dim{
{Size: -1},
{Size: 3},
},
UnknownRank: false,
},
},
},
},
},
},
},
}
_, specErr := generator.GenerateOpenAPI(model)
expectedErr := fmt.Sprintf(SpecGenerationError, types.UnsupportedAPISchemaError)
g.Expect(specErr).To(gomega.MatchError(expectedErr))
}


func TestGenerateOpenAPIForRowFmtMultipleTensors(t *testing.T) {
// model src: gs://kfserving-samples/models/tensorflow/flowers
g := gomega.NewGomegaWithT(t)
model := model(t, "TestRowFmtMultipleTensors")
generator := defaultGenerator()
spec, specErr := generator.GenerateOpenAPI(model)
g.Expect(specErr).Should(gomega.BeNil())

swagger := &openapi3.Swagger{}
g.Expect(json.Unmarshal([]byte(spec), &swagger)).To(gomega.Succeed())

expectedSpec := string(openAPI(t, "TestRowFmtMultipleTensors"))
expectedSwagger := &openapi3.Swagger{}
// remove any formatting from expectedSpec
buffer := new(bytes.Buffer)
if err := json.Compact(buffer, []byte(expectedSpec)); err != nil {
t.Fatal(err)
}
g.Expect(json.Unmarshal(buffer.Bytes(), &expectedSwagger)).To(gomega.Succeed())

// test equality, ignoring order in JSON arrays
instances := swagger.Components.RequestBodies["modelInput"].Value.Content.Get("application/json").
Schema.Value.Properties["instances"].Value
expectedInstances := expectedSwagger.Components.RequestBodies["modelInput"].Value.Content.
Get("application/json").Schema.Value.Properties["instances"].Value
g.Expect(instances.Items.Value.Required).Should(gomega.Not(gomega.BeNil()))
g.Expect(instances.Items.Value.Required).To(gomega.ConsistOf(expectedInstances.Items.Value.Required))
g.Expect(instances.Items.Value.AdditionalPropertiesAllowed).Should(gomega.Not(gomega.BeNil()))
g.Expect(instances.Items.Value.AdditionalPropertiesAllowed).Should(gomega.Equal(expectedInstances.Items.Value.AdditionalPropertiesAllowed))
g.Expect(instances.Items.Value.Properties).Should(gomega.Equal(expectedInstances.Items.Value.Properties))
}

func TestGenerateOpenAPIForColFmtMultipleTensors(t *testing.T) {
g := gomega.NewGomegaWithT(t)
model := model(t, "TestColFmtMultipleTensors")
generator := defaultGenerator()
spec, specErr := generator.GenerateOpenAPI(model)
g.Expect(specErr).Should(gomega.BeNil())

swagger := &openapi3.Swagger{}
g.Expect(json.Unmarshal([]byte(spec), &swagger)).To(gomega.Succeed())

expectedSpec := string(openAPI(t, "TestColFmtMultipleTensors"))
expectedSwagger := &openapi3.Swagger{}
// remove any formatting from expectedSpec
buffer := new(bytes.Buffer)
if err := json.Compact(buffer, []byte(expectedSpec)); err != nil {
t.Fatal(err)
}
g.Expect(json.Unmarshal(buffer.Bytes(), &expectedSwagger)).To(gomega.Succeed())

// ignore order in JSON arrays
inputs := swagger.Components.RequestBodies["modelInput"].Value.Content.Get("application/json").
Schema.Value.Properties["inputs"].Value
expectedInputs := expectedSwagger.Components.RequestBodies["modelInput"].Value.Content.
Get("application/json").Schema.Value.Properties["inputs"].Value
g.Expect(inputs.Required).Should(gomega.Not(gomega.BeNil()))
g.Expect(inputs.Required).To(gomega.ConsistOf(expectedInputs.Required))
g.Expect(inputs.Properties).Should(gomega.Equal(expectedInputs.Properties))
g.Expect(inputs.AdditionalPropertiesAllowed).Should(gomega.Equal(expectedInputs.AdditionalPropertiesAllowed))
}

func TestGenerateOpenAPIForVariousFmtsStrictly(t *testing.T) {
inputFmts := []struct {
name string
}{
{"TestColFmtSingleTensor"}, {"TestColFmtScalar"}, {"TestRowFmtSingleTensor"},
}
for _, fmt := range inputFmts {
g := gomega.NewGomegaWithT(t)
model := model(t, fmt.name)
generator := defaultGenerator()
spec, specErr := generator.GenerateOpenAPI(model)
expectedSpec := openAPI(t, fmt.name)
g.Expect(spec).Should(gomega.MatchJSON(expectedSpec))
g.Expect(specErr).Should(gomega.BeNil())
}
}

func TestAcceptsValidTFServingInput(t *testing.T) {
inputFmts := []struct {
name string
}{
{"TestColFmtSingleTensor"},
{"TestColFmtScalar"},
{"TestRowFmtSingleTensor"},
{"TestColFmtMultipleTensors"},
{"TestRowFmtMultipleTensors"},
}
for _, fmt := range inputFmts {
g := gomega.NewGomegaWithT(t)
g.Expect(acceptsValidReq(t, fmt.name)).Should(gomega.BeNil())
}
}

func defaultGenerator() Generator {
return Generator{
name: "model",
version: "1",
metaGraphTags: []string{DefaultTag},
sigDefKey: DefaultSigDefKey,
}
}

func model(t *testing.T, fName string) *pb.SavedModel {
model := &pb.SavedModel{}
fPath := filepath.Join("testdata", fName+".pb")
modelPb, err := ioutil.ReadFile(fPath)
if err != nil {
t.Fatalf("failed reading %s: %s", fPath, err)
}
if err := proto.Unmarshal(modelPb, model); err != nil {
t.Fatal("SavedModel not in expected format. May be corrupted: " + err.Error())
}
return model
}

func openAPI(t *testing.T, fName string) []byte {
fPath := filepath.Join("testdata", fName+".golden.json")
openAPI, err := ioutil.ReadFile(fPath)
if err != nil {
t.Fatalf("failed reading %s: %s", fPath, err)
}
return openAPI
}

func acceptsValidReq(t *testing.T, fName string) error {
router := openapi3filter.NewRouter().WithSwagger(loadSwagger(t, fName))
req, reqErr := http.NewRequest(http.MethodPost, "/v1/models/model/versions/1:predict",
bytes.NewReader(loadPayload(t, fName)))
if reqErr != nil {
t.Fatalf("error creating request: %s", reqErr)
}
route, pathParams, routeErr := router.FindRoute(req.Method, req.URL)
if routeErr != nil {
t.Fatalf("error finding route: %s", routeErr)
}
req.Header.Set("Content-Type", "application/json")
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
}
return openapi3filter.ValidateRequest(context.TODO(), requestValidationInput)
}

func loadSwagger(t *testing.T, fName string) *openapi3.Swagger {
fPath := filepath.Join("testdata", fName+".golden.json")
swagger, err := openapi3.NewSwaggerLoader().LoadSwaggerFromFile(fPath)
if err != nil {
t.Fatalf("failed reading %s: %s", fPath, err)
}
return swagger
}

func loadPayload(t *testing.T, fName string) []byte {
fPath := filepath.Join("testdata", fName+"Req.json")
payload, err := ioutil.ReadFile(fPath)
if err != nil {
t.Fatalf("failed reading %s: %s", fPath, err)
}
return payload
}
Loading

0 comments on commit 3107cdf

Please sign in to comment.