Skip to content

Commit

Permalink
fix: support nullable types for bulkcopy (#192)
Browse files Browse the repository at this point in the history
* fix: support nullable types for bulkcopy

* Add test cases for all nullable types

* Fix test cases

* Add bulkcopy test for invalid nullable types

* Add case in convertInputParameter to bypass uniqueidentifier type

* Add test cases for invalid nullable test

* Revert bypass change
  • Loading branch information
vecknishwaran authored May 23, 2024
1 parent ada30cb commit b3a8513
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
14 changes: 14 additions & 0 deletions bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mssql
import (
"bytes"
"context"
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
Expand Down Expand Up @@ -318,6 +319,19 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = col.ti.Size
res.ti.TypeId = col.ti.TypeId

switch valuer := val.(type) {
case driver.Valuer:
var e error
val, e = driver.DefaultParameterConverter.ConvertValue(valuer)
if e != nil {
err = e
return
}
if val != nil {
return b.makeParam(val, col)
}
}

if val == nil {
res.ti.Size = 0
return
Expand Down
137 changes: 137 additions & 0 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,103 @@ import (
"time"
)

func TestBulkcopyWithInvalidNullableType(t *testing.T) {
// Arrange
tableName := "#table_test"
columns := []string{
"test_nullfloat",
"test_nullstring",
"test_nullbyte",
"test_nullbool",
"test_nullint64",
"test_nullint32",
"test_nullint16",
"test_nulltime",
"test_nulluniqueidentifier",
}
values := []interface{}{
sql.NullFloat64{Valid: false},
sql.NullString{Valid: false},
sql.NullByte{Valid: false},
sql.NullBool{Valid: false},
sql.NullInt64{Valid: false},
sql.NullInt32{Valid: false},
sql.NullInt16{Valid: false},
sql.NullTime{Valid: false},
NullUniqueIdentifier{Valid: false},
}

pool, logger := open(t)
defer pool.Close()
defer logger.StopLogging()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

conn, err := pool.Conn(ctx)
if err != nil {
t.Fatal("failed to pull connection from pool", err)
}
defer conn.Close()

err = setupNullableTypeTable(ctx, t, conn, tableName)
if err != nil {
t.Error("Setup table failed: ", err)
return
}

stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

_, err = stmt.Exec(values...)
if err != nil {
t.Fatal("AddRow failed: ", err.Error())
}

result, err := stmt.Exec()
if err != nil {
t.Fatal("bulkcopy failed: ", err.Error())
}

insertedRowCount, _ := result.RowsAffected()
if insertedRowCount == 0 {
t.Fatal("0 row inserted!")
}

//data verification
rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
for rows.Next() {

ptrs := make([]interface{}, len(columns))
container := make([]interface{}, len(columns))
for i := range ptrs {
ptrs[i] = &container[i]
}
if err := rows.Scan(ptrs...); err != nil {
t.Fatal(err)
}
for i, c := range columns {
if !compareValue(container[i], nil) {
v := container[i]
if s, ok := v.([]uint8); ok {
v = string(s)
}
t.Errorf("columns %s : expected: %T %v, got: %T %v\n", c, nil, nil, container[i], v)
}
}
}
if err := rows.Err(); err != nil {
t.Error(err)
}
}

func TestBulkcopy(t *testing.T) {
// TDS level Bulk Insert is not supported on Azure SQL Server.
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
Expand Down Expand Up @@ -69,6 +166,14 @@ func TestBulkcopy(t *testing.T) {
{"test_geom", geom, string(geom)},
{"test_uniqueidentifier", uid, string(uid)},
{"test_nulluniqueidentifier", nil, nil},
{"test_nullfloat", sql.NullFloat64{64, true}, 64.0},
{"test_nullstring", sql.NullString{"abcdefg", true}, "abcdefg"},
{"test_nullbyte", sql.NullByte{0x01, true}, 1},
{"test_nullbool", sql.NullBool{true, true}, true},
{"test_nullint64", sql.NullInt64{9223372036854775807, true}, 9223372036854775807},
{"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647},
{"test_nullint16", sql.NullInt16{32767, true}, 32767},
{"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
// {"test_smallmoney", 1234.56, nil},
// {"test_money", 1234.56, nil},
{"test_decimal_18_0", 1234.0001, "1234"},
Expand Down Expand Up @@ -223,6 +328,30 @@ func compareValue(a interface{}, expected interface{}) bool {
}
}

func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL,
[test_nullfloat] [float] NULL,
[test_nullstring] [nvarchar](50) NULL,
[test_nullbyte] [tinyint] NULL,
[test_nullbool] [bit] NULL,
[test_nullint64] [bigint] NULL,
[test_nullint32] [int] NULL,
[test_nullint16] [smallint] NULL,
[test_nulltime] [datetime] NULL,
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY];`
_, err = conn.ExecContext(ctx, tablesql)
if err != nil {
t.Fatal("tablesql failed:", err)
}
return
}

func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL,
Expand Down Expand Up @@ -290,6 +419,14 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_int16nvarchar] [varchar](4) NULL,
[test_int8nvarchar] [varchar](3) NULL,
[test_intnvarchar] [varchar](4) NULL,
[test_nullfloat] [float] NULL,
[test_nullstring] [nvarchar](50) NULL,
[test_nullbyte] [tinyint] NULL,
[test_nullbool] [bit] NULL,
[test_nullint64] [bigint] NULL,
[test_nullint32] [int] NULL,
[test_nullint16] [smallint] NULL,
[test_nulltime] [datetime] NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
Expand Down

0 comments on commit b3a8513

Please sign in to comment.