Skip to content

Commit

Permalink
bigquery: check for recursive types during schema inference
Browse files Browse the repository at this point in the history
Keep track of which types we've seen to avoid infinite recursion
on recursive types.

Change-Id: Ibb36190adde8199f1ebfc32f9260661b3834c747
Reviewed-on: https://code-review.googlesource.com/9831
Reviewed-by: Sarah Adams <shadams@google.com>
  • Loading branch information
jba committed Dec 10, 2016
1 parent 0c87a68 commit dd37f36
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
34 changes: 22 additions & 12 deletions bigquery/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package bigquery

import (
"errors"
"fmt"
"reflect"

bq "google.golang.org/api/bigquery/v2"
Expand Down Expand Up @@ -128,34 +129,43 @@ var typeOfByteSlice = reflect.TypeOf([]byte{})
// InferSchema tries to derive a BigQuery schema from the supplied struct value.
// NOTE: All fields in the returned Schema are configured to be required,
// unless the corresponding field in the supplied struct is a slice or array.
//
// It is considered an error if the struct (including nested structs) contains
// any exported fields that are pointers or one of the following types:
// uint, uint64, uintptr, map, interface, complex64, complex128, func, chan.
// In these cases, an error will be returned.
// Future versions may handle these cases without error.
//
// Recursively defined structs are also disallowed.
func InferSchema(st interface{}) (Schema, error) {
return inferStruct(reflect.TypeOf(st))
return inferSchemaReflect(reflect.TypeOf(st))
}

func inferStruct(rt reflect.Type) (Schema, error) {
switch rt.Kind() {
func inferSchemaReflect(t reflect.Type) (Schema, error) {
return inferStruct(t, map[reflect.Type]bool{})
}
func inferStruct(t reflect.Type, seen map[reflect.Type]bool) (Schema, error) {
if seen[t] {
return nil, fmt.Errorf("bigquery: schema inference for recursive type %s", t)
}
seen[t] = true
switch t.Kind() {
case reflect.Ptr:
if rt.Elem().Kind() != reflect.Struct {
if t.Elem().Kind() != reflect.Struct {
return nil, errNoStruct
}
rt = rt.Elem()
t = t.Elem()
fallthrough

case reflect.Struct:
return inferFields(rt)
return inferFields(t, seen)
default:
return nil, errNoStruct
}

}

// inferFieldSchema infers the FieldSchema for a Go type
func inferFieldSchema(rt reflect.Type) (*FieldSchema, error) {
func inferFieldSchema(rt reflect.Type, seen map[reflect.Type]bool) (*FieldSchema, error) {
switch rt {
case typeOfByteSlice:
return &FieldSchema{Required: true, Type: BytesFieldType}, nil
Expand All @@ -179,15 +189,15 @@ func inferFieldSchema(rt reflect.Type) (*FieldSchema, error) {
return nil, errUnsupportedFieldType
}

f, err := inferFieldSchema(et)
f, err := inferFieldSchema(et, seen)
if err != nil {
return nil, err
}
f.Repeated = true
f.Required = false
return f, nil
case reflect.Struct, reflect.Ptr:
nested, err := inferStruct(rt)
nested, err := inferStruct(rt, seen)
if err != nil {
return nil, err
}
Expand All @@ -204,14 +214,14 @@ func inferFieldSchema(rt reflect.Type) (*FieldSchema, error) {
}

// inferFields extracts all exported field types from struct type.
func inferFields(rt reflect.Type) (Schema, error) {
func inferFields(rt reflect.Type, seen map[reflect.Type]bool) (Schema, error) {
var s Schema
fields, err := fieldCache.Fields(rt)
if err != nil {
return nil, err
}
for _, field := range fields {
f, err := inferFieldSchema(field.Type)
f, err := inferFieldSchema(field.Type, seen)
if err != nil {
return nil, err
}
Expand Down
12 changes: 12 additions & 0 deletions bigquery/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,18 @@ func TestEmbeddedInference(t *testing.T) {
}
}

func TestRecursiveInference(t *testing.T) {
type List struct {
Val int
Next *List
}

_, err := InferSchema(List{})
if err == nil {
t.Fatal("got nil, want error")
}
}

type withTags struct {
NoTag int
ExcludeTag int `bigquery:"-"`
Expand Down
2 changes: 1 addition & 1 deletion bigquery/uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func toValueSaver(x interface{}) (ValueSaver, bool, error) {
return nil, false, nil
}
// TODO(jba): cache schema inference to speed this up.
schema, err := inferStruct(v.Type())
schema, err := inferSchemaReflect(v.Type())
if err != nil {
return nil, false, err
}
Expand Down

0 comments on commit dd37f36

Please sign in to comment.