Skip to content

Commit

Permalink
NEOS-1489 Fix default being converted to database DEFAULT (#2741)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Sep 25, 2024
1 parent c5d48b6 commit 074ff2c
Show file tree
Hide file tree
Showing 22 changed files with 514 additions and 176 deletions.
21 changes: 21 additions & 0 deletions backend/pkg/sqlmanager/mssql/mssql-manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
"github.com/nucleuscloud/neosync/internal/gotypeutil"
)

type Manager struct {
Expand Down Expand Up @@ -297,3 +298,23 @@ func BuildMssqlSetIdentityInsertStatement(
}
return fmt.Sprintf("SET IDENTITY_INSERT %q.%q %s;", schema, table, enabledKeyword)
}

func GetMssqlColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.ColumnInfo) (needsOverride, needsReset bool) {
needsOverride = false
needsReset = false

// check if the column is an idenitity type
if columnInfo.IdentityGeneration != nil && *columnInfo.IdentityGeneration != "" {
needsOverride = true
needsReset = true
return
}

// check if column default is sequence
if columnInfo.ColumnDefault != "" && gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "NEXT VALUE") {
needsReset = true
return
}

return
}
37 changes: 37 additions & 0 deletions backend/pkg/sqlmanager/mssql/mssql-manager_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ func (s *IntegrationTestSuite) Test_GetTableConstraintsBySchema() {

"sqlmanagermssql2.TableA": {"IdA1", "IdA2"},
"sqlmanagermssql2.TableB": {"IdB1", "IdB2"},

"sqlmanagermssql2.defaults_table": {"id"},
},
}

Expand Down Expand Up @@ -301,3 +303,38 @@ func containsSubset[T any](t testing.TB, array, subset []T) {
require.Contains(t, array, elem)
}
}

type testColumnProperties struct {
needsOverride bool
needsReset bool
}

func (s *IntegrationTestSuite) Test_GetMssqlColumnOverrideAndResetProperties() {
manager := NewManager(s.source.querier, s.source.testDb, func() {})

colInfoMap, err := manager.GetSchemaColumnMap(context.Background())
require.NoError(s.T(), err)

testDefaultTable := colInfoMap["testdb.sqlmanagermssql2.defaults_table"]

var expectedProperties = map[string]testColumnProperties{
"description": {needsOverride: false, needsReset: false},
"registration_date": {needsOverride: false, needsReset: false},
"score": {needsOverride: false, needsReset: false},
"status": {needsOverride: false, needsReset: false},
"id": {needsOverride: true, needsReset: true},
"last_login": {needsOverride: false, needsReset: false},
"age": {needsOverride: false, needsReset: false},
"is_active": {needsOverride: false, needsReset: false},
"created_at": {needsOverride: false, needsReset: false},
"uuid": {needsOverride: false, needsReset: false},
}

for col, colInfo := range testDefaultTable {
needsOverride, needsReset := GetMssqlColumnOverrideAndResetProperties(colInfo)
expected, ok := expectedProperties[col]
require.Truef(s.T(), ok, "Missing expected column %q", col)
require.Equalf(s.T(), expected.needsOverride, needsOverride, "Incorrect needsOverride value for column %q", col)
require.Equalf(s.T(), expected.needsReset, needsReset, "Incorrect needsReset value for column %q", col)
}
}
14 changes: 14 additions & 0 deletions backend/pkg/sqlmanager/mssql/testdata/source-setup/004.sql
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ FOREIGN KEY (IdB1, IdB2) REFERENCES testdb.sqlmanagermssql2.TableB(IdB1, IdB2);
ALTER TABLE testdb.sqlmanagermssql2.TableB
ADD CONSTRAINT FK_TableB_TableA
FOREIGN KEY (IdA1, IdA2) REFERENCES testdb.sqlmanagermssql2.TableA(IdA1, IdA2);


CREATE TABLE testdb.sqlmanagermssql2.defaults_table (
id INT IDENTITY(1,1) PRIMARY KEY,
description NVARCHAR(MAX),
age INT DEFAULT 18,
is_active BIT DEFAULT 1,
registration_date DATE DEFAULT GETDATE(),
last_login DATETIME2,
score DECIMAL(10,2) DEFAULT 0.00,
status NVARCHAR(20) DEFAULT 'pending',
created_at DATETIME2 DEFAULT SYSDATETIME(),
uuid UNIQUEIDENTIFIER DEFAULT NEWID()
);
9 changes: 7 additions & 2 deletions backend/pkg/sqlmanager/mysql/mysql-manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql
return output, nil
}

//nolint:gofmt
func convertUInt8ToString(value interface{}) (string, error) {
func convertUInt8ToString(value any) (string, error) {
convertedType, ok := value.([]uint8)
if !ok {
return "", fmt.Errorf("failed to convert []uint8 to string")
Expand Down Expand Up @@ -793,3 +792,9 @@ func EscapeMysqlColumns(cols []string) []string {
func EscapeMysqlColumn(col string) string {
return fmt.Sprintf("`%s`", col)
}

func GetMysqlColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.ColumnInfo) (needsOverride, needsReset bool) {
needsOverride = false
needsReset = false
return
}
26 changes: 26 additions & 0 deletions backend/pkg/sqlmanager/postgres/postgres-manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql"
"github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
"github.com/nucleuscloud/neosync/internal/gotypeutil"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -914,3 +915,28 @@ func BuildPgInsertIdentityAlwaysSql(
sqlSplit := strings.Split(insertQuery, ") VALUES (")
return sqlSplit[0] + ") OVERRIDING SYSTEM VALUE VALUES(" + sqlSplit[1]
}

func GetPostgresColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.ColumnInfo) (needsOverride, needsReset bool) {
needsOverride = false
needsReset = false

// check if the column is an idenitity type
if columnInfo.IdentityGeneration != nil && *columnInfo.IdentityGeneration != "" {
switch *columnInfo.IdentityGeneration {
case "a": // ALWAYS
needsOverride = true
needsReset = true
case "d": // DEFAULT
needsReset = true
}
return
}

// check if column default is sequence
if columnInfo.ColumnDefault != "" && gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "nextVal") {
needsReset = true
return
}

return
}
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,40 @@ func containsSubset[T any](t testing.TB, array, subset []T) {
require.Contains(t, array, elem)
}
}

type testColumnProperties struct {
needsOverride bool
needsReset bool
}

func (s *IntegrationTestSuite) Test_GetPostgresColumnOverrideAndResetProperties() {
manager := PostgresManager{querier: s.querier, pool: s.pgpool}

colInfoMap, err := manager.GetSchemaColumnMap(context.Background())
require.NoError(s.T(), err)

testDefaultTable := colInfoMap["sqlmanagerpostgres@special.defaults_table"]

var expectedProperties = map[string]testColumnProperties{
"description": {needsOverride: false, needsReset: false},
"registration_date": {needsOverride: false, needsReset: false},
"score": {needsOverride: false, needsReset: false},
"status": {needsOverride: false, needsReset: false},
"id": {needsOverride: true, needsReset: true},
"sequence_number": {needsOverride: false, needsReset: true},
"last_login": {needsOverride: false, needsReset: false},
"age": {needsOverride: false, needsReset: false},
"is_active": {needsOverride: false, needsReset: false},
"created_at": {needsOverride: false, needsReset: false},
"uuid": {needsOverride: false, needsReset: false},
"serial_number": {needsOverride: false, needsReset: true},
}

for col, colInfo := range testDefaultTable {
needsOverride, needsReset := GetPostgresColumnOverrideAndResetProperties(colInfo)
expected, ok := expectedProperties[col]
require.Truef(s.T(), ok, "Missing expected column %q", col)
require.Equalf(s.T(), expected.needsOverride, needsOverride, "Incorrect needsOverride value for column %q", col)
require.Equalf(s.T(), expected.needsReset, needsReset, "Incorrect needsReset value for column %q", col)
}
}
16 changes: 16 additions & 0 deletions backend/pkg/sqlmanager/postgres/testdata/setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ CREATE TABLE tablewithcount (
id TEXT NOT NULL
);
INSERT INTO tablewithcount(id) VALUES ('1'), ('2');


CREATE TABLE defaults_table (
id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
description TEXT,
age INT DEFAULT 18,
is_active BOOLEAN DEFAULT true,
registration_date DATE DEFAULT CURRENT_DATE,
last_login TIMESTAMP,
score NUMERIC(10,2) DEFAULT 0.00,
status VARCHAR(20) DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
sequence_number INT GENERATED BY DEFAULT AS IDENTITY,
uuid UUID DEFAULT gen_random_uuid(),
serial_number SERIAL
);
16 changes: 16 additions & 0 deletions backend/pkg/sqlmanager/sql-manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,19 @@ func (s *SqlManager) NewSqlDbFromUrl(
Driver: driver,
}, nil
}

func GetColumnOverrideAndResetProperties(driver string, cInfo *sqlmanager_shared.ColumnInfo) (needsOverride, needsReset bool, err error) {
switch driver {
case sqlmanager_shared.PostgresDriver:
needsOverride, needsReset := sqlmanager_postgres.GetPostgresColumnOverrideAndResetProperties(cInfo)
return needsOverride, needsReset, nil
case sqlmanager_shared.MysqlDriver:
needsOverride, needsReset := sqlmanager_mysql.GetMysqlColumnOverrideAndResetProperties(cInfo)
return needsOverride, needsReset, nil
case sqlmanager_shared.MssqlDriver:
needsOverride, needsReset := sqlmanager_mssql.GetMssqlColumnOverrideAndResetProperties(cInfo)
return needsOverride, needsReset, nil
default:
return false, false, fmt.Errorf("unsupported sql driver: %s", driver)
}
}
7 changes: 7 additions & 0 deletions internal/gotypeutil/strings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package gotypeutil

import "strings"

func CaseInsensitiveContains(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}
25 changes: 25 additions & 0 deletions internal/gotypeutil/strings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package gotypeutil

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_CaseInsensitiveContains(t *testing.T) {
t.Run("CaseInsensitiveContains", func(t *testing.T) {
require.True(t, CaseInsensitiveContains("Hello, World!", "hello"), "Should find lowercase substring")
require.True(t, CaseInsensitiveContains("Hello, World!", "WORLD"), "Should find uppercase substring")
require.True(t, CaseInsensitiveContains("Hello, World!", "o, wo"), "Should find mixed case substring")
require.True(t, CaseInsensitiveContains("Hello, World!", ""), "Should return true for empty substring")
require.True(t, CaseInsensitiveContains("Hello, World!", "Hello, World!"), "Should find when substring is equal to string")

require.False(t, CaseInsensitiveContains("Hello, World!", "goodbye"), "Should not find non-existent substring")
require.False(t, CaseInsensitiveContains("", "test"), "Should return false when string is empty and substring is not")
require.False(t, CaseInsensitiveContains("Hello", "Hello, World!"), "Should return false when substring is longer than string")

require.True(t, CaseInsensitiveContains("HeLLo, WoRLD!", "hello, world!"), "Should handle mixed case in both string and substring")
require.True(t, CaseInsensitiveContains("HELLO", "hello"), "Should handle all uppercase string and lowercase substring")
require.True(t, CaseInsensitiveContains("hello", "HELLO"), "Should handle all lowercase string and uppercase substring")
})
}
12 changes: 6 additions & 6 deletions internal/sqlserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@ func toBit(v bool) int {

func FilterOutSqlServerDefaultIdentityColumns(
driver string,
identityCols, columnNames []string,
defaultIdentityCols, columnNames []string,
argRows [][]any,
) (columns []string, rows [][]any) {
if len(identityCols) == 0 || driver != sqlmanager_shared.MssqlDriver {
if len(defaultIdentityCols) == 0 || driver != sqlmanager_shared.MssqlDriver {
return columnNames, argRows
}

// build map of identity columns
identityColMap := map[string]bool{}
for _, id := range identityCols {
identityColMap[id] = true
defaultIdentityColMap := map[string]bool{}
for _, id := range defaultIdentityCols {
defaultIdentityColMap[id] = true
}

nonIdentityColumnMap := map[string]struct{}{} // map of non identity columns
Expand All @@ -125,7 +125,7 @@ func FilterOutSqlServerDefaultIdentityColumns(
newRow := []any{}
for idx, arg := range row {
col := columnNames[idx]
if identityColMap[col] && arg == "DEFAULT" {
if defaultIdentityColMap[col] {
// pass on identity columns with a default
continue
}
Expand Down
4 changes: 2 additions & 2 deletions internal/sqlserver/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func Test_filterIdentityColumns(t *testing.T) {

gotCols, gotRows := FilterOutSqlServerDefaultIdentityColumns(driver, identityCols, columnNames, argRows)

require.Equal(t, []string{"id", "name", "age"}, gotCols, "Identity column should be removed")
require.Equal(t, [][]any{{1, "Alice", 30}, {2, "Bob", 25}}, gotRows, "Identity column values should be removed")
require.Equal(t, []string{"name", "age"}, gotCols, "Identity column should be removed")
require.Equal(t, [][]any{{"Alice", 30}, {"Bob", 25}}, gotRows, "Identity column values should be removed")
})

t.Run("MSSQL driver with DEFAULT value", func(t *testing.T) {
Expand Down
34 changes: 20 additions & 14 deletions worker/pkg/benthos/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,21 +367,27 @@ type PooledSqlUpdate struct {
Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
}

type ColumnDefaultProperties struct {
NeedsReset bool `json:"needs_reset" yaml:"needs_reset"`
NeedsOverride bool `json:"needs_override" yaml:"needs_override"`
HasDefaultTransformer bool `json:"has_default_transformer" yaml:"has_default_transformer"`
}

type PooledSqlInsert struct {
Driver string `json:"driver" yaml:"driver"`
Dsn string `json:"dsn" yaml:"dsn"`
Schema string `json:"schema" yaml:"schema"`
Table string `json:"table" yaml:"table"`
Columns []string `json:"columns" yaml:"columns"`
ColumnsDataTypes []string `json:"column_data_types" yaml:"column_data_types"`
IdentityColumns []string `json:"identity_columns" yaml:"identity_columns"`
OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"`
TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"`
SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"`
ArgsMapping string `json:"args_mapping" yaml:"args_mapping"`
Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"`
Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"`
Driver string `json:"driver" yaml:"driver"`
Dsn string `json:"dsn" yaml:"dsn"`
Schema string `json:"schema" yaml:"schema"`
Table string `json:"table" yaml:"table"`
Columns []string `json:"columns" yaml:"columns"`
ColumnsDataTypes []string `json:"column_data_types" yaml:"column_data_types"`
ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"`
OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"`
TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"`
SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"`
ArgsMapping string `json:"args_mapping" yaml:"args_mapping"`
Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"`
Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"`
}

type SqlInsert struct {
Expand Down
Loading

0 comments on commit 074ff2c

Please sign in to comment.