diff --git a/internal/sqlserver/utils.go b/internal/sqlserver/utils.go index 8851e98c6a..a734ec5104 100644 --- a/internal/sqlserver/utils.go +++ b/internal/sqlserver/utils.go @@ -2,9 +2,11 @@ package sqlserver import ( "database/sql" + "fmt" "strings" "github.com/gofrs/uuid" + sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" ) func SqlRowToSqlServerTypesMap(rows *sql.Rows) (map[string]any, error) { @@ -68,3 +70,76 @@ func BitsToUuidString(bits []byte) (string, error) { } return u.String(), nil } + +func GeSqlServerDefaultValuesInsertSql(schema, table string, rowCount int) string { + var sqlStr string + for i := 0; i < rowCount; i++ { + sqlStr += fmt.Sprintf("INSERT INTO %q.%q DEFAULT VALUES;", schema, table) + } + return sqlStr +} + +func GoTypeToSqlServerType(rows [][]any) [][]any { + newRows := [][]any{} + for _, r := range rows { + newRow := []any{} + for _, v := range r { + switch t := v.(type) { + case bool: + newRow = append(newRow, toBit(t)) + default: + newRow = append(newRow, t) + } + } + newRows = append(newRows, newRow) + } + return newRows +} + +func toBit(v bool) int { + if v { + return 1 + } + return 0 +} + +func FilterOutSqlServerDefaultIdentityColumns( + driver string, + identityCols, columnNames []string, + argRows [][]any, +) (columns []string, rows [][]any) { + if len(identityCols) == 0 || driver != sqlmanager_shared.MssqlDriver { + return columnNames, argRows + } + + // build map of identity columns + identityColMap := map[string]bool{} + for _, id := range identityCols { + identityColMap[id] = true + } + + nonIdentityColumnMap := map[string]struct{}{} // map of non identity columns + newRows := [][]any{} + // build rows removing identity columns/args with default set + for _, row := range argRows { + newRow := []any{} + for idx, arg := range row { + col := columnNames[idx] + if identityColMap[col] && arg == "DEFAULT" { + // pass on identity columns with a default + continue + } + newRow = append(newRow, arg) + nonIdentityColumnMap[col] = struct{}{} + } + newRows = append(newRows, newRow) + } + newColumns := []string{} + // build new columns list while maintaining same order + for _, col := range columnNames { + if _, ok := nonIdentityColumnMap[col]; ok { + newColumns = append(newColumns, col) + } + } + return newColumns, newRows +} diff --git a/internal/sqlserver/utils_test.go b/internal/sqlserver/utils_test.go new file mode 100644 index 0000000000..65e07d4b37 --- /dev/null +++ b/internal/sqlserver/utils_test.go @@ -0,0 +1,109 @@ +package sqlserver + +import ( + "testing" + + sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + "github.com/stretchr/testify/require" +) + +func Test_filterIdentityColumns(t *testing.T) { + t.Run("Non-MSSQL driver", func(t *testing.T) { + driver := "mysql" + identityCols := []string{"id"} + columnNames := []string{"id", "name", "age"} + argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} + + gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows) + + require.Equal(t, columnNames, gotCols, "Columns should remain unchanged for non-MSSQL driver") + require.Equal(t, argRows, gotRows, "Rows should remain unchanged for non-MSSQL driver") + }) + + t.Run("MSSQL driver with identity columns", func(t *testing.T) { + driver := sqlmanager_shared.MssqlDriver + identityCols := []string{"id"} + columnNames := []string{"id", "name", "age"} + argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} + + gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows) + + require.Equal(t, []string{"id", "name", "age"}, gotCols, "Identity column should be removed") + require.Equal(t, [][]any{{1, "Alice", 30}, {2, "Bob", 25}}, gotRows, "Identity column values should be removed") + }) + + t.Run("MSSQL driver with DEFAULT value", func(t *testing.T) { + driver := sqlmanager_shared.MssqlDriver + identityCols := []string{"id"} + columnNames := []string{"id", "name", "age", "city"} + argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}} + + gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows) + + require.Equal(t, []string{"name", "age", "city"}, gotCols, "All columns should be present when DEFAULT is used") + require.Equal(t, [][]any{{"Alice", 30, "DEFAULT"}, {"Bob", 25, "DEFAULT"}}, gotRows, "All rows should remain unchanged when DEFAULT is used") + }) + + t.Run("Empty identity columns", func(t *testing.T) { + driver := sqlmanager_shared.MssqlDriver + identityCols := []string{} + columnNames := []string{"id", "name", "age"} + argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} + + gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows) + + require.Equal(t, columnNames, gotCols, "Columns should remain unchanged with empty identity columns") + require.Equal(t, argRows, gotRows, "Rows should remain unchanged with empty identity columns") + }) + + t.Run("Multiple identity columns", func(t *testing.T) { + driver := sqlmanager_shared.MssqlDriver + identityCols := []string{"id", "created_at"} + columnNames := []string{"id", "name", "age", "created_at"} + argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}} + + gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows) + + require.Equal(t, []string{"name", "age"}, gotCols, "Multiple identity columns should be removed") + require.Equal(t, [][]any{{"Alice", 30}, {"Bob", 25}}, gotRows, "Multiple identity column values should be removed") + }) +} + +func Test_GoTypeToSqlServerType(t *testing.T) { + t.Run("GoTypeToSqlServerType", func(t *testing.T) { + t.Run("Empty input", func(t *testing.T) { + input := [][]any{} + result := GoTypeToSqlServerType(input) + require.Equal(t, [][]any{}, result) + }) + + t.Run("Single row with no boolean", func(t *testing.T) { + input := [][]any{{1, "test", 3.14}} + expected := [][]any{{1, "test", 3.14}} + result := GoTypeToSqlServerType(input) + require.Equal(t, expected, result) + }) + + t.Run("Single row with boolean", func(t *testing.T) { + input := [][]any{{true, false, "test"}} + expected := [][]any{{1, 0, "test"}} + result := GoTypeToSqlServerType(input) + require.Equal(t, expected, result) + }) + + t.Run("Multiple rows with mixed types", func(t *testing.T) { + input := [][]any{ + {1, true, "test1"}, + {2, false, "test2"}, + {3, true, "test3"}, + } + expected := [][]any{ + {1, 1, "test1"}, + {2, 0, "test2"}, + {3, 1, "test3"}, + } + result := GoTypeToSqlServerType(input) + require.Equal(t, expected, result) + }) + }) +} diff --git a/worker/pkg/benthos/sql/output_sql_insert.go b/worker/pkg/benthos/sql/output_sql_insert.go index 58f1406c3a..17bd7c7f11 100644 --- a/worker/pkg/benthos/sql/output_sql_insert.go +++ b/worker/pkg/benthos/sql/output_sql_insert.go @@ -7,11 +7,11 @@ import ( "sync" "github.com/Jeffail/shutdown" - "github.com/doug-martin/goqu/v9" _ "github.com/doug-martin/goqu/v9/dialect/mysql" _ "github.com/doug-martin/goqu/v9/dialect/postgres" mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + sqlserverutil "github.com/nucleuscloud/neosync/internal/sqlserver" querybuilder "github.com/nucleuscloud/neosync/worker/pkg/query-builder" "github.com/warpstreamlabs/bento/public/bloblang" "github.com/warpstreamlabs/bento/public/service" @@ -271,24 +271,14 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa rows = append(rows, args) } - filteredCols, filteredRows := filterOutMssqlDefaultIdentityColumns(s.driver, s.identityColumns, s.columns, rows) - - // set any default transformations - for i, row := range filteredRows { - for j, arg := range row { - if arg == "DEFAULT" { - filteredRows[i][j] = goqu.L("DEFAULT") - } - } - } - - insertQuery, err := querybuilder.BuildInsertQuery(s.driver, fmt.Sprintf("%s.%s", s.schema, s.table), filteredCols, filteredRows, &s.onConflictDoNothing) + processedCols, processedRows := s.processRows(s.columns, rows) + insertQuery, err := querybuilder.BuildInsertQuery(s.driver, fmt.Sprintf("%s.%s", s.schema, s.table), processedCols, processedRows, &s.onConflictDoNothing) if err != nil { return err } - if s.driver == sqlmanager_shared.MssqlDriver && len(filteredCols) == 0 { - insertQuery = getMssqlDefaultValuesInsertSql(s.schema, s.table, len(rows)) + if s.driver == sqlmanager_shared.MssqlDriver && len(processedCols) == 0 { + insertQuery = sqlserverutil.GeSqlServerDefaultValuesInsertSql(s.schema, s.table, len(rows)) } query := s.buildQuery(insertQuery) @@ -298,54 +288,14 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa return nil } -// use when all columns are identity generation columns -func getMssqlDefaultValuesInsertSql(schema, table string, rowCount int) string { - var sql string - for i := 0; i < rowCount; i++ { - sql += fmt.Sprintf("INSERT INTO %q.%q DEFAULT VALUES;", schema, table) - } - return sql -} - -func filterOutMssqlDefaultIdentityColumns( - driver string, - identityCols, columnNames []string, - argRows [][]any, -) (columns []string, rows [][]any) { - if len(identityCols) == 0 || driver != sqlmanager_shared.MssqlDriver { - return columnNames, argRows - } - - // build map of identity columns - identityColMap := map[string]bool{} - for _, id := range identityCols { - identityColMap[id] = true - } - - nonIdentityColumnMap := map[string]struct{}{} // map of non identity columns - newRows := [][]any{} - // build rows removing identity columns/args with default set - for _, row := range argRows { - newRow := []any{} - for idx, arg := range row { - col := columnNames[idx] - if identityColMap[col] && arg == "DEFAULT" { - // pass on identity columns with a default - continue - } - newRow = append(newRow, arg) - nonIdentityColumnMap[col] = struct{}{} - } - newRows = append(newRows, newRow) - } - newColumns := []string{} - // build new columns list while maintaining same order - for _, col := range columnNames { - if _, ok := nonIdentityColumnMap[col]; ok { - newColumns = append(newColumns, col) - } +func (s *pooledInsertOutput) processRows(columnNames []string, dataRows [][]any) (columns []string, rows [][]any) { + switch s.driver { + case sqlmanager_shared.MssqlDriver: + newDataRows := sqlserverutil.GoTypeToSqlServerType(dataRows) + return sqlserverutil.FilterOutSqlServerDefaultIdentityColumns(s.driver, s.identityColumns, s.columns, newDataRows) + default: + return columnNames, dataRows } - return newColumns, newRows } func (s *pooledInsertOutput) buildQuery(insertQuery string) string { diff --git a/worker/pkg/benthos/sql/output_sql_insert_test.go b/worker/pkg/benthos/sql/output_sql_insert_test.go index 3463195763..9719e86bef 100644 --- a/worker/pkg/benthos/sql/output_sql_insert_test.go +++ b/worker/pkg/benthos/sql/output_sql_insert_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" "github.com/stretchr/testify/require" "github.com/warpstreamlabs/bento/public/service" ) @@ -27,65 +26,3 @@ args_mapping: 'root = [this.id]' require.NoError(t, err) require.NoError(t, insertOutput.Close(context.Background())) } - -func Test_filterIdentityColumns(t *testing.T) { - t.Run("Non-MSSQL driver", func(t *testing.T) { - driver := "mysql" - identityCols := []string{"id"} - columnNames := []string{"id", "name", "age"} - argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} - - gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows) - - require.Equal(t, columnNames, gotCols, "Columns should remain unchanged for non-MSSQL driver") - require.Equal(t, argRows, gotRows, "Rows should remain unchanged for non-MSSQL driver") - }) - - t.Run("MSSQL driver with identity columns", func(t *testing.T) { - driver := sqlmanager_shared.MssqlDriver - identityCols := []string{"id"} - columnNames := []string{"id", "name", "age"} - argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} - - gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows) - - require.Equal(t, []string{"id", "name", "age"}, gotCols, "Identity column should be removed") - require.Equal(t, [][]any{{1, "Alice", 30}, {2, "Bob", 25}}, gotRows, "Identity column values should be removed") - }) - - t.Run("MSSQL driver with DEFAULT value", func(t *testing.T) { - driver := sqlmanager_shared.MssqlDriver - identityCols := []string{"id"} - columnNames := []string{"id", "name", "age", "city"} - argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}} - - gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows) - - require.Equal(t, []string{"name", "age", "city"}, gotCols, "All columns should be present when DEFAULT is used") - require.Equal(t, [][]any{{"Alice", 30, "DEFAULT"}, {"Bob", 25, "DEFAULT"}}, gotRows, "All rows should remain unchanged when DEFAULT is used") - }) - - t.Run("Empty identity columns", func(t *testing.T) { - driver := sqlmanager_shared.MssqlDriver - identityCols := []string{} - columnNames := []string{"id", "name", "age"} - argRows := [][]any{{1, "Alice", 30}, {2, "Bob", 25}} - - gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows) - - require.Equal(t, columnNames, gotCols, "Columns should remain unchanged with empty identity columns") - require.Equal(t, argRows, gotRows, "Rows should remain unchanged with empty identity columns") - }) - - t.Run("Multiple identity columns", func(t *testing.T) { - driver := sqlmanager_shared.MssqlDriver - identityCols := []string{"id", "created_at"} - columnNames := []string{"id", "name", "age", "created_at"} - argRows := [][]any{{"DEFAULT", "Alice", 30, "DEFAULT"}, {"DEFAULT", "Bob", 25, "DEFAULT"}} - - gotCols, gotRows := filterOutMssqlDefaultIdentityColumns(driver, identityCols, columnNames, argRows) - - require.Equal(t, []string{"name", "age"}, gotCols, "Multiple identity columns should be removed") - require.Equal(t, [][]any{{"Alice", 30}, {"Bob", 25}}, gotRows, "Multiple identity column values should be removed") - }) -} diff --git a/worker/pkg/query-builder/query-builder.go b/worker/pkg/query-builder/query-builder.go index 6214e6d8f7..7399e0144a 100644 --- a/worker/pkg/query-builder/query-builder.go +++ b/worker/pkg/query-builder/query-builder.go @@ -14,6 +14,8 @@ import ( pgutil "github.com/nucleuscloud/neosync/internal/postgres" ) +const defaultStr = "DEFAULT" + type SubsetReferenceKey struct { Table string Columns []string @@ -73,7 +75,11 @@ func getGoquVals(driver string, row []any) goqu.Vals { } gval := goqu.Vals{} for _, a := range row { - gval = append(gval, a) + if a == defaultStr { + gval = append(gval, goqu.Literal(defaultStr)) + } else { + gval = append(gval, a) + } } return gval } @@ -86,6 +92,8 @@ func getPgGoquVals(row []any) goqu.Vals { gval = append(gval, goqu.Literal(pgutil.FormatPgArrayLiteral(a))) } else if ok { gval = append(gval, pq.Array(ar)) + } else if a == defaultStr { + gval = append(gval, goqu.Literal(defaultStr)) } else { gval = append(gval, a) } @@ -134,8 +142,8 @@ func BuildUpdateQuery( updateRecord := goqu.Record{} for _, col := range insertColumns { val := columnValueMap[col] - if val == "DEFAULT" { - updateRecord[col] = goqu.L("DEFAULT") + if val == defaultStr { + updateRecord[col] = goqu.L(defaultStr) } else { updateRecord[col] = val } diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql index 87ceb1f4f5..a1c3ccfa5d 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql @@ -2,7 +2,7 @@ CREATE TABLE alltypes.alldatatypes ( -- Exact numerics col_bigint BIGINT, col_numeric NUMERIC(18,0), - -- col_bit BIT, + col_bit BIT, col_smallint SMALLINT, col_decimal DECIMAL(18,0), col_smallmoney SMALLMONEY, diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql index 777905674c..a4ecdee50f 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql @@ -1,7 +1,7 @@ INSERT INTO alltypes.alldatatypes ( -- Exact numerics col_bigint, col_numeric, - -- col_bit, + col_bit, col_smallint, col_decimal, col_smallmoney, col_int, col_tinyint, col_money, -- Approximate numerics col_float, col_real, @@ -24,7 +24,7 @@ VALUES ( -- Exact numerics 9223372036854775807, -- BIGINT max value 1234567890, -- NUMERIC - -- 1, -- BIT + 1, -- BIT 32767, -- SMALLINT max value 1234567890, -- DECIMAL 214748.3647, -- SMALLMONEY max value diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go index 3920038472..468275c96c 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go @@ -26,6 +26,14 @@ func GetDefaultSyncJobMappings()[]*mgmtv1alpha1.JobMapping { Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, }, }, + { + Schema: "alltypes", + Table: "alldatatypes", + Column: "col_bit", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, { Schema: "alltypes", Table: "alldatatypes", @@ -212,6 +220,7 @@ func GetTableColumnTypeMap() map[string]map[string]string { "alltypes.alldatatypes": { "col_bigint": "BIGINT", "col_numeric": "NUMERIC(18,0)", + "col_bit": "BIT", "col_smallint": "SMALLINT", "col_decimal": "DECIMAL(18,0)", "col_smallmoney": "SMALLMONEY",