diff --git a/bulkcopy.go b/bulkcopy.go index 15512a9e..a9aa1102 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -263,7 +263,7 @@ func (b *Bulk) createColMetadata() []byte { } binary.Write(buf, binary.LittleEndian, uint16(col.Flags)) - writeTypeInfo(buf, &b.bulkColumns[i].ti) + writeTypeInfo(buf, &b.bulkColumns[i].ti, false) if col.ti.TypeId == typeNText || col.ti.TypeId == typeText || diff --git a/queries_go19_test.go b/queries_go19_test.go index 20e22505..7578a175 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -346,6 +346,155 @@ SELECT @param2 = 'World' }) } +func TestOutputINOUTStringParam(t *testing.T) { + sqltextcreate := ` +CREATE PROCEDURE vinout + @sinout NVARCHAR(4000) OUTPUT +AS +BEGIN + IF @sinout = 'empty' + SET @sinout = NULL + ELSE + SET @sinout = 'long_long_value' +END; +` + sqltextdrop := `DROP PROCEDURE vinout;` + sqltextrun := `vinout` + + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + + db, err := sql.Open("sqlserver", makeConnStr(t).String()) + if err != nil { + t.Fatalf("failed to open driver sqlserver") + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db.ExecContext(ctx, sqltextdrop) + _, err = db.ExecContext(ctx, sqltextcreate) + if err != nil { + t.Fatal(err) + } + defer db.ExecContext(ctx, sqltextdrop) + + t.Run("original test", func(t *testing.T) { + sinout := "short_value" + _, err = db.ExecContext(ctx, sqltextrun, + sql.Named("sinout", sql.Out{Dest: &sinout}), + ) + if err != nil { + t.Error(err) + } + + if sinout != "long_long_value" { + t.Errorf("expected long_long_value, got %s", sinout) + } + }) + + t.Run("nullable value", func(t *testing.T) { + sinout := sql.NullString{String: "short_value", Valid: true} + _, err = db.ExecContext(ctx, sqltextrun, + sql.Named("sinout", sql.Out{Dest: &sinout}), + ) + if err != nil { + t.Error(err) + } + + if !sinout.Valid || sinout.String != "long_long_value" { + if sinout.Valid { + t.Errorf("expected long_long_value, got %s", sinout.String) + } else { + t.Errorf("expected long_long_value, got NULL") + } + } + }) + + t.Run("null value", func(t *testing.T) { + sinout := sql.NullString{} + _, err = db.ExecContext(ctx, sqltextrun, + sql.Named("sinout", sql.Out{Dest: &sinout}), + ) + if err != nil { + t.Error(err) + } + + if !sinout.Valid || sinout.String != "long_long_value" { + if sinout.Valid { + t.Errorf("expected long_long_value, got %s", sinout.String) + } else { + t.Errorf("expected long_long_value, got NULL") + } + } + }) + + t.Run("null result", func(t *testing.T) { + sinout := sql.NullString{String: "empty", Valid: true} + _, err = db.ExecContext(ctx, sqltextrun, + sql.Named("sinout", sql.Out{Dest: &sinout}), + ) + if err != nil { + t.Error(err) + } + + if sinout.Valid { + t.Errorf("expected NULL, got %s", sinout.String) + } + }) +} + +func TestOutputINOUTBytesParam(t *testing.T) { + sqltextcreate := ` +CREATE PROCEDURE vinout + @binout VARBINARY(4000) OUTPUT +AS +BEGIN + SET @binout = CONVERT(VARBINARY(4000), 'long_long_value') +END; +` + sqltextdrop := `DROP PROCEDURE vinout;` + sqltextrun := `vinout` + + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + + db, err := sql.Open("sqlserver", makeConnStr(t).String()) + if err != nil { + t.Fatalf("failed to open driver sqlserver") + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db.ExecContext(ctx, sqltextdrop) + _, err = db.ExecContext(ctx, sqltextcreate) + if err != nil { + t.Fatal(err) + } + defer db.ExecContext(ctx, sqltextdrop) + + t.Run("original test", func(t *testing.T) { + binout := []byte("short_value") + _, err = db.ExecContext(ctx, sqltextrun, + sql.Named("binout", sql.Out{Dest: &binout}), + ) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(binout, []byte("long_long_value")) { + t.Errorf("expected long_long_value, got %s", string(binout)) + } + }) +} + func TestOutputINOUTParam(t *testing.T) { sqltextcreate := ` CREATE PROCEDURE abinout diff --git a/rpc.go b/rpc.go index 8f1ef2b4..afda1309 100644 --- a/rpc.go +++ b/rpc.go @@ -73,7 +73,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil { return } - err = writeTypeInfo(buf, ¶m.ti) + err = writeTypeInfo(buf, ¶m.ti, (param.Flags&fByRevValue) != 0) if err != nil { return } @@ -82,7 +82,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, return } if (param.Flags & fEncrypted) == fEncrypted { - err = writeTypeInfo(buf, ¶m.tiOriginal) + err = writeTypeInfo(buf, ¶m.tiOriginal, false) if err != nil { return } diff --git a/tvp_go19.go b/tvp_go19.go index 0d555471..cc5dbfe4 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -80,7 +80,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd for i, column := range columnStr { binary.Write(buf, binary.LittleEndian, column.UserType) binary.Write(buf, binary.LittleEndian, column.Flags) - writeTypeInfo(buf, &columnStr[i].ti) + writeTypeInfo(buf, &columnStr[i].ti, false) writeBVarChar(buf, "") } // The returned error is always nil diff --git a/types.go b/types.go index 24cc4077..ac99e92e 100644 --- a/types.go +++ b/types.go @@ -148,7 +148,7 @@ func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) { } // https://msdn.microsoft.com/en-us/library/dd358284.aspx -func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) { +func writeTypeInfo(w io.Writer, ti *typeInfo, out bool) (err error) { err = binary.Write(w, binary.LittleEndian, ti.TypeId) if err != nil { return @@ -162,7 +162,7 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) { case typeTvp: ti.Writer = writeFixedType default: // all others are VARLENTYPE - err = writeVarLen(w, ti) + err = writeVarLen(w, ti, out) if err != nil { return } @@ -176,7 +176,7 @@ func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) { } // https://msdn.microsoft.com/en-us/library/dd358341.aspx -func writeVarLen(w io.Writer, ti *typeInfo) (err error) { +func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) { switch ti.TypeId { case typeDateN: @@ -222,7 +222,7 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) { typeNVarChar, typeNChar, typeXml, typeUdt: // short len types - if ti.Size > 8000 || ti.Size == 0 { + if ti.Size > 8000 || ti.Size == 0 || out { if err = binary.Write(w, binary.LittleEndian, uint16(0xffff)); err != nil { return }