diff --git a/tvp_go19.go b/tvp_go19.go index d3890af9..72ec5be5 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -54,7 +54,7 @@ func (tvp TVP) check() error { if valueOf.IsNil() { return ErrorTypeSliceIsEmpty } - if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct { + if elem := reflect.TypeOf(tvp.Value).Elem(); (elem.Kind() == reflect.Ptr && elem.Elem().Kind() != reflect.Struct) && (elem.Kind() != reflect.Struct) { return ErrorTypeSlice } return nil @@ -153,6 +153,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) { continue } + if field.PkgPath == "" { + continue + } tvpFieldIndexes = append(tvpFieldIndexes, i) if field.Type.Kind() == reflect.Ptr { v := reflect.New(field.Type.Elem()) diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6703ce5c..6a69802d 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -340,6 +340,33 @@ func TestTVPGoSQLTypes(t *testing.T) { }, } + param2 := []*TvpGoSQLTypes{ + { + PBool: sql.NullBool{ + Bool: true, + Valid: true, + }, + PBoolNull: sql.NullBool{}, + PFloat64: sql.NullFloat64{ + Float64: 14.33, + Valid: true, + }, + PFloat64Null: sql.NullFloat64{}, + PInt64: sql.NullInt64{ + Int64: 777, + Valid: true, + }, + PInt64Null: sql.NullInt64{}, + PString: sql.NullString{ + String: "test=tvp", + Valid: true, + }, + PStringNull: sql.NullString{}, + }, + } + + testResult := param1[:] + testResult = append(testResult, param1...) tvpType := TVP{ TypeName: "tvpGoSQLTypes", Value: param1, @@ -348,12 +375,17 @@ func TestTVPGoSQLTypes(t *testing.T) { TypeName: "tvpGoSQLTypes", Value: []TvpGoSQLTypes{}, } + tvpPointerType := TVP{ + TypeName: "tvpGoSQLTypes", + Value: param2, + } rows, err := db.QueryContext(ctx, - "exec spwithtvpGoSQLTypes @param1, @param2, @param3", + "exec spwithtvpGoSQLTypes @param1, @param2, @param3, @param4", sql.Named("param1", tvpType), sql.Named("param2", tvpTypeEmpty), sql.Named("param3", "test"), + sql.Named("param4", tvpPointerType), ) if err != nil { @@ -380,8 +412,8 @@ func TestTVPGoSQLTypes(t *testing.T) { result1 = append(result1, val) } - if !reflect.DeepEqual(param1, result1) { - t.Logf("expected: %+v", param1) + if !reflect.DeepEqual(testResult, result1) { + t.Logf("expected: %+v", testResult) t.Logf("actual: %+v", result1) t.Errorf("first resultset did not match param1") } diff --git a/tvp_go19_test.go b/tvp_go19_test.go index d38556cb..2638e0dc 100644 --- a/tvp_go19_test.go +++ b/tvp_go19_test.go @@ -182,12 +182,12 @@ func TestTVPType_check(t *testing.T) { wantErr: true, }, { - name: "Value isn't right", + name: "Value is pointer", fields: fields{ TVPName: "Test", TVPValue: []*fields{}, }, - wantErr: true, + wantErr: false, }, { name: "Value is right", @@ -331,10 +331,10 @@ func BenchmarkColumnTypes(b *testing.B) { func TestIsSkipField(t *testing.T) { type args struct { - tvpTagValue string - isTvpValue bool - jsonTagValue string - isJsonTagValue bool + TvpTagValue string + IsTvpValue bool + JsonTagValue string + IsJsonTagValue bool } tests := []struct { name string @@ -349,78 +349,78 @@ func TestIsSkipField(t *testing.T) { name: "tvp is skip", want: true, args: args{ - isTvpValue: true, - tvpTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "tvp is any", want: false, args: args{ - isTvpValue: true, - tvpTagValue: "tvp", + IsTvpValue: true, + TvpTagValue: "tvp", }, }, { name: "Json is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: skipTagValue, }, }, { name: "Json is any", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: "any", + IsJsonTagValue: true, + JsonTagValue: "any", }, }, { name: "Json is skip tvp is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, - isTvpValue: true, - tvpTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "Json is skip tvp is any", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: skipTagValue, - isTvpValue: true, - tvpTagValue: "tvp", + IsJsonTagValue: true, + JsonTagValue: skipTagValue, + IsTvpValue: true, + TvpTagValue: "tvp", }, }, { name: "Json is any tvp is skip", want: true, args: args{ - isJsonTagValue: true, - jsonTagValue: "json", - isTvpValue: true, - tvpTagValue: skipTagValue, + IsJsonTagValue: true, + JsonTagValue: "json", + IsTvpValue: true, + TvpTagValue: skipTagValue, }, }, { name: "Json is any tvp is skip", want: false, args: args{ - isJsonTagValue: true, - jsonTagValue: "json", - isTvpValue: true, - tvpTagValue: "tvp", + IsJsonTagValue: true, + JsonTagValue: "json", + IsTvpValue: true, + TvpTagValue: "tvp", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := IsSkipField(tt.args.tvpTagValue, tt.args.isTvpValue, tt.args.jsonTagValue, tt.args.isJsonTagValue); got != tt.want { + if got := IsSkipField(tt.args.TvpTagValue, tt.args.IsTvpValue, tt.args.JsonTagValue, tt.args.IsJsonTagValue); got != tt.want { t.Errorf("IsSkipField() = %v, schema %v", got, tt.want) } })