Skip to content

Commit

Permalink
fix: use BLOB sql type to encode []byte in MySQL and SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 22, 2021
1 parent 01d0ce2 commit 725ec88
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 9 deletions.
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ func fieldSQLType(field *schema.Field) string {
}
}

if field.DiscoveredSQLType == sqltype.Blob {
return pgTypeBytea
}

return sqlType(field.IndirectType)
}

Expand Down
35 changes: 29 additions & 6 deletions dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqlitedialect

import (
"database/sql"
"encoding/hex"

"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
Expand Down Expand Up @@ -47,14 +48,36 @@ func (d *Dialect) OnTable(table *schema.Table) {
}

func (d *Dialect) onField(field *schema.Field) {
// INTEGER PRIMARY KEY is an alias for the ROWID.
// It is safe to convert all ints to INTEGER, because SQLite types don't have size.
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:
field.DiscoveredSQLType = sqltype.Integer
}
field.DiscoveredSQLType = fieldSQLType(field)
}

func (d *Dialect) IdentQuote() byte {
return '"'
}

func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}

b = append(b, `X'`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

b = append(b, '\'')

return b
}

func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:
// INTEGER PRIMARY KEY is an alias for the ROWID.
// It is safe to convert all ints to INTEGER, because SQLite types don't have size.
return sqltype.Integer
default:
return field.DiscoveredSQLType
}
}
1 change: 1 addition & 0 deletions dialect/sqltype/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const (
Real = "REAL"
DoublePrecision = "DOUBLE PRECISION"
VarChar = "VARCHAR"
Blob = "BLOB"
Timestamp = "TIMESTAMP"
JSON = "JSON"
JSONB = "JSONB"
Expand Down
25 changes: 23 additions & 2 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ func TestDB(t *testing.T) {
{testFKViolation},
{testInterfaceAny},
{testInterfaceJSON},
{testScanBytes},
{testScanRawMessage},
{testPointers},
{testExists},
{testScanTimeIntoString},
{testModelNonPointer},
{testBinaryData},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -828,7 +829,7 @@ func testInterfaceJSON(t *testing.T, db *bun.DB) {
require.Equal(t, "hello", model.Value)
}

func testScanBytes(t *testing.T, db *bun.DB) {
func testScanRawMessage(t *testing.T, db *bun.DB) {
type Model struct {
ID int64
Value json.RawMessage
Expand Down Expand Up @@ -914,3 +915,23 @@ func testModelNonPointer(t *testing.T, db *bun.DB) {
require.Error(t, err)
require.Equal(t, "bun: Model(non-pointer dbtest_test.Model)", err.Error())
}

func testBinaryData(t *testing.T, db *bun.DB) {
type Model struct {
ID int64
Data []byte
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

_, err = db.NewInsert().Model(&Model{Data: []byte("hello")}).Exec(ctx)
require.NoError(t, err)

var model Model
err = db.NewSelect().Model(&model).Scan(ctx)
require.NoError(t, err)
require.Equal(t, []byte("hello"), model.Data)
}
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-sqlite-60
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO "models" ("bytes") VALUES ('\x00000000000000000000')
INSERT INTO "models" ("bytes") VALUES (X'00000000000000000000')
8 changes: 8 additions & 0 deletions schema/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ func DiscoverSQLType(typ reflect.Type) string {
case nullStringType:
return sqltype.VarChar
}

switch typ.Kind() {
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return sqltype.Blob
}
}

return sqlTypes[typ.Kind()]
}

Expand Down

0 comments on commit 725ec88

Please sign in to comment.