diff --git a/bulkcopy.go b/bulkcopy.go index a9aa1102..ec0e9b8d 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -3,6 +3,7 @@ package mssql import ( "bytes" "context" + "database/sql/driver" "encoding/binary" "fmt" "math" @@ -318,6 +319,19 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.ti.Size = col.ti.Size res.ti.TypeId = col.ti.TypeId + switch valuer := val.(type) { + case driver.Valuer: + var e error + val, e = driver.DefaultParameterConverter.ConvertValue(valuer) + if e != nil { + err = e + return + } + if val != nil { + return b.makeParam(val, col) + } + } + if val == nil { res.ti.Size = 0 return diff --git a/bulkcopy_test.go b/bulkcopy_test.go index ce7168cf..2dc6c6ae 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -14,6 +14,103 @@ import ( "time" ) +func TestBulkcopyWithInvalidNullableType(t *testing.T) { + // Arrange + tableName := "#table_test" + columns := []string{ + "test_nullfloat", + "test_nullstring", + "test_nullbyte", + "test_nullbool", + "test_nullint64", + "test_nullint32", + "test_nullint16", + "test_nulltime", + "test_nulluniqueidentifier", + } + values := []interface{}{ + sql.NullFloat64{Valid: false}, + sql.NullString{Valid: false}, + sql.NullByte{Valid: false}, + sql.NullBool{Valid: false}, + sql.NullInt64{Valid: false}, + sql.NullInt32{Valid: false}, + sql.NullInt16{Valid: false}, + sql.NullTime{Valid: false}, + NullUniqueIdentifier{Valid: false}, + } + + pool, logger := open(t) + defer pool.Close() + defer logger.StopLogging() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := pool.Conn(ctx) + if err != nil { + t.Fatal("failed to pull connection from pool", err) + } + defer conn.Close() + + err = setupNullableTypeTable(ctx, t, conn, tableName) + if err != nil { + t.Error("Setup table failed: ", err) + return + } + + stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...)) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.Exec(values...) + if err != nil { + t.Fatal("AddRow failed: ", err.Error()) + } + + result, err := stmt.Exec() + if err != nil { + t.Fatal("bulkcopy failed: ", err.Error()) + } + + insertedRowCount, _ := result.RowsAffected() + if insertedRowCount == 0 { + t.Fatal("0 row inserted!") + } + + //data verification + rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + + ptrs := make([]interface{}, len(columns)) + container := make([]interface{}, len(columns)) + for i := range ptrs { + ptrs[i] = &container[i] + } + if err := rows.Scan(ptrs...); err != nil { + t.Fatal(err) + } + for i, c := range columns { + if !compareValue(container[i], nil) { + v := container[i] + if s, ok := v.([]uint8); ok { + v = string(s) + } + t.Errorf("columns %s : expected: %T %v, got: %T %v\n", c, nil, nil, container[i], v) + } + } + } + if err := rows.Err(); err != nil { + t.Error(err) + } +} + func TestBulkcopy(t *testing.T) { // TDS level Bulk Insert is not supported on Azure SQL Server. if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { @@ -69,6 +166,14 @@ func TestBulkcopy(t *testing.T) { {"test_geom", geom, string(geom)}, {"test_uniqueidentifier", uid, string(uid)}, {"test_nulluniqueidentifier", nil, nil}, + {"test_nullfloat", sql.NullFloat64{64, true}, 64.0}, + {"test_nullstring", sql.NullString{"abcdefg", true}, "abcdefg"}, + {"test_nullbyte", sql.NullByte{0x01, true}, 1}, + {"test_nullbool", sql.NullBool{true, true}, true}, + {"test_nullint64", sql.NullInt64{9223372036854775807, true}, 9223372036854775807}, + {"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647}, + {"test_nullint16", sql.NullInt16{32767, true}, 32767}, + {"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)}, // {"test_smallmoney", 1234.56, nil}, // {"test_money", 1234.56, nil}, {"test_decimal_18_0", 1234.0001, "1234"}, @@ -223,6 +328,30 @@ func compareValue(a interface{}, expected interface{}) bool { } } +func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) { + tablesql := `CREATE TABLE ` + tableName + ` ( + [id] [int] IDENTITY(1,1) NOT NULL, + [test_nullfloat] [float] NULL, + [test_nullstring] [nvarchar](50) NULL, + [test_nullbyte] [tinyint] NULL, + [test_nullbool] [bit] NULL, + [test_nullint64] [bigint] NULL, + [test_nullint32] [int] NULL, + [test_nullint16] [smallint] NULL, + [test_nulltime] [datetime] NULL, + [test_nulluniqueidentifier] [uniqueidentifier] NULL, + CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED +( + [id] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY];` + _, err = conn.ExecContext(ctx, tablesql) + if err != nil { + t.Fatal("tablesql failed:", err) + } + return +} + func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) { tablesql := `CREATE TABLE ` + tableName + ` ( [id] [int] IDENTITY(1,1) NOT NULL, @@ -290,6 +419,14 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str [test_int16nvarchar] [varchar](4) NULL, [test_int8nvarchar] [varchar](3) NULL, [test_intnvarchar] [varchar](4) NULL, + [test_nullfloat] [float] NULL, + [test_nullstring] [nvarchar](50) NULL, + [test_nullbyte] [tinyint] NULL, + [test_nullbool] [bit] NULL, + [test_nullint64] [bigint] NULL, + [test_nullint32] [int] NULL, + [test_nullint16] [smallint] NULL, + [test_nulltime] [datetime] NULL, CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED ( [id] ASC