From a64eb5dd7183a634a9e99ae5b9844c0fd6098910 Mon Sep 17 00:00:00 2001 From: Ryan Wood Date: Fri, 20 Nov 2020 22:04:24 +0900 Subject: [PATCH 1/3] Fixed issues with TVPs --- tvp_go19.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tvp_go19.go b/tvp_go19.go index d3890af9..5a5c48bd 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -54,7 +54,7 @@ func (tvp TVP) check() error { if valueOf.IsNil() { return ErrorTypeSliceIsEmpty } - if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct { + if elem := reflect.TypeOf(tvp.Value).Elem(); (elem.Kind() == reflect.Ptr && elem.Elem().Kind() != reflect.Struct) || (elem.Kind() != reflect.Struct) { return ErrorTypeSlice } return nil @@ -75,7 +75,18 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd writeBVarChar(buf, name) binary.Write(buf, binary.LittleEndian, uint16(len(columnStr))) + val := reflect.ValueOf(tvp.Value) + var elemType reflect.Type + if elem := val.Elem(); elem.Kind() == reflect.Ptr { + elemType = elem.Elem().Type() + } else { + elemType = elem.Type() + } + for i, column := range columnStr { + if elemType.Field(i).PkgPath == "" { + continue + } binary.Write(buf, binary.LittleEndian, uint32(column.UserType)) binary.Write(buf, binary.LittleEndian, uint16(column.Flags)) writeTypeInfo(buf, &columnStr[i].ti) @@ -91,12 +102,15 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd c: conn, } - val := reflect.ValueOf(tvp.Value) for i := 0; i < val.Len(); i++ { refStr := reflect.ValueOf(val.Index(i).Interface()) buf.WriteByte(_TVP_ROW_TOKEN) for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) + if refStr.Type().Field(fieldIdx).PkgPath == "" { + continue + } + tvpVal := field.Interface() if tvp.verifyStandardTypeOnNull(buf, tvpVal) { continue From 3429fd42f4c4a303f8b88bc01c489067ecf23dd7 Mon Sep 17 00:00:00 2001 From: Ryan Wood Date: Fri, 20 Nov 2020 22:32:01 +0900 Subject: [PATCH 2/3] Fixed panic --- tvp_go19.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tvp_go19.go b/tvp_go19.go index 5a5c48bd..c2f6ea02 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -77,10 +77,10 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd val := reflect.ValueOf(tvp.Value) var elemType reflect.Type - if elem := val.Elem(); elem.Kind() == reflect.Ptr { - elemType = elem.Elem().Type() + if elem := val.Type().Elem(); elem.Kind() == reflect.Ptr { + elemType = elem.Elem() } else { - elemType = elem.Type() + elemType = elem } for i, column := range columnStr { From 4517f4fd02b65fe971f08e0a446479516bc161bc Mon Sep 17 00:00:00 2001 From: Ryan Wood Date: Tue, 24 Nov 2020 12:14:02 +0900 Subject: [PATCH 3/3] Made unexported fields not be included --- tvp_go19.go | 21 ++++----------- tvp_go19_db_test.go | 38 ++++++++++++++++++++++++--- tvp_go19_test.go | 62 ++++++++++++++++++++++----------------------- 3 files changed, 71 insertions(+), 50 deletions(-) diff --git a/tvp_go19.go b/tvp_go19.go index c2f6ea02..72ec5be5 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -54,7 +54,7 @@ func (tvp TVP) check() error { if valueOf.IsNil() { return ErrorTypeSliceIsEmpty } - if elem := reflect.TypeOf(tvp.Value).Elem(); (elem.Kind() == reflect.Ptr && elem.Elem().Kind() != reflect.Struct) || (elem.Kind() != reflect.Struct) { + if elem := reflect.TypeOf(tvp.Value).Elem(); (elem.Kind() == reflect.Ptr && elem.Elem().Kind() != reflect.Struct) && (elem.Kind() != reflect.Struct) { return ErrorTypeSlice } return nil @@ -75,18 +75,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd writeBVarChar(buf, name) binary.Write(buf, binary.LittleEndian, uint16(len(columnStr))) - val := reflect.ValueOf(tvp.Value) - var elemType reflect.Type - if elem := val.Type().Elem(); elem.Kind() == reflect.Ptr { - elemType = elem.Elem() - } else { - elemType = elem - } - for i, column := range columnStr { - if elemType.Field(i).PkgPath == "" { - continue - } binary.Write(buf, binary.LittleEndian, uint32(column.UserType)) binary.Write(buf, binary.LittleEndian, uint16(column.Flags)) writeTypeInfo(buf, &columnStr[i].ti) @@ -102,15 +91,12 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd c: conn, } + val := reflect.ValueOf(tvp.Value) for i := 0; i < val.Len(); i++ { refStr := reflect.ValueOf(val.Index(i).Interface()) buf.WriteByte(_TVP_ROW_TOKEN) for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) - if refStr.Type().Field(fieldIdx).PkgPath == "" { - continue - } - tvpVal := field.Interface() if tvp.verifyStandardTypeOnNull(buf, tvpVal) { continue @@ -167,6 +153,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) { continue } + if field.PkgPath == "" { + continue + } tvpFieldIndexes = append(tvpFieldIndexes, i) if field.Type.Kind() == reflect.Ptr { v := reflect.New(field.Type.Elem()) diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6703ce5c..6a69802d 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -340,6 +340,33 @@ func TestTVPGoSQLTypes(t *testing.T) { }, } + param2 := []*TvpGoSQLTypes{ + { + PBool: sql.NullBool{ + Bool: true, + Valid: true, + }, + PBoolNull: sql.NullBool{}, + PFloat64: sql.NullFloat64{ + Float64: 14.33, + Valid: true, + }, + PFloat64Null: sql.NullFloat64{}, + PInt64: sql.NullInt64{ + Int64: 777, + Valid: true, + }, + PInt64Null: sql.NullInt64{}, + PString: sql.NullString{ + String: "test=tvp", + Valid: true, + }, + PStringNull: sql.NullString{}, + }, + } + + testResult := param1[:] + testResult = append(testResult, param1...) tvpType := TVP{ TypeName: "tvpGoSQLTypes", Value: param1, @@ -348,12 +375,17 @@ func TestTVPGoSQLTypes(t *testing.T) { TypeName: "tvpGoSQLTypes", Value: []TvpGoSQLTypes{}, } + tvpPointerType := TVP{ + TypeName: "tvpGoSQLTypes", + Value: param2, + } rows, err := db.QueryContext(ctx, - "exec spwithtvpGoSQLTypes @param1, @param2, @param3", + "exec spwithtvpGoSQLTypes @param1, @param2, @param3, @param4", sql.Named("param1", tvpType), sql.Named("param2", tvpTypeEmpty), sql.Named("param3", "test"), + sql.Named("param4", tvpPointerType), ) if err != nil { @@ -380,8 +412,8 @@ func TestTVPGoSQLTypes(t *testing.T) { result1 = append(result1, val) } - if !reflect.DeepEqual(param1, result1) { - t.Logf("expected: %+v", param1) + if !reflect.DeepEqual(testResult, result1) { + t.Logf("expected: %+v", testResult) t.Logf("actual: %+v", result1) t.Errorf("first resultset did not match param1") } diff --git a/tvp_go19_test.go b/tvp_go19_test.go index d38556cb..2638e0dc 100644 --- a/tvp_go19_test.go +++ b/tvp_go19_test.go @@ -182,12 +182,12 @@ func TestTVPType_check(t *testing.T) { wantErr: true, }, { - name: "Value isn't right", + name: "Value is pointer", fields: fields{ TVPName: "Test", TVPValue: []*fields{}, }, - wantErr: true, + wantErr: false, }, { name: "Value is right", @@ -331,10 +331,10 @@ func BenchmarkColumnTypes(b *testing.B) { func TestIsSkipField(t *testing.T) { type args struct { - tvpTagValue string - isTvpValue bool - jsonTagValue string - isJsonTagValue bool + TvpTagValue string + IsTvpValue bool + JsonTagValue string + IsJsonTagValue bool } tests := []struct { name string @@ -349,78 +349,78 @@ func TestIsSkipField(t *testing.T) { name: "tvp is skip", want: true, args: args{ - isTvpValue: true, - tvpTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "tvp is any", want: false, args: args{ - isTvpValue: true, - tvpTagValue: "tvp", + IsTvpValue: true, + TvpTagValue: "tvp", }, }, { name: "Json is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: skipTagValue, }, }, { name: "Json is any", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: "any", + IsJsonTagValue: true, + JsonTagValue: "any", }, }, { name: "Json is skip tvp is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, - isTvpValue: true, - tvpTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "Json is skip tvp is any", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, - isTvpValue: true, - tvpTagValue: "tvp", + IsJsonTagValue: true, + JsonTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: "tvp", }, }, { name: "Json is any tvp is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: "json", - isTvpValue: true, - tvpTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: "json", + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "Json is any tvp is skip", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: "json", - isTvpValue: true, - tvpTagValue: "tvp", + IsJsonTagValue: true, + JsonTagValue: "json", + IsTvpValue: true, + TvpTagValue: "tvp", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := IsSkipField(tt.args.tvpTagValue, tt.args.isTvpValue, tt.args.jsonTagValue, tt.args.isJsonTagValue); got != tt.want { + if got := IsSkipField(tt.args.TvpTagValue, tt.args.IsTvpValue, tt.args.JsonTagValue, tt.args.IsJsonTagValue); got != tt.want { t.Errorf("IsSkipField() = %v, schema %v", got, tt.want) } })