Skip to content

Commit

Permalink
NEOS-1426 Fix SQL Server bit data type in sync (#2646)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Sep 9, 2024
1 parent c26ce60 commit c7769fd
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 131 deletions.
75 changes: 75 additions & 0 deletions internal/sqlserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package sqlserver

import (
"database/sql"
"fmt"
"strings"

"github.com/gofrs/uuid"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
)

func SqlRowToSqlServerTypesMap(rows *sql.Rows) (map[string]any, error) {
Expand Down Expand Up @@ -68,3 +70,76 @@ func BitsToUuidString(bits []byte) (string, error) {
}
return u.String(), nil
}

func GeSqlServerDefaultValuesInsertSql(schema, table string, rowCount int) string {
var sqlStr string
for i := 0; i < rowCount; i++ {
sqlStr += fmt.Sprintf("INSERT INTO %q.%q DEFAULT VALUES;", schema, table)
}
return sqlStr
}

func GoTypeToSqlServerType(rows [][]any) [][]any {
newRows := [][]any{}
for _, r := range rows {
newRow := []any{}
for _, v := range r {
switch t := v.(type) {
case bool:
newRow = append(newRow, toBit(t))
default:
newRow = append(newRow, t)
}
}
newRows = append(newRows, newRow)
}
return newRows
}

func toBit(v bool) int {
if v {
return 1
}
return 0
}

func FilterOutSqlServerDefaultIdentityColumns(
driver string,
identityCols, columnNames []string,
argRows [][]any,
) (columns []string, rows [][]any) {
if len(identityCols) == 0 || driver != sqlmanager_shared.MssqlDriver {
return columnNames, argRows
}

// build map of identity columns
identityColMap := map[string]bool{}
for _, id := range identityCols {
identityColMap[id] = true
}

nonIdentityColumnMap := map[string]struct{}{} // map of non identity columns
newRows := [][]any{}
// build rows removing identity columns/args with default set
for _, row := range argRows {
newRow := []any{}
for idx, arg := range row {
col := columnNames[idx]
if identityColMap[col] && arg == "DEFAULT" {
// pass on identity columns with a default
continue
}
newRow = append(newRow, arg)
nonIdentityColumnMap[col] = struct{}{}
}
newRows = append(newRows, newRow)
}
newColumns := []string{}
// build new columns list while maintaining same order
for _, col := range columnNames {
if _, ok := nonIdentityColumnMap[col]; ok {
newColumns = append(newColumns, col)
}
}
return newColumns, newRows
}
109 changes: 109 additions & 0 deletions internal/sqlserver/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package sqlserver

import (
"testing"

sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
"github.com/stretchr/testify/require"
)

func Test_filterIdentityColumns(t *testing.T) {
t.Run("Non-MSSQL driver", func(t *testing.T) {
driver := "mysql"
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, columnNames, gotCols, "Columns should remain unchanged for non-MSSQL driver")
require.Equal(t, argRows, gotRows, "Rows should remain unchanged for non-MSSQL driver")
})

t.Run("MSSQL driver with identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"id", "name", "age"}, gotCols, "Identity column should be removed")
require.Equal(t, [][]any{{1, "Alice", 30}, {2, "Bob", 25}}, gotRows, "Identity column values should be removed")
})

t.Run("MSSQL driver with DEFAULT value", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age", "city"}
argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}}

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"name", "age", "city"}, gotCols, "All columns should be present when DEFAULT is used")
require.Equal(t, [][]any{{"Alice", 30, "DEFAULT"}, {"Bob", 25, "DEFAULT"}}, gotRows, "All rows should remain unchanged when DEFAULT is used")
})

t.Run("Empty identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, columnNames, gotCols, "Columns should remain unchanged with empty identity columns")
require.Equal(t, argRows, gotRows, "Rows should remain unchanged with empty identity columns")
})

t.Run("Multiple identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id", "created_at"}
columnNames := []string{"id", "name", "age", "created_at"}
argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}}

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"name", "age"}, gotCols, "Multiple identity columns should be removed")
require.Equal(t, [][]any{{"Alice", 30}, {"Bob", 25}}, gotRows, "Multiple identity column values should be removed")
})
}

func Test_GoTypeToSqlServerType(t *testing.T) {
t.Run("GoTypeToSqlServerType", func(t *testing.T) {
t.Run("Empty input", func(t *testing.T) {
input := [][]any{}
result := GoTypeToSqlServerType(input)
require.Equal(t, [][]any{}, result)
})

t.Run("Single row with no boolean", func(t *testing.T) {
input := [][]any{{1, "test", 3.14}}
expected := [][]any{{1, "test", 3.14}}
result := GoTypeToSqlServerType(input)
require.Equal(t, expected, result)
})

t.Run("Single row with boolean", func(t *testing.T) {
input := [][]any{{true, false, "test"}}
expected := [][]any{{1, 0, "test"}}
result := GoTypeToSqlServerType(input)
require.Equal(t, expected, result)
})

t.Run("Multiple rows with mixed types", func(t *testing.T) {
input := [][]any{
{1, true, "test1"},
{2, false, "test2"},
{3, true, "test3"},
}
expected := [][]any{
{1, 1, "test1"},
{2, 0, "test2"},
{3, 1, "test3"},
}
result := GoTypeToSqlServerType(input)
require.Equal(t, expected, result)
})
})
}
74 changes: 12 additions & 62 deletions worker/pkg/benthos/sql/output_sql_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"sync"

"github.com/Jeffail/shutdown"
"github.com/doug-martin/goqu/v9"
_ "github.com/doug-martin/goqu/v9/dialect/mysql"
_ "github.com/doug-martin/goqu/v9/dialect/postgres"
mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
sqlserverutil "github.com/nucleuscloud/neosync/internal/sqlserver"
querybuilder "github.com/nucleuscloud/neosync/worker/pkg/query-builder"
"github.com/warpstreamlabs/bento/public/bloblang"
"github.com/warpstreamlabs/bento/public/service"
Expand Down Expand Up @@ -271,24 +271,14 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa
rows = append(rows, args)
}

filteredCols, filteredRows := filterOutMssqlDefaultIdentityColumns(s.driver, s.identityColumns, s.columns, rows)

// set any default transformations
for i, row := range filteredRows {
for j, arg := range row {
if arg == "DEFAULT" {
filteredRows[i][j] = goqu.L("DEFAULT")
}
}
}

insertQuery, err := querybuilder.BuildInsertQuery(s.driver, fmt.Sprintf("%s.%s", s.schema, s.table), filteredCols, filteredRows, &s.onConflictDoNothing)
processedCols, processedRows := s.processRows(s.columns, rows)
insertQuery, err := querybuilder.BuildInsertQuery(s.driver, fmt.Sprintf("%s.%s", s.schema, s.table), processedCols, processedRows, &s.onConflictDoNothing)
if err != nil {
return err
}

if s.driver == sqlmanager_shared.MssqlDriver && len(filteredCols) == 0 {
insertQuery = getMssqlDefaultValuesInsertSql(s.schema, s.table, len(rows))
if s.driver == sqlmanager_shared.MssqlDriver && len(processedCols) == 0 {
insertQuery = sqlserverutil.GeSqlServerDefaultValuesInsertSql(s.schema, s.table, len(rows))
}

query := s.buildQuery(insertQuery)
Expand All @@ -298,54 +288,14 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa
return nil
}

// use when all columns are identity generation columns
func getMssqlDefaultValuesInsertSql(schema, table string, rowCount int) string {
var sql string
for i := 0; i < rowCount; i++ {
sql += fmt.Sprintf("INSERT INTO %q.%q DEFAULT VALUES;", schema, table)
}
return sql
}

func filterOutMssqlDefaultIdentityColumns(
driver string,
identityCols, columnNames []string,
argRows [][]any,
) (columns []string, rows [][]any) {
if len(identityCols) == 0 || driver != sqlmanager_shared.MssqlDriver {
return columnNames, argRows
}

// build map of identity columns
identityColMap := map[string]bool{}
for _, id := range identityCols {
identityColMap[id] = true
}

nonIdentityColumnMap := map[string]struct{}{} // map of non identity columns
newRows := [][]any{}
// build rows removing identity columns/args with default set
for _, row := range argRows {
newRow := []any{}
for idx, arg := range row {
col := columnNames[idx]
if identityColMap[col] && arg == "DEFAULT" {
// pass on identity columns with a default
continue
}
newRow = append(newRow, arg)
nonIdentityColumnMap[col] = struct{}{}
}
newRows = append(newRows, newRow)
}
newColumns := []string{}
// build new columns list while maintaining same order
for _, col := range columnNames {
if _, ok := nonIdentityColumnMap[col]; ok {
newColumns = append(newColumns, col)
}
func (s *pooledInsertOutput) processRows(columnNames []string, dataRows [][]any) (columns []string, rows [][]any) {
switch s.driver {
case sqlmanager_shared.MssqlDriver:
newDataRows := sqlserverutil.GoTypeToSqlServerType(dataRows)
return sqlserverutil.FilterOutSqlServerDefaultIdentityColumns(s.driver, s.identityColumns, s.columns, newDataRows)
default:
return columnNames, dataRows
}
return newColumns, newRows
}

func (s *pooledInsertOutput) buildQuery(insertQuery string) string {
Expand Down
63 changes: 0 additions & 63 deletions worker/pkg/benthos/sql/output_sql_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"testing"

sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
"github.com/stretchr/testify/require"
"github.com/warpstreamlabs/bento/public/service"
)
Expand All @@ -27,65 +26,3 @@ args_mapping: 'root = [this.id]'
require.NoError(t, err)
require.NoError(t, insertOutput.Close(context.Background()))
}

func Test_filterIdentityColumns(t *testing.T) {
t.Run("Non-MSSQL driver", func(t *testing.T) {
driver := "mysql"
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, columnNames, gotCols, "Columns should remain unchanged for non-MSSQL driver")
require.Equal(t, argRows, gotRows, "Rows should remain unchanged for non-MSSQL driver")
})

t.Run("MSSQL driver with identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"id", "name", "age"}, gotCols, "Identity column should be removed")
require.Equal(t, [][]any{{1, "Alice", 30}, {2, "Bob", 25}}, gotRows, "Identity column values should be removed")
})

t.Run("MSSQL driver with DEFAULT value", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id"}
columnNames := []string{"id", "name", "age", "city"}
argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}}

gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"name", "age", "city"}, gotCols, "All columns should be present when DEFAULT is used")
require.Equal(t, [][]any{{"Alice", 30, "DEFAULT"}, {"Bob", 25, "DEFAULT"}}, gotRows, "All rows should remain unchanged when DEFAULT is used")
})

t.Run("Empty identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{}
columnNames := []string{"id", "name", "age"}
argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}}

gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, columnNames, gotCols, "Columns should remain unchanged with empty identity columns")
require.Equal(t, argRows, gotRows, "Rows should remain unchanged with empty identity columns")
})

t.Run("Multiple identity columns", func(t *testing.T) {
driver := sqlmanager_shared.MssqlDriver
identityCols := []string{"id", "created_at"}
columnNames := []string{"id", "name", "age", "created_at"}
argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}}

gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"name", "age"}, gotCols, "Multiple identity columns should be removed")
require.Equal(t, [][]any{{"Alice", 30}, {"Bob", 25}}, gotRows, "Multiple identity column values should be removed")
})
}
Loading

0 comments on commit c7769fd

Please sign in to comment.