diff --git a/cmd/greenmask/cmd/dump/dump.go b/cmd/greenmask/cmd/dump/dump.go index a2374d57..1fb70a53 100644 --- a/cmd/greenmask/cmd/dump/dump.go +++ b/cmd/greenmask/cmd/dump/dump.go @@ -66,7 +66,6 @@ var ( // TODO: Check how does work mixed options - use-list + tables, etc. // TODO: Options currently are not implemented: // - encoding -// - disable-triggers // - lock-wait-timeout // - no-sync // - data-only @@ -104,7 +103,6 @@ func init() { ) Cmd.Flags().BoolP("no-owner", "O", false, "skip restoration of object ownership in plain-text format") Cmd.Flags().BoolP("schema-only", "s", false, "dump only the schema, no data") - Cmd.Flags().StringP("superuser", "S", "", "superuser user name to use in plain-text format") Cmd.Flags().StringSliceVarP( &Config.Dump.PgDumpOptions.Table, "table", "t", []string{}, "dump the specified table(s) only", ) @@ -113,7 +111,6 @@ func init() { ) Cmd.Flags().BoolP("no-privileges", "X", false, "do not dump privileges (grant/revoke)") Cmd.Flags().BoolP("disable-dollar-quoting", "", false, "disable dollar quoting, use SQL standard quoting") - Cmd.Flags().BoolP("disable-triggers", "", false, "disable triggers during data-only restore") Cmd.Flags().BoolP( "enable-row-security", "", false, "enable row security (dump only content user has access to)", ) @@ -163,8 +160,8 @@ func init() { "file", "jobs", "verbose", "compress", "dbname", "host", "username", "lock-wait-timeout", "no-sync", "data-only", "blobs", "no-blobs", "clean", "create", "extension", "encoding", "schema", "exclude-schema", - "no-owner", "schema-only", "superuser", "table", "exclude-table", "no-privileges", "disable-dollar-quoting", - "disable-triggers", "enable-row-security", "exclude-table-data", "extra-float-digits", "if-exists", + "no-owner", "schema-only", "table", "exclude-table", "no-privileges", "disable-dollar-quoting", + "enable-row-security", "exclude-table-data", "extra-float-digits", "if-exists", "include-foreign-data", "load-via-partition-root", "no-comments", "no-publications", "no-security-labels", "no-subscriptions", "no-synchronized-snapshots", "no-tablespaces", "no-toast-compression", "no-unlogged-table-data", "quote-all-identifiers", "section", diff --git a/cmd/greenmask/cmd/restore/restore.go b/cmd/greenmask/cmd/restore/restore.go index d09b6dcd..b263cb41 100644 --- a/cmd/greenmask/cmd/restore/restore.go +++ b/cmd/greenmask/cmd/restore/restore.go @@ -150,12 +150,11 @@ func init() { Cmd.Flags().BoolP("no-owner", "O", false, "skip restoration of object ownership") Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Function, "function", "P", []string{}, "restore named function") Cmd.Flags().BoolP("schema-only", "s", false, "restore only the schema, no data") - Cmd.Flags().StringP("superuser", "S", "", "superuser user name to use for disabling triggers") Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Table, "table", "t", []string{}, "restore named relation (table, view, etc.)") Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Trigger, "trigger", "T", []string{}, "restore named trigger") Cmd.Flags().BoolP("no-privileges", "X", false, "skip restoration of access privileges (grant/revoke)") Cmd.Flags().BoolP("single-transaction", "1", false, "restore as a single transaction") - Cmd.Flags().BoolP("disable-triggers", "", false, "disable triggers during data-only restore") + Cmd.Flags().BoolP("disable-triggers", "", false, "disable triggers during data section restore") Cmd.Flags().BoolP("enable-row-security", "", false, "enable row security") Cmd.Flags().BoolP("if-exists", "", false, "use IF EXISTS when dropping objects") Cmd.Flags().BoolP("no-comments", "", false, "do not restore comments") @@ -169,6 +168,12 @@ func init() { Cmd.Flags().BoolP("strict-names", "", false, "restore named section (pre-data, data, or post-data) match at least one entity each") Cmd.Flags().BoolP("use-set-session-authorization", "", false, "use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership") Cmd.Flags().BoolP("on-conflict-do-nothing", "", false, "add ON CONFLICT DO NOTHING to INSERT commands") + Cmd.Flags().StringP("superuser", "S", "", "superuser user name to use for disabling triggers") + Cmd.Flags().BoolP( + "use-session-replication-role-replica", "", false, + "use SET session_replication_role = 'replica' to disable triggers during data section restore"+ + " (alternative for --disable-triggers)", + ) Cmd.Flags().BoolP("inserts", "", false, "restore data as INSERT commands, rather than COPY") Cmd.Flags().BoolP("restore-in-order", "", false, "restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored") Cmd.Flags().BoolP( @@ -193,11 +198,11 @@ func init() { "dbname", "file", "verbose", "data-only", "clean", "create", "exit-on-error", "jobs", "list-format", "use-list", "schema", "exclude-schema", - "no-owner", "function", "schema-only", "superuser", "table", "trigger", "no-privileges", "single-transaction", + "no-owner", "function", "schema-only", "table", "trigger", "no-privileges", "single-transaction", "disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables", "no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section", "strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "restore-in-order", - "pgzip", "batch-size", "overriding-system-value", + "pgzip", "batch-size", "overriding-system-value", "superuser", "use-session-replication-role-replica", "host", "port", "username", } { diff --git a/docs/commands/dump.md b/docs/commands/dump.md index 9be7786e..ca2f68d4 100644 --- a/docs/commands/dump.md +++ b/docs/commands/dump.md @@ -20,7 +20,6 @@ Mostly it supports the same flags as the `pg_dump` utility, with some extra flag -a, --data-only dump only the data, not the schema -d, --dbname string database to dump (default "postgres") --disable-dollar-quoting disable dollar quoting, use SQL standard quoting - --disable-triggers disable triggers during data-only restore --enable-row-security enable row security (dump only content user has access to) -E, --encoding string dump the data in encoding ENCODING -N, --exclude-schema strings dump the specified schema(s) only @@ -37,7 +36,7 @@ Mostly it supports the same flags as the `pg_dump` utility, with some extra flag --lock-wait-timeout int fail after waiting TIMEOUT for a table lock (default -1) -B, --no-blobs exclude large objects in dump --no-comments do not dump comments - -O, --no-owner string skip restoration of object ownership in plain-text format + -O, --no-owner skip restoration of object ownership in plain-text format -X, --no-privileges do not dump privileges (grant/revoke) --no-publications do not dump publications --no-security-labels do not dump security label assignments @@ -51,12 +50,11 @@ Mostly it supports the same flags as the `pg_dump` utility, with some extra flag -p, --port int database server port number (default 5432) --quote-all-identifiers quote all identifiers, even if not key words -n, --schema strings dump the specified schema(s) only - -s, --schema-only string dump only the schema, no data + -s, --schema-only dump only the schema, no data --section string dump named section (pre-data, data, or post-data) --serializable-deferrable wait until the dump can run without anomalies --snapshot string use given snapshot for the dump --strict-names require table and/or schema include patterns to match at least one entity each - -S, --superuser string superuser user name to use in plain-text format -t, --table strings dump the specified table(s) only --test string connect as specified database user (default "postgres") --use-set-session-authorization use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership diff --git a/docs/commands/restore.md b/docs/commands/restore.md index e5d5c2ab..317eb30c 100644 --- a/docs/commands/restore.md +++ b/docs/commands/restore.md @@ -19,49 +19,50 @@ allowing you to configure the restoration process as needed. Mostly it supports the same flags as the `pg_restore` utility, with some extra flags for Greenmask-specific features. ```text title="Supported flags" - --batch-size int the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch) - -c, --clean clean (drop) database objects before recreating - -C, --create create the target database - -a, --data-only restore only the data, no schema - -d, --dbname string connect to database name (default "postgres") - --disable-triggers disable triggers during data-only restore - --enable-row-security enable row security - -N, --exclude-schema strings do not restore objects in this schema - -e, --exit-on-error exit on error, default is to continue - -f, --file string output file name (- for stdout) - -P, --function strings restore named function - -h, --host string database server host or socket directory (default "/var/run/postgres") - --if-exists use IF EXISTS when dropping objects - -i, --index strings restore named index - --inserts restore data as INSERT commands, rather than COPY - -j, --jobs int use this many parallel jobs to restore (default 1) - --list-format string use table of contents in format of text, json or yaml (default "text") - --no-comments do not restore comments - --no-data-for-failed-tables do not restore data of tables that could not be created - -O, --no-owner string skip restoration of object ownership - -X, --no-privileges skip restoration of access privileges (grant/revoke) - --no-publications do not restore publications - --no-security-labels do not restore security labels - --no-subscriptions ddo not restore subscriptions - --no-table-access-method do not restore table access methods - --no-tablespaces do not restore tablespace assignments - --on-conflict-do-nothing add ON CONFLICT DO NOTHING to INSERT commands - --overriding-system-value use OVERRIDING SYSTEM VALUE clause for INSERTs - --pgzip use pgzip decompression instead of gzip - -p, --port int database server port number (default 5432) - --restore-in-order restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored - -n, --schema strings restore only objects in this schema - -s, --schema-only restore only the schema, no data - --section string restore named section (pre-data, data, or post-data) - -1, --single-transaction restore as a single transaction - --strict-names restore named section (pre-data, data, or post-data) match at least one entity each - -S, --superuser string superuser user name to use for disabling triggers - -t, --table strings restore named relation (table, view, etc.) - -T, --trigger strings restore named trigger - -L, --use-list string use table of contents from this file for selecting/ordering output - --use-set-session-authorization use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership - -U, --username string connect as specified database user (default "postgres") - -v, --verbose string verbose mode + --batch-size int the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch) + -c, --clean clean (drop) database objects before recreating + -C, --create create the target database + -a, --data-only restore only the data, no schema + -d, --dbname string connect to database name (default "postgres") + --disable-triggers disable triggers during data section restore + --enable-row-security enable row security + -N, --exclude-schema strings do not restore objects in this schema + -e, --exit-on-error exit on error, default is to continue + -f, --file string output file name (- for stdout) + -P, --function strings restore named function + -h, --host string database server host or socket directory (default "/var/run/postgres") + --if-exists use IF EXISTS when dropping objects + -i, --index strings restore named index + --inserts restore data as INSERT commands, rather than COPY + -j, --jobs int use this many parallel jobs to restore (default 1) + --list-format string use table of contents in format of text, json or yaml (default "text") + --no-comments do not restore comments + --no-data-for-failed-tables do not restore data of tables that could not be created + -O, --no-owner skip restoration of object ownership + -X, --no-privileges skip restoration of access privileges (grant/revoke) + --no-publications do not restore publications + --no-security-labels do not restore security labels + --no-subscriptions ddo not restore subscriptions + --no-table-access-method do not restore table access methods + --no-tablespaces do not restore tablespace assignments + --on-conflict-do-nothing add ON CONFLICT DO NOTHING to INSERT commands + --overriding-system-value use OVERRIDING SYSTEM VALUE clause for INSERTs + --pgzip use pgzip decompression instead of gzip + -p, --port int database server port number (default 5432) + --restore-in-order restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored + -n, --schema strings restore only objects in this schema + -s, --schema-only restore only the schema, no data + --section string restore named section (pre-data, data, or post-data) + -1, --single-transaction restore as a single transaction + --strict-names restore named section (pre-data, data, or post-data) match at least one entity each + -S, --superuser string superuser user name to use for disabling triggers + -t, --table strings restore named relation (table, view, etc.) + -T, --trigger strings restore named trigger + -L, --use-list string use table of contents from this file for selecting/ordering output + --use-session-replication-role-replica use SET session_replication_role = 'replica' to disable triggers during data section restore (alternative for --disable-triggers) + --use-set-session-authorization use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership + -U, --username string connect as specified database user (default "postgres") + -v, --verbose string verbose mode ``` ## Extra features diff --git a/go.mod b/go.mod index c0a093bd..a4d5a603 100644 --- a/go.mod +++ b/go.mod @@ -89,6 +89,7 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/internal/db/postgres/cmd/restore.go b/internal/db/postgres/cmd/restore.go index 0c5ab26a..7322a04d 100644 --- a/internal/db/postgres/cmd/restore.go +++ b/internal/db/postgres/cmd/restore.go @@ -632,13 +632,10 @@ func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTa return fmt.Errorf("cannot get table definition from meta: %w", err) } task = restorers.NewTableRestorerInsertFormat( - entry, t, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.OnConflictDoNothing, - r.cfg.ErrorExclusions, r.restoreOpt.Pgzip, r.restoreOpt.OverridingSystemValue, + entry, t, r.st, r.restoreOpt.ToDataSectionSettings(), r.cfg.ErrorExclusions, ) } else { - task = restorers.NewTableRestorer( - entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.Pgzip, r.restoreOpt.BatchSize, - ) + task = restorers.NewTableRestorer(entry, r.st, r.restoreOpt.ToDataSectionSettings()) } case toc.SequenceSetDesc: diff --git a/internal/db/postgres/pgdump/adapter.go b/internal/db/postgres/pgdump/adapter.go index 81dde491..b6c9ac92 100644 --- a/internal/db/postgres/pgdump/adapter.go +++ b/internal/db/postgres/pgdump/adapter.go @@ -25,66 +25,66 @@ import ( // TODO: Simplify that adapter const ( - NoneActionStage = iota - StartedActionStage - EndedActionStage + noneActionStage = iota + startedActionStage + endedActionStage ) var ( - StateLowerCase StateName = "StateLowerCase" - StateKeepCase StateName = "StateKeepCase" - StateWildcard StateName = "StateWildcard" - StateDoubledQuotes StateName = "StateDoubledQuotes" - StateQuestionMark StateName = "StateQuestionMark" - DefaultState = &State{ - Name: StateLowerCase, - Action: LowerCaseState, + stateLowerCase stateName = "stateLowerCase" + stateKeepCase stateName = "stateKeepCase" + stateWildcard stateName = "stateWildcard" + stateDoubledQuotes stateName = "stateDoubledQuotes" + stateQuestionMark stateName = "stateQuestionMark" + defaultState = &state{ + name: stateLowerCase, + action: lowerCaseState, } ) -type ActionContext struct { +type actionContext struct { strings.Builder - Stage int + stage int } -func (ac *ActionContext) IsDone() bool { - return ac.Stage == EndedActionStage +func (ac *actionContext) isDone() bool { + return ac.stage == endedActionStage } -func (ac *ActionContext) SetDone() { - ac.Stage = EndedActionStage +func (ac *actionContext) setDone() { + ac.stage = endedActionStage } -type StateName string +type stateName string -type Action func(actx *ActionContext, s string, dest *strings.Builder) error +type action func(actx *actionContext, s string, dest *strings.Builder) error -type State struct { - ActionContext - Name StateName - Action Action +type state struct { + actionContext + name stateName + action action } -func LowerCaseState(actx *ActionContext, s string, dest *strings.Builder) error { +func lowerCaseState(actx *actionContext, s string, dest *strings.Builder) error { if s == "" { return errors.New("unexpected char length") } - actx.SetDone() // It is always done because it's default state + actx.setDone() // It is always done because it's default state if _, err := dest.WriteString(strings.ToLower(s)); err != nil { return err } return nil } -func KeepCaseState(actx *ActionContext, s string, dest *strings.Builder) error { +func keepCaseState(actx *actionContext, s string, dest *strings.Builder) error { if s == `"` { - if actx.Stage == NoneActionStage { - actx.Stage = StartedActionStage + if actx.stage == noneActionStage { + actx.stage = startedActionStage } else { - actx.Stage = EndedActionStage + actx.stage = endedActionStage } return nil - } else if actx.Stage == NoneActionStage { + } else if actx.stage == noneActionStage { return errors.New("syntax error") } if _, err := dest.WriteString(s); err != nil { @@ -94,58 +94,58 @@ func KeepCaseState(actx *ActionContext, s string, dest *strings.Builder) error { } -func WildCardState(actx *ActionContext, s string, dest *strings.Builder) error { +func wildCardState(actx *actionContext, s string, dest *strings.Builder) error { if string(s) != "*" { return errors.New("unknown character") } if _, err := dest.WriteString(".*"); err != nil { return err } - actx.SetDone() + actx.setDone() return nil } -func QuestionMarkState(actx *ActionContext, s string, dest *strings.Builder) error { +func questionMarkState(actx *actionContext, s string, dest *strings.Builder) error { if s != "?" { return errors.New("unknown character") } if _, err := dest.WriteRune('.'); err != nil { return err } - actx.SetDone() + actx.setDone() return nil } -func DoubleQuoteState(actx *ActionContext, s string, dest *strings.Builder) error { +func doubleQuoteState(actx *actionContext, s string, dest *strings.Builder) error { if s != `""` { return errors.New("unknown character") } if _, err := dest.WriteRune('"'); err != nil { return err } - actx.SetDone() + actx.setDone() return nil } type ParserContext struct { - currentState *State + currentState *state // stateStack - nested states that we would be able to handle - stateStack []*State + stateStack []*state } -func NewParser() *ParserContext { +func newParser() *ParserContext { return &ParserContext{ - currentState: DefaultState, - stateStack: []*State{DefaultState}, + currentState: defaultState, + stateStack: []*state{defaultState}, } } -func (p *ParserContext) PushState(state *State) { +func (p *ParserContext) pushState(state *state) { p.currentState = state p.stateStack = append(p.stateStack, state) } -func (p *ParserContext) PopState() { +func (p *ParserContext) popState() { p.currentState = p.stateStack[len(p.stateStack)-2] p.stateStack = p.stateStack[:len(p.stateStack)-1] } @@ -155,7 +155,7 @@ func (p *ParserContext) Depth() int { } func AdaptRegexp(data string) (string, error) { - pctx := NewParser() + pctx := newParser() src := strings.NewReader(data) dest := &strings.Builder{} var isEOF bool @@ -192,9 +192,9 @@ func AdaptRegexp(data string) (string, error) { switch ch { case '"': // Doubled double quote parse - pctx.PushState(&State{ - Name: StateDoubledQuotes, - Action: DoubleQuoteState, + pctx.pushState(&state{ + name: stateDoubledQuotes, + action: doubleQuoteState, }) if _, err = literals.WriteRune(ch); err != nil { return "", err @@ -207,36 +207,36 @@ func AdaptRegexp(data string) (string, error) { } } - if pctx.currentState.Name != StateKeepCase { - pctx.PushState(&State{ - Name: StateKeepCase, - Action: KeepCaseState, + if pctx.currentState.name != stateKeepCase { + pctx.pushState(&state{ + name: stateKeepCase, + action: keepCaseState, }) } } case '*': - if pctx.currentState.Name != StateWildcard { - pctx.PushState(&State{ - Name: StateWildcard, - Action: WildCardState, + if pctx.currentState.name != stateWildcard { + pctx.pushState(&state{ + name: stateWildcard, + action: wildCardState, }) } case '?': - if pctx.currentState.Name != StateQuestionMark { - pctx.PushState(&State{ - Name: StateQuestionMark, - Action: QuestionMarkState, + if pctx.currentState.name != stateQuestionMark { + pctx.pushState(&state{ + name: stateQuestionMark, + action: questionMarkState, }) } } - if err = pctx.currentState.Action(&pctx.currentState.ActionContext, literals.String(), dest); err != nil { + if err = pctx.currentState.action(&pctx.currentState.actionContext, literals.String(), dest); err != nil { return "", err } - if pctx.Depth() > 1 && pctx.currentState.IsDone() { - pctx.PopState() + if pctx.Depth() > 1 && pctx.currentState.isDone() { + pctx.popState() } literals.Reset() } diff --git a/internal/db/postgres/pgdump/pgdump.go b/internal/db/postgres/pgdump/pgdump.go index 222c012a..cc933e78 100644 --- a/internal/db/postgres/pgdump/pgdump.go +++ b/internal/db/postgres/pgdump/pgdump.go @@ -67,12 +67,10 @@ type Options struct { ExcludeSchema []string `mapstructure:"exclude-schema"` NoOwner bool `mapstructure:"no-owner"` SchemaOnly bool `mapstructure:"schema-only"` - SuperUser string `mapstructure:"superuser"` Table []string `mapstructure:"table"` ExcludeTable []string `mapstructure:"exclude-table"` NoPrivileges bool `mapstructure:"no-privileges"` DisableDollarQuoting bool `mapstructure:"disable-dollar-quoting"` - DisableTriggers bool `mapstructure:"disable-triggers"` EnableRowSecurity bool `mapstructure:"enable-row-security"` ExcludeTableData []string `mapstructure:"exclude-table-data"` ExtraFloatDigits string `mapstructure:"extra-float-digits"` @@ -149,7 +147,8 @@ func (o *Options) GetParams() []string { args = append(args, "--verbose") } if o.Compression != -1 { - args = append(args, "--compress", strconv.FormatInt(int64(o.Compression), 10)) + panic("FIXME: --compress is not implemented") + //args = append(args, "--compress", strconv.FormatInt(int64(o.Compression), 10)) } if o.LockWaitTimeout != -1 { args = append(args, "--lock-wait-timeout", strconv.FormatInt(int64(o.Compression), 10)) @@ -163,10 +162,12 @@ func (o *Options) GetParams() []string { args = append(args, "--data-only") } if o.Blobs { - args = append(args, "--blobs") + panic("FIXME: --blobs is not implemented") + //args = append(args, "--blobs") } if o.NoBlobs { - args = append(args, "--no-blobs") + panic("FIXME: --no-blobs is not implemented") + //args = append(args, "--no-blobs") } if o.Clean { args = append(args, "--clean") @@ -198,9 +199,6 @@ func (o *Options) GetParams() []string { if o.SchemaOnly { args = append(args, "--schema-only") } - if o.SuperUser != "" { - args = append(args, "--superuser", o.SuperUser) - } if len(o.Table) > 0 { for _, item := range o.Table { args = append(args, "--table", item) @@ -217,14 +215,10 @@ func (o *Options) GetParams() []string { if o.DisableDollarQuoting { args = append(args, "--disable-dollar-quoting") } - if o.DisableTriggers { - //args = append(args, "--disable-triggers") - panic("FIXME: --disable-triggers is not implemented") - } if o.EnableRowSecurity { // TODO: Seems that this options affects COPY - log.Warn().Msgf("FIXME: Seems that this options affects COPY and is not implemented") - args = append(args, "--enable-row-security") + panic("FIXME: --enable-row-security is not implemented") + //args = append(args, "--enable-row-security") } if len(o.ExcludeTableData) > 0 { for _, item := range o.ExcludeTableData { diff --git a/internal/db/postgres/pgrestore/pgrestore.go b/internal/db/postgres/pgrestore/pgrestore.go index dde641fe..2d814ef1 100644 --- a/internal/db/postgres/pgrestore/pgrestore.go +++ b/internal/db/postgres/pgrestore/pgrestore.go @@ -45,6 +45,18 @@ func (pr *PgRestore) Run(ctx context.Context, options *Options) error { return cmd_runner.Run(ctx, &log.Logger, path.Join(pr.BinPath, pgRestoreExecutable), options.GetParams()...) } +// DataSectionSettings - settings for data section that changes behavior of dumpers +type DataSectionSettings struct { + ExitOnError bool + UsePgzip bool + BatchSize int64 + OnConflictDoNothing bool + OverridingSystemValue bool + DisableTriggers bool + SuperUser string + UseSessionReplicationRoleReplica bool +} + type Options struct { // Custom DirPath string @@ -99,8 +111,9 @@ type Options struct { // OverridingSystemValue is a custom option that allows to use OVERRIDING SYSTEM VALUE for INSERTs OverridingSystemValue bool `mapstructure:"overriding-system-value"` // Use pgzip decompression instead of gzip - Pgzip bool `mapstructure:"pgzip"` - BatchSize int64 `mapstructure:"batch-size"` + Pgzip bool `mapstructure:"pgzip"` + BatchSize int64 `mapstructure:"batch-size"` + UseSessionReplicationRoleReplica bool `mapstructure:"use-session-replication-role-replica"` // Connection options: Host string `mapstructure:"host"` @@ -111,6 +124,19 @@ type Options struct { Role string `mapstructure:"role"` } +func (o *Options) ToDataSectionSettings() *DataSectionSettings { + return &DataSectionSettings{ + ExitOnError: o.ExitOnError, + UsePgzip: o.Pgzip, + BatchSize: o.BatchSize, + OnConflictDoNothing: o.OnConflictDoNothing, + OverridingSystemValue: o.OverridingSystemValue, + DisableTriggers: o.DisableTriggers, + SuperUser: o.SuperUser, + UseSessionReplicationRoleReplica: o.UseSessionReplicationRoleReplica, + } +} + func (o *Options) GetPgDSN() (string, error) { if strings.HasPrefix(o.DbName, "postgresql://") || strings.Contains(o.DbName, "=") { return o.DbName, nil diff --git a/internal/db/postgres/restorers/base.go b/internal/db/postgres/restorers/base.go new file mode 100644 index 00000000..f7bd7748 --- /dev/null +++ b/internal/db/postgres/restorers/base.go @@ -0,0 +1,179 @@ +package restorers + +import ( + "context" + "fmt" + "io" + + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/storages" + "github.com/greenmaskio/greenmask/internal/utils/ioutils" +) + +type restoreBase struct { + opt *pgrestore.DataSectionSettings + entry *toc.Entry + st storages.Storager +} + +func newRestoreBase(entry *toc.Entry, st storages.Storager, opt *pgrestore.DataSectionSettings) *restoreBase { + return &restoreBase{ + entry: entry, + st: st, + opt: opt, + } + +} + +func (rb *restoreBase) DebugInfo() string { + return fmt.Sprintf("table %s.%s", *rb.entry.Namespace, *rb.entry.Tag) +} + +func (rb *restoreBase) setSessionReplicationRole(ctx context.Context, tx pgx.Tx) error { + if err := rb.setSuperUser(ctx, tx); err != nil { + return err + } + if rb.opt.UseSessionReplicationRoleReplica { + _, err := tx.Exec(ctx, "SET session_replication_role = 'replica'") + if err != nil { + return err + } + } + if err := rb.resetSuperUser(ctx, tx); err != nil { + return err + } + return nil +} + +func (rb *restoreBase) resetSessionReplicationRole(ctx context.Context, tx pgx.Tx) error { + if err := rb.setSuperUser(ctx, tx); err != nil { + return err + } + if rb.opt.UseSessionReplicationRoleReplica { + _, err := tx.Exec(ctx, "RESET session_replication_role") + if err != nil { + return err + } + } + if err := rb.resetSuperUser(ctx, tx); err != nil { + return err + } + return nil +} + +func (rb *restoreBase) disableTriggers(ctx context.Context, tx pgx.Tx) error { + if rb.opt.DisableTriggers { + if err := rb.setSuperUser(ctx, tx); err != nil { + return err + } + _, err := tx.Exec( + ctx, + fmt.Sprintf( + "ALTER TABLE %s.%s DISABLE TRIGGER ALL", + *rb.entry.Namespace, + *rb.entry.Tag, + ), + ) + if err != nil { + return err + } + if err := rb.resetSuperUser(ctx, tx); err != nil { + return err + } + } + return nil +} + +func (rb *restoreBase) enableTriggers(ctx context.Context, tx pgx.Tx) error { + if rb.opt.DisableTriggers { + if err := rb.setSuperUser(ctx, tx); err != nil { + return err + } + _, err := tx.Exec( + ctx, + fmt.Sprintf( + "ALTER TABLE %s.%s ENABLE TRIGGER ALL", + *rb.entry.Namespace, + *rb.entry.Tag, + ), + ) + if err != nil { + return err + } + if err := rb.resetSuperUser(ctx, tx); err != nil { + return err + } + } + return nil +} + +func (rb *restoreBase) setSuperUser(ctx context.Context, tx pgx.Tx) error { + if rb.opt.SuperUser != "" { + _, err := tx.Exec(ctx, fmt.Sprintf("SET ROLE %s", rb.opt.SuperUser)) + if err != nil { + return fmt.Errorf("cannot set superuser: %w", err) + } + } + return nil +} + +func (rb *restoreBase) resetSuperUser(ctx context.Context, tx pgx.Tx) error { + if rb.opt.SuperUser != "" { + _, err := tx.Exec(ctx, "RESET ROLE") + if err != nil { + return fmt.Errorf("cannot reset superuser: %w", err) + } + } + return nil +} + +// setupTx - setup transaction before restore. It disables triggers and sets session replication role if set. +func (rb *restoreBase) setupTx(ctx context.Context, tx pgx.Tx) error { + if err := rb.setSessionReplicationRole(ctx, tx); err != nil { + return fmt.Errorf("cannot set session replication role: %w", err) + } + if err := rb.disableTriggers(ctx, tx); err != nil { + return fmt.Errorf("cannot disable triggers: %w", err) + } + return nil +} + +// resetTx - reset transaction state after restore so the changes such as temporal alter table will not be +// commited +func (rb *restoreBase) resetTx(ctx context.Context, tx pgx.Tx) error { + if err := rb.enableTriggers(ctx, tx); err != nil { + return fmt.Errorf("cannot enable triggers: %w", err) + } + if err := rb.resetSessionReplicationRole(ctx, tx); err != nil { + return fmt.Errorf("cannot reset session replication role: %w", err) + } + return nil +} + +// getObject returns a reader for the dump file. It warps the file in a gzip reader. +func (rb *restoreBase) getObject(ctx context.Context) (io.ReadCloser, error) { + if rb.entry.FileName == nil { + return nil, fmt.Errorf("file name in toc.Entry is empty") + } + + r, err := rb.st.GetObject(ctx, *rb.entry.FileName) + if err != nil { + return nil, fmt.Errorf("cannot open dump file: %w", err) + } + + gz, err := ioutils.NewGzipReader(r, rb.opt.UsePgzip) + if err != nil { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return nil, fmt.Errorf("cannot create gzip reader: %w", err) + } + + return gz, nil +} diff --git a/internal/db/postgres/restorers/base_test.go b/internal/db/postgres/restorers/base_test.go new file mode 100644 index 00000000..48778af7 --- /dev/null +++ b/internal/db/postgres/restorers/base_test.go @@ -0,0 +1,597 @@ +package restorers + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/utils/testutils" +) + +const ( + migrationUp = ` +CREATE USER non_super_user PASSWORD '1234' NOINHERIT; +GRANT testuser TO non_super_user; +GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA public TO non_super_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT, INSERT ON TABLES TO non_super_user; +GRANT INSERT ON ALL TABLES IN SCHEMA public TO non_super_user; + +-- Create the 'users' table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create the 'orders' table +CREATE TABLE orders ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL, + order_amount NUMERIC(10, 2) NOT NULL, + raise_error TEXT, + order_date DATE DEFAULT CURRENT_DATE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +); + +-- Trigger function to ensure 'order_date' is always set +CREATE OR REPLACE FUNCTION set_order_date() +RETURNS TRIGGER AS $$ +BEGIN + If NEW.raise_error != '' THEN + RAISE EXCEPTION '%', NEW.raise_error; + END IF; + IF NEW.order_date IS NULL THEN + NEW.order_date = CURRENT_DATE; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger for 'orders' table +CREATE TRIGGER trg_set_order_date +BEFORE INSERT ON orders +FOR EACH ROW +EXECUTE FUNCTION set_order_date(); + +-- Insert sample data into 'users' +INSERT INTO users (name, email) VALUES +('Alice', 'alice@example.com'), +('Bob', 'bob@example.com'); + +-- Insert sample data into 'orders' +INSERT INTO orders (user_id, order_amount) VALUES +(1, 100.50), +(2, 200.75); +` + migrationDown = ` +REVOKE ALL ON SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL TABLES IN SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL FUNCTIONS IN SCHEMA public FROM non_super_user; +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM non_super_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE SELECT, INSERT ON TABLES FROM non_super_user; +REVOKE USAGE ON SCHEMA public FROM non_super_user; +REVOKE testuser FROM non_super_user; +DROP USER non_super_user; +DROP TRIGGER IF EXISTS trg_set_order_date ON orders; +DROP FUNCTION IF EXISTS set_order_date; +DROP TABLE IF EXISTS orders; +DROP TABLE IF EXISTS users; +` +) + +type readCloserMock struct { + *bytes.Buffer +} + +func (r *readCloserMock) Close() error { + return nil +} + +type restoresSuite struct { + testutils.PgContainerSuite + nonSuperUserPassword string + nonSuperUser string +} + +func (s *restoresSuite) SetupSuite() { + s.SetMigrationUp(migrationUp). + SetMigrationDown(migrationDown). + SetupSuite() + s.nonSuperUser = "non_super_user" + s.nonSuperUserPassword = "1234" +} + +func (s *restoresSuite) Test_restoreBase_DebugInfo() { + nsp := "public" + tag := "orders" + rb := newRestoreBase(&toc.Entry{ + Namespace: &nsp, + Tag: &tag, + }, nil, nil) + s.Equal("table public.orders", rb.DebugInfo()) +} + +func (s *restoresSuite) Test_restoreBase_setSessionReplicationRole() { + userName := s.GetSuperUser() + opt := &pgrestore.DataSectionSettings{ + UseSessionReplicationRoleReplica: true, + SuperUser: userName, + } + + rb := newRestoreBase(nil, nil, opt) + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) // nolint: errcheck + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + s.Require().NoError(rb.setSessionReplicationRole(cxt, tx)) + + expectedUser := s.nonSuperUser + expectedReplicaRole := "replica" + + var actualUser string + r := tx.QueryRow(cxt, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) + + var actualReplicaRole string + r = tx.QueryRow(cxt, "SHOW session_replication_role") + err = r.Scan(&actualReplicaRole) + s.Require().NoError(err) + s.Assert().Equal(expectedReplicaRole, actualReplicaRole) + + s.NoError(tx.Rollback(cxt)) +} + +func (s *restoresSuite) Test_restoreBase_resetSessionReplicationRole() { + userName := s.GetSuperUser() + opt := &pgrestore.DataSectionSettings{ + UseSessionReplicationRoleReplica: true, + SuperUser: userName, + } + + rb := newRestoreBase(nil, nil, opt) + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) // nolint: errcheck + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + + _, err = tx.Exec(cxt, "SET ROLE "+s.GetSuperUser()) + s.Require().NoError(err) + _, err = tx.Exec(cxt, "SET session_replication_role = 'replica'") + s.Require().NoError(err) + + err = rb.setSessionReplicationRole(cxt, tx) + s.Require().NoError(err) + + expectedUser := s.nonSuperUser + expectedReplicaRole := "replica" + + var actualUser string + r := tx.QueryRow(cxt, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) + + var actualReplicaRole string + r = tx.QueryRow(cxt, "SHOW session_replication_role") + err = r.Scan(&actualReplicaRole) + s.Require().NoError(err) + s.Assert().Equal(expectedReplicaRole, actualReplicaRole) + + s.NoError(tx.Rollback(cxt)) +} + +func (s *restoresSuite) Test_restoreBase_enableTriggers() { + schemaName := "public" + tableName := "orders" + opt := &pgrestore.DataSectionSettings{ + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + }, nil, opt) + ctx := context.Background() + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(ctx) // nolint: errcheck + s.Require().NoError(err) + tx, err := conn.Begin(ctx) + s.Require().NoError(err) + err = rb.disableTriggers(ctx, tx) + s.Require().NoError(err) + + expectedUser := s.nonSuperUser + var actualUser string + r := tx.QueryRow(ctx, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) + + // tgenabled value: + // O = trigger fires in “origin” and “local” modes, D = trigger is disabled, + // R = trigger fires in “replica” mode, A = trigger fires always + checkDisabledTriggerSql := ` +SELECT tgname AS trigger_name, + tgenabled +FROM pg_trigger t + JOIN pg_class c ON t.tgrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid +WHERE n.nspname = $1 AND c.relname = $2 + AND t.tgname = ANY($3); +` + rows, err := conn.Query( + ctx, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, + ) + s.Require().NoError(err) + defer rows.Close() + + type triggerStatus struct { + triggerName string + tgenabled rune + } + var triggers []triggerStatus + for rows.Next() { + var t triggerStatus + err = rows.Scan(&t.triggerName, &t.tgenabled) + s.Require().NoError(err) + triggers = append(triggers, t) + } + + expectedTriggerStatus := []triggerStatus{ + {triggerName: "trg_set_order_date", tgenabled: 'D'}, + } + + s.Require().Len(triggers, len(expectedTriggerStatus)) + for i, expected := range expectedTriggerStatus { + s.Assert().Equal(expected.triggerName, triggers[i].triggerName) + s.Assert().Equal(expected.tgenabled, triggers[i].tgenabled) + } + + s.NoError(tx.Rollback(ctx)) +} + +func (s *restoresSuite) Test_restoreBase_disableTriggers() { + cxt := context.Background() + schemaName := "public" + tableName := "orders" + + suConn, err := s.GetConnection(cxt) + s.Require().NoError(err) + defer suConn.Close(cxt) // nolint: errcheck + s.Require().NoError(err) + _, err = suConn.Exec(cxt, "ALTER TABLE public.orders DISABLE TRIGGER ALL") + s.Require().NoError(err) + + opt := &pgrestore.DataSectionSettings{ + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + }, nil, opt) + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) // nolint: errcheck + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + err = rb.enableTriggers(cxt, tx) + s.Require().NoError(err) + + expectedUser := s.nonSuperUser + var actualUser string + r := tx.QueryRow(cxt, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) + + // tgenabled value: + // O = trigger fires in “origin” and “local” modes, D = trigger is disabled, + // R = trigger fires in “replica” mode, A = trigger fires always + checkDisabledTriggerSql := ` +SELECT tgname AS trigger_name, + tgenabled +FROM pg_trigger t + JOIN pg_class c ON t.tgrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid +WHERE n.nspname = $1 AND c.relname = $2 + AND t.tgname = ANY($3); +` + rows, err := conn.Query( + cxt, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, + ) + s.Require().NoError(err) + defer rows.Close() + + type triggerStatus struct { + triggerName string + tgenabled rune + } + var triggers []triggerStatus + for rows.Next() { + var t triggerStatus + err = rows.Scan(&t.triggerName, &t.tgenabled) + s.Require().NoError(err) + triggers = append(triggers, t) + } + + expectedTriggerStatus := []triggerStatus{ + {triggerName: "trg_set_order_date", tgenabled: 'O'}, + } + + s.Require().Len(triggers, len(expectedTriggerStatus)) + for i, expected := range expectedTriggerStatus { + s.Assert().Equal(expected.triggerName, triggers[i].triggerName) + s.Assert().Equal(expected.tgenabled, triggers[i].tgenabled) + } + + s.NoError(tx.Rollback(cxt)) +} + +func (s *restoresSuite) Test_restoreBase_setSuperUser() { + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + defer tx.Rollback(cxt) // nolint: errcheck + rb := newRestoreBase(nil, nil, &pgrestore.DataSectionSettings{ + SuperUser: s.GetSuperUser(), + }) + err = rb.setSuperUser(cxt, tx) + s.Require().NoError(err) + + expectedUser := s.GetSuperUser() + var actualUser string + r := conn.QueryRow(cxt, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) +} + +func (s *restoresSuite) Test_restoreBase_resetSuperUser() { + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + defer tx.Rollback(cxt) // nolint: errcheck + + _, err = tx.Exec(cxt, fmt.Sprintf("SET ROLE %s", s.GetSuperUser())) + s.Require().NoError(err) + + rb := newRestoreBase(nil, nil, &pgrestore.DataSectionSettings{ + SuperUser: s.GetSuperUser(), + }) + err = rb.resetSuperUser(cxt, tx) + s.Require().NoError(err) + + expectedUser := s.nonSuperUser + var actualUser string + r := conn.QueryRow(cxt, "SELECT current_user") + err = r.Scan(&actualUser) + s.Require().NoError(err) + s.Assert().Equal(expectedUser, actualUser) +} + +func (s *restoresSuite) Test_restoreBase_setupTx() { + // Test triggers enabled + // Test session replication role enabled + schemaName := "public" + tableName := "orders" + opt := &pgrestore.DataSectionSettings{ + DisableTriggers: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + }, nil, opt) + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + err = rb.setupTx(cxt, tx) + s.Require().NoError(err) + + // tgenabled value: + // O = trigger fires in “origin” and “local” modes, D = trigger is disabled, + // R = trigger fires in “replica” mode, A = trigger fires always + checkDisabledTriggerSql := ` +SELECT tgname AS trigger_name, + tgenabled +FROM pg_trigger t + JOIN pg_class c ON t.tgrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid +WHERE n.nspname = $1 AND c.relname = $2 + AND t.tgname = ANY($3); +` + rows, err := conn.Query( + cxt, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, + ) + s.Require().NoError(err) + defer rows.Close() + + type triggerStatus struct { + triggerName string + tgenabled rune + } + var triggers []triggerStatus + for rows.Next() { + var t triggerStatus + err = rows.Scan(&t.triggerName, &t.tgenabled) + s.Require().NoError(err) + triggers = append(triggers, t) + } + + expectedTriggerStatus := []triggerStatus{ + {triggerName: "trg_set_order_date", tgenabled: 'D'}, + } + + s.Require().Len(triggers, len(expectedTriggerStatus)) + for i, expected := range expectedTriggerStatus { + s.Assert().Equal(expected.triggerName, triggers[i].triggerName) + s.Assert().Equal(expected.tgenabled, triggers[i].tgenabled) + } + + expectedReplicaRole := "replica" + actualReplicaRole := "" + r := tx.QueryRow(cxt, "SHOW session_replication_role") + err = r.Scan(&actualReplicaRole) + s.Require().NoError(err) + s.Assert().Equal(expectedReplicaRole, actualReplicaRole) + + s.NoError(tx.Rollback(cxt)) +} + +func (s *restoresSuite) Test_restoreBase_resetTx() { + // Test triggers enabled + // Test session replication role enabled + schemaName := "public" + tableName := "orders" + opt := &pgrestore.DataSectionSettings{ + DisableTriggers: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + }, nil, opt) + cxt := context.Background() + conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + defer conn.Close(cxt) + s.Require().NoError(err) + tx, err := conn.Begin(cxt) + s.Require().NoError(err) + + _, err = tx.Exec(cxt, "SET ROLE "+s.GetSuperUser()) + s.Require().NoError(err) + _, err = tx.Exec(cxt, "ALTER TABLE public.orders DISABLE TRIGGER ALL") + s.Require().NoError(err) + _, err = tx.Exec(cxt, "SET session_replication_role = 'replica'") + s.Require().NoError(err) + + err = rb.resetTx(cxt, tx) + s.Require().NoError(err) + + // tgenabled value: + // O = trigger fires in “origin” and “local” modes, D = trigger is disabled, + // R = trigger fires in “replica” mode, A = trigger fires always + checkDisabledTriggerSql := ` +SELECT tgname AS trigger_name, + tgenabled +FROM pg_trigger t + JOIN pg_class c ON t.tgrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid +WHERE n.nspname = $1 AND c.relname = $2 + AND t.tgname = ANY($3); +` + rows, err := conn.Query( + cxt, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, + ) + s.Require().NoError(err) + defer rows.Close() + + type triggerStatus struct { + triggerName string + tgenabled rune + } + var triggers []triggerStatus + for rows.Next() { + var t triggerStatus + err = rows.Scan(&t.triggerName, &t.tgenabled) + s.Require().NoError(err) + triggers = append(triggers, t) + } + + expectedTriggerStatus := []triggerStatus{ + {triggerName: "trg_set_order_date", tgenabled: 'O'}, + } + + s.Require().Len(triggers, len(expectedTriggerStatus)) + for i, expected := range expectedTriggerStatus { + s.Assert().Equal(expected.triggerName, triggers[i].triggerName) + s.Assert().Equal(expected.tgenabled, triggers[i].tgenabled) + } + + expectedReplicaRole := "origin" + actualReplicaRole := "" + r := tx.QueryRow(cxt, "SHOW session_replication_role") + err = r.Scan(&actualReplicaRole) + s.Require().NoError(err) + s.Assert().Equal(expectedReplicaRole, actualReplicaRole) + + s.NoError(tx.Rollback(cxt)) +} + +func (s *restoresSuite) Test_restoreBase_getObject() { + schemaName := "public" + tableName := "orders" + fileName := "test.tar.gz" + + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", mock.Anything, mock.Anything). + Return(objSrc, nil) + + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + }, st, &pgrestore.DataSectionSettings{}) + ctx := context.Background() + obj, err := rb.getObject(ctx) + s.Require().NoError(err) + readBuf := make([]byte, 1024) + n, err := obj.Read(readBuf) + s.Require().NoError(err) + s.Assert().Equal(data, string(readBuf[:n])) + s.NoError(obj.Close()) +} + +func TestRestorers(t *testing.T) { + suite.Run(t, new(restoresSuite)) +} diff --git a/internal/db/postgres/restorers/blobs.go b/internal/db/postgres/restorers/blobs.go index 80b34f40..4462e7f7 100644 --- a/internal/db/postgres/restorers/blobs.go +++ b/internal/db/postgres/restorers/blobs.go @@ -22,13 +22,13 @@ import ( "strconv" "strings" - "github.com/greenmaskio/greenmask/internal/utils/ioutils" "github.com/jackc/pgx/v5" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/greenmaskio/greenmask/internal/storages" + "github.com/greenmaskio/greenmask/internal/utils/ioutils" ) type BlobsRestorer struct { @@ -99,7 +99,7 @@ func (td *BlobsRestorer) execute(ctx context.Context, tx pgx.Tx) error { loApi := tx.LargeObjects() // restoring large objects one by one - buf := make([]byte, DefaultBufferSize) + buf := make([]byte, defaultBufferSize) for _, loOid := range td.largeObjectsOids { log.Debug().Uint32("oid", loOid).Msg("large object restoration is started") err = func() error { diff --git a/internal/db/postgres/restorers/restorer.go b/internal/db/postgres/restorers/restorer.go index 2099c9d6..ac7a9840 100644 --- a/internal/db/postgres/restorers/restorer.go +++ b/internal/db/postgres/restorers/restorer.go @@ -17,8 +17,9 @@ package restorers import ( "context" - "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/jackc/pgx/v5" + + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" ) type RestoreTask interface { diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index be30718c..4de41cf4 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -21,96 +21,88 @@ import ( "fmt" "io" - "github.com/greenmaskio/greenmask/internal/utils/ioutils" - "github.com/greenmaskio/greenmask/internal/utils/pgerrors" - "github.com/greenmaskio/greenmask/internal/utils/reader" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog/log" + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/greenmaskio/greenmask/internal/storages" + "github.com/greenmaskio/greenmask/internal/utils/pgerrors" + "github.com/greenmaskio/greenmask/internal/utils/reader" ) -const DefaultBufferSize = 1024 * 1024 +const defaultBufferSize = 1024 * 1024 type TableRestorer struct { - Entry *toc.Entry - St storages.Storager - exitOnError bool - usePgzip bool - batchSize int64 + *restoreBase } func NewTableRestorer( - entry *toc.Entry, st storages.Storager, exitOnError bool, usePgzip bool, batchSize int64, + entry *toc.Entry, st storages.Storager, opt *pgrestore.DataSectionSettings, ) *TableRestorer { return &TableRestorer{ - Entry: entry, - St: st, - exitOnError: exitOnError, - usePgzip: usePgzip, - batchSize: batchSize, + restoreBase: newRestoreBase(entry, st, opt), } } func (td *TableRestorer) GetEntry() *toc.Entry { - return td.Entry + return td.entry } func (td *TableRestorer) Execute(ctx context.Context, conn *pgx.Conn) error { // TODO: Add tests - if td.Entry.FileName == nil { + if td.entry.FileName == nil { return fmt.Errorf("cannot get file name from toc Entry") } - r, err := td.St.GetObject(ctx, *td.Entry.FileName) - if err != nil { - return fmt.Errorf("cannot open dump file: %w", err) - } - defer func(reader io.ReadCloser) { - if err := reader.Close(); err != nil { - log.Warn(). - Err(err). - Msg("error closing dump file") - } - }(r) - gz, err := ioutils.GetGzipReadCloser(r, td.usePgzip) + r, err := td.getObject(ctx) if err != nil { - return fmt.Errorf("cannot create gzip reader: %w", err) + return fmt.Errorf("cannot get storage object: %w", err) } - defer func(gz io.Closer) { - if err := gz.Close(); err != nil { + defer func() { + if err := r.Close(); err != nil { log.Warn(). Err(err). - Msg("error closing gzip reader") + Str("objectName", td.DebugInfo()). + Msg("cannot close storage object") } - }(gz) + }() // Open new transaction for each task tx, err := conn.Begin(ctx) if err != nil { return fmt.Errorf("cannot start transaction (restoring %s): %w", td.DebugInfo(), err) } + if err := td.setupTx(ctx, tx); err != nil { + return fmt.Errorf("cannot setup transaction: %w", err) + } - log.Debug().Str("copyStmt", *td.Entry.CopyStmt).Msgf("performing pgcopy statement") + log.Debug(). + Str("copyStmt", *td.entry.CopyStmt). + Msgf("performing pgcopy statement") f := tx.Conn().PgConn().Frontend() - if err = td.restoreCopy(ctx, f, gz); err != nil { - if txErr := tx.Rollback(ctx); txErr != nil { - log.Warn(). - Err(txErr). - Str("objectName", td.DebugInfo()). - Msg("cannot rollback transaction") - } - if td.exitOnError { + if err = td.restoreCopy(ctx, f, r); err != nil { + rollbackTransaction(ctx, tx, td.entry) + if td.opt.ExitOnError { return fmt.Errorf("unable to restore table: %w", err) } - log.Warn().Err(err).Msg("unable to restore table") + log.Warn(). + Err(err). + Str("objectName", td.DebugInfo()). + Msg("unable to restore table") return nil } + if err := td.resetTx(ctx, tx); err != nil { + rollbackTransaction(ctx, tx, td.entry) + if td.opt.ExitOnError { + return fmt.Errorf("unable to reset transaction: %w", err) + } + } + if err = tx.Commit(ctx); err != nil { return fmt.Errorf("cannot commit transaction (restoring %s): %w", td.DebugInfo(), err) } @@ -123,7 +115,7 @@ func (td *TableRestorer) restoreCopy(ctx context.Context, f *pgproto3.Frontend, return fmt.Errorf("error initializing pgcopy: %w", err) } - if td.batchSize > 0 { + if td.opt.BatchSize > 0 { if err := td.streamCopyDataByBatch(ctx, f, r); err != nil { return fmt.Errorf("error streaming pgcopy data: %w", err) } @@ -140,7 +132,7 @@ func (td *TableRestorer) restoreCopy(ctx context.Context, f *pgproto3.Frontend, } func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) error { - err := sendMessage(f, &pgproto3.Query{String: *td.Entry.CopyStmt}) + err := sendMessage(f, &pgproto3.Query{String: *td.entry.CopyStmt}) if err != nil { return fmt.Errorf("error sending Query message: %w", err) } @@ -173,7 +165,7 @@ func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) err // stops immediately and returns the error func (td *TableRestorer) streamCopyDataByBatch(ctx context.Context, f *pgproto3.Frontend, r io.Reader) (err error) { bi := bufio.NewReader(r) - buf := make([]byte, DefaultBufferSize) + buf := make([]byte, defaultBufferSize) var lineNum int64 for { buf, err = reader.ReadLine(bi, buf) @@ -194,7 +186,7 @@ func (td *TableRestorer) streamCopyDataByBatch(ctx context.Context, f *pgproto3. return fmt.Errorf("error sending CopyData message: %w", err) } - if lineNum%td.batchSize == 0 { + if lineNum%td.opt.BatchSize == 0 { if err = td.completeBatch(ctx, f); err != nil { return fmt.Errorf("error completing batch: %w", err) } @@ -213,7 +205,7 @@ func (td *TableRestorer) streamCopyDataByBatch(ctx context.Context, f *pgproto3. func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Frontend, r io.Reader) error { // Streaming pgcopy data from table dump - buf := make([]byte, DefaultBufferSize) + buf := make([]byte, defaultBufferSize) for { var n int @@ -279,8 +271,14 @@ func (td *TableRestorer) postStreamingHandle(ctx context.Context, f *pgproto3.Fr } } -func (td *TableRestorer) DebugInfo() string { - return fmt.Sprintf("table %s.%s", *td.Entry.Namespace, *td.Entry.Tag) +func rollbackTransaction(ctx context.Context, tx pgx.Tx, e *toc.Entry) { + if err := tx.Rollback(ctx); err != nil { + log.Warn(). + Err(err). + Str("SchemaName", *e.Namespace). + Str("TableName", *e.Tag). + Msg("cannot rollback transaction") + } } // sendMessage - send a message to the PostgreSQL backend and flush a buffer diff --git a/internal/db/postgres/restorers/table_insert_format.go b/internal/db/postgres/restorers/table_insert_format.go index 5fad13a3..a81b750d 100644 --- a/internal/db/postgres/restorers/table_insert_format.go +++ b/internal/db/postgres/restorers/table_insert_format.go @@ -28,31 +28,25 @@ import ( "github.com/rs/zerolog/log" "github.com/greenmaskio/greenmask/internal/db/postgres/pgcopy" + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/greenmaskio/greenmask/internal/domains" "github.com/greenmaskio/greenmask/internal/storages" - "github.com/greenmaskio/greenmask/internal/utils/ioutils" "github.com/greenmaskio/greenmask/internal/utils/reader" "github.com/greenmaskio/greenmask/pkg/toolkit" ) type TableRestorerInsertFormat struct { - Entry *toc.Entry - Table *toolkit.Table - St storages.Storager - doNothing bool - exitOnError bool - query string - globalExclusions *domains.GlobalDataRestorationErrorExclusions - tableExclusion *domains.TablesDataRestorationErrorExclusions - usePgzip bool - overridingSystemValue bool + *restoreBase + Table *toolkit.Table + query string + globalExclusions *domains.GlobalDataRestorationErrorExclusions + tableExclusion *domains.TablesDataRestorationErrorExclusions } func NewTableRestorerInsertFormat( - entry *toc.Entry, t *toolkit.Table, st storages.Storager, exitOnError bool, - doNothing bool, exclusions *domains.DataRestorationErrorExclusions, - usePgzip bool, overridingSystemValue bool, + entry *toc.Entry, t *toolkit.Table, st storages.Storager, opt *pgrestore.DataSectionSettings, + exclusions *domains.DataRestorationErrorExclusions, ) *TableRestorerInsertFormat { var ( @@ -79,53 +73,33 @@ func NewTableRestorerInsertFormat( } return &TableRestorerInsertFormat{ - Table: t, - Entry: entry, - St: st, - exitOnError: exitOnError, - doNothing: doNothing, - globalExclusions: globalExclusion, - tableExclusion: tableExclusion, - usePgzip: usePgzip, - overridingSystemValue: overridingSystemValue, + restoreBase: newRestoreBase(entry, st, opt), + Table: t, + globalExclusions: globalExclusion, + tableExclusion: tableExclusion, } } func (td *TableRestorerInsertFormat) GetEntry() *toc.Entry { - return td.Entry + return td.entry } func (td *TableRestorerInsertFormat) Execute(ctx context.Context, conn *pgx.Conn) error { - - if td.Entry.FileName == nil { - return fmt.Errorf("cannot get file name from toc Entry") - } - - r, err := td.St.GetObject(ctx, *td.Entry.FileName) + r, err := td.getObject(ctx) if err != nil { - return fmt.Errorf("cannot open dump file: %w", err) + return fmt.Errorf("cannot get storage object: %w", err) } - defer func(reader io.ReadCloser) { - if err := reader.Close(); err != nil { + defer func() { + if err := r.Close(); err != nil { log.Warn(). Err(err). - Msg("error closing dump file") + Str("objectName", td.DebugInfo()). + Msg("cannot close storage object") } - }(r) - gz, err := ioutils.GetGzipReadCloser(r, td.usePgzip) - if err != nil { - return fmt.Errorf("cannot create gzip reader: %w", err) - } - defer func(gz io.Closer) { - if err := gz.Close(); err != nil { - log.Warn(). - Err(err). - Msg("error closing gzip reader") - } - }(gz) + }() - if err = td.streamInsertData(ctx, conn, gz); err != nil { - if td.exitOnError { + if err = td.streamInsertData(ctx, conn, r); err != nil { + if td.opt.ExitOnError { return fmt.Errorf("error streaming pgcopy data: %w", err) } log.Warn().Err(err).Msg("error streaming pgcopy data") @@ -146,6 +120,7 @@ func (td *TableRestorerInsertFormat) streamInsertData(ctx context.Context, conn default: } + // TODO: This might require some optimization because too many allocations line, err := reader.ReadLine(buf, nil) if err != nil { if errors.Is(err, io.EOF) { @@ -160,7 +135,7 @@ func (td *TableRestorerInsertFormat) streamInsertData(ctx context.Context, conn return fmt.Errorf("error decoding line: %w", err) } - if err = td.insertDataOnConflictDoNothing(ctx, conn, row); err != nil { + if err = td.insertData(ctx, conn, row); err != nil { if !td.isErrorAllowed(err) { return fmt.Errorf("error inserting data: %w", err) } else { @@ -187,12 +162,12 @@ func (td *TableRestorerInsertFormat) generateInsertStmt(onConflictDoNothing bool } overridingSystemValue := "" - if td.overridingSystemValue { + if td.opt.OverridingSystemValue { overridingSystemValue = "OVERRIDING SYSTEM VALUE " } - tableName := *td.Entry.Tag - tableSchema := *td.Entry.Namespace + tableName := *td.entry.Tag + tableSchema := *td.entry.Namespace if td.Table.RootPtOid != 0 { tableName = td.Table.RootPtName @@ -211,20 +186,39 @@ func (td *TableRestorerInsertFormat) generateInsertStmt(onConflictDoNothing bool return res } -func (td *TableRestorerInsertFormat) insertDataOnConflictDoNothing( +func (td *TableRestorerInsertFormat) insertData( ctx context.Context, conn *pgx.Conn, row *pgcopy.Row, ) error { if td.query == "" { - td.query = td.generateInsertStmt(td.doNothing) + td.query = td.generateInsertStmt(td.opt.OnConflictDoNothing) + } + tx, err := conn.Begin(ctx) + if err != nil { + return fmt.Errorf("cannot start transaction (restoring %s): %w", td.DebugInfo(), err) + } + + if err := td.setupTx(ctx, tx); err != nil { + rollbackTransaction(ctx, tx, td.entry) + return fmt.Errorf("cannot setup transaction: %w", err) } - // TODO: The implementation based on pgx.Conn.Exec is not efficient for bulk inserts. - // Consider rewrite to string literal that contains generated statement instead of using prepared statement + // TODO: The implementation based on Exec is not efficient for bulk inserts. + // Consider rewrite to string literal that contains generated statement instead of using prepared statement // in driver - _, err := conn.Exec(ctx, td.query, getAllArguments(row)...) + _, err = tx.Exec(ctx, td.query, getAllArguments(row)...) if err != nil { + rollbackTransaction(ctx, tx, td.entry) return err } + + if err := td.resetTx(ctx, tx); err != nil { + rollbackTransaction(ctx, tx, td.entry) + return fmt.Errorf("cannot reset transaction: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("cannot commit transaction (restoring %s): %w", td.DebugInfo(), err) + } return nil } @@ -253,10 +247,6 @@ func (td *TableRestorerInsertFormat) isErrorAllowed(err error) bool { return false } -func (td *TableRestorerInsertFormat) DebugInfo() string { - return fmt.Sprintf("table %s.%s", *td.Entry.Namespace, *td.Entry.Tag) -} - func getAllArguments(row *pgcopy.Row) []any { var res []any for i := 0; i < row.Length(); i++ { diff --git a/internal/db/postgres/restorers/table_insert_format_test.go b/internal/db/postgres/restorers/table_insert_format_test.go new file mode 100644 index 00000000..854bdd1d --- /dev/null +++ b/internal/db/postgres/restorers/table_insert_format_test.go @@ -0,0 +1,197 @@ +package restorers + +import ( + "bytes" + "compress/gzip" + "context" + + "github.com/stretchr/testify/mock" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/internal/utils/testutils" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func (s *restoresSuite) Test_TableRestorerInsertFormat_check_triggers_errors() { + s.Run("check triggers causes error by default", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "6\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().ErrorContains(err, "Test exception (SQLSTATE P0001)") + }) + + s.Run("disable triggers", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "7\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) + + s.Run("session_replication_role is replica", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "8\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) +} diff --git a/internal/db/postgres/restorers/table_test.go b/internal/db/postgres/restorers/table_test.go new file mode 100644 index 00000000..7f1d31bb --- /dev/null +++ b/internal/db/postgres/restorers/table_test.go @@ -0,0 +1,128 @@ +package restorers + +import ( + "bytes" + "compress/gzip" + "context" + + "github.com/stretchr/testify/mock" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/utils/testutils" +) + +func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { + s.Run("check triggers causes error by default", func() { + // Given + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "3\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().ErrorContains(err, "Test exception (code P0001)") + }) + + s.Run("disable triggers", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "4\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) + + s.Run("session_replication_role is replica", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "5\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + s.Require().NoError(err) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) +} diff --git a/internal/utils/ioutils/gzip_reader.go b/internal/utils/ioutils/gzip_reader.go new file mode 100644 index 00000000..141430f5 --- /dev/null +++ b/internal/utils/ioutils/gzip_reader.go @@ -0,0 +1,52 @@ +package ioutils + +import ( + "fmt" + "io" + + "github.com/rs/zerolog/log" +) + +type GzipReader struct { + gz io.ReadCloser + r io.ReadCloser +} + +func NewGzipReader(r io.ReadCloser, usePgzip bool) (*GzipReader, error) { + gz, err := GetGzipReadCloser(r, usePgzip) + if err != nil { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return nil, fmt.Errorf("cannot create gzip reader: %w", err) + } + + return &GzipReader{ + gz: gz, + r: r, + }, nil + +} + +func (r *GzipReader) Read(p []byte) (n int, err error) { + return r.gz.Read(p) +} + +func (r *GzipReader) Close() error { + var lastErr error + if err := r.gz.Close(); err != nil { + lastErr = fmt.Errorf("error closing gzip reader: %w", err) + log.Warn(). + Err(err). + Msg("error closing gzip reader") + } + if err := r.r.Close(); err != nil { + lastErr = fmt.Errorf("error closing dump file: %w", err) + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return lastErr +} diff --git a/internal/utils/ioutils/gzip_reader_test.go b/internal/utils/ioutils/gzip_reader_test.go new file mode 100644 index 00000000..fb23d289 --- /dev/null +++ b/internal/utils/ioutils/gzip_reader_test.go @@ -0,0 +1,62 @@ +package ioutils + +import ( + "bytes" + "compress/gzip" + "testing" + + "github.com/stretchr/testify/require" +) + +type readCloserMock struct { + *bytes.Buffer + closeCallCount int +} + +func (r *readCloserMock) Close() error { + r.closeCallCount++ + return nil +} + +func TestNewGzipReader_Read(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + objSrc := &readCloserMock{Buffer: buf} + r, err := NewGzipReader(objSrc, false) + require.NoError(t, err) + readBuf := make([]byte, 1024) + n, err := r.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte(data), readBuf[:n]) +} + +func TestNewGzipReader_Close(t *testing.T) { + data := "" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + objSrc := &readCloserMock{Buffer: buf, closeCallCount: 0} + r, err := NewGzipReader(objSrc, false) + require.NoError(t, err) + err = r.Close() + require.NoError(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + gz := r.gz.(*gzip.Reader) + _, err = gz.Read([]byte{}) + require.Error(t, err) +} diff --git a/internal/utils/ioutils/gzip_writer.go b/internal/utils/ioutils/gzip_writer.go index 5b8390f7..51dd5f54 100644 --- a/internal/utils/ioutils/gzip_writer.go +++ b/internal/utils/ioutils/gzip_writer.go @@ -52,17 +52,18 @@ func (gw *GzipWriter) Write(p []byte) (int, error) { // Close - closing method with gz buffer flushing func (gw *GzipWriter) Close() error { - defer gw.w.Close() - flushErr := gw.gz.Flush() - if flushErr != nil { - log.Warn().Err(flushErr).Msg("error flushing gzip buffer") + var globalErr error + if err := gw.gz.Flush(); err != nil { + globalErr = fmt.Errorf("error flushing gzip buffer: %w", err) + log.Warn().Err(err).Msg("error flushing gzip buffer") } - if closeErr := gw.gz.Close(); closeErr != nil || flushErr != nil { - err := closeErr - if flushErr != nil { - err = flushErr - } - return fmt.Errorf("error closing gzip writer: %w", err) + if err := gw.gz.Close(); err != nil { + globalErr = fmt.Errorf("error closing gzip writer: %w", err) + log.Warn().Err(err).Msg("error closing gzip writer") } - return nil + if err := gw.w.Close(); err != nil { + globalErr = fmt.Errorf("error closing dump file: %w", err) + log.Warn().Err(err).Msg("error closing dump file") + } + return globalErr } diff --git a/internal/utils/ioutils/gzip_writer_test.go b/internal/utils/ioutils/gzip_writer_test.go new file mode 100644 index 00000000..f0800dcf --- /dev/null +++ b/internal/utils/ioutils/gzip_writer_test.go @@ -0,0 +1,111 @@ +package ioutils + +import ( + "bytes" + "compress/gzip" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type writeCloserMock struct { + data []byte + writeCallCount int + writeCallFunc func(callCount int) error + closeCallCount int + closeCallFunc func(callCount int) error +} + +func (w *writeCloserMock) Write(p []byte) (n int, err error) { + w.writeCallCount++ + if w.writeCallFunc != nil { + return 0, w.writeCallFunc(w.writeCallCount) + } + w.data = append(w.data, p...) + return len(p), nil +} + +func (w *writeCloserMock) Close() error { + w.closeCallCount++ + if w.closeCallFunc != nil { + return w.closeCallFunc(w.closeCallCount) + } + return nil +} + +func TestNewGzipWriter_Write(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + testDataBuf := new(bytes.Buffer) + gzData := gzip.NewWriter(testDataBuf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + expectedData := testDataBuf.Bytes() + + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + objSrc := &writeCloserMock{} + r := NewGzipWriter(objSrc, false) + require.NoError(t, err) + _, err = r.Write([]byte(data)) + require.NoError(t, err) + err = r.Close() + require.NoError(t, err) + + require.Equal(t, expectedData, objSrc.data) +} + +func TestNewGzipWriter_Close(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + t.Run("Success", func(t *testing.T) { + objSrc := &writeCloserMock{} + r := NewGzipWriter(objSrc, false) + err := r.Close() + require.NoError(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + }) + + t.Run("Flush Error", func(t *testing.T) { + objSrc := &writeCloserMock{ + writeCallFunc: func(c int) error { + if c == 2 { + return errors.New("storage object error") + } + return nil + }, + } + r := NewGzipWriter(objSrc, false) + _, err := r.Write([]byte(data)) + require.NoError(t, err) + + err = r.Close() + require.Error(t, err) + require.ErrorContains(t, err, "error closing gzip writer") + require.Equal(t, 1, objSrc.closeCallCount) + require.Equal(t, 2, objSrc.writeCallCount) + }) + + t.Run("Storage object close Error", func(t *testing.T) { + objSrc := &writeCloserMock{ + closeCallFunc: func(c int) error { + return errors.New("storage object error") + }, + } + r := NewGzipWriter(objSrc, false) + err := r.Close() + require.Error(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + require.ErrorContains(t, err, "error closing dump file") + }) +} diff --git a/internal/utils/testutils/containers.go b/internal/utils/testutils/containers.go new file mode 100644 index 00000000..8e5faba5 --- /dev/null +++ b/internal/utils/testutils/containers.go @@ -0,0 +1,126 @@ +package testutils + +import ( + "context" + "fmt" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + testContainerPort = "5432" + testContainerDatabase = "testdb" + testContainerUser = "testuser" + testContainerPassword = "testpassword" + testContainerImage = "postgres:17" + testContainerExposedPort = "5432/tcp" +) + +type PgContainerSuite struct { + suite.Suite + username string + Container testcontainers.Container + MigrationUp string + MigrationDown string +} + +func (s *PgContainerSuite) SetupSuite() { + ctx := context.Background() + s.username = testContainerUser + req := testcontainers.ContainerRequest{ + Image: testContainerImage, // Specify the PostgreSQL image + ExposedPorts: []string{testContainerExposedPort}, // Expose the PostgreSQL port + Env: map[string]string{ + "POSTGRES_USER": testContainerUser, + "POSTGRES_PASSWORD": testContainerPassword, + "POSTGRES_DB": testContainerDatabase, + }, + WaitingFor: wait.ForSQL(testContainerExposedPort, "pgx", func(host string, port nat.Port) string { + return fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s?sslmode=disable", + testContainerUser, testContainerPassword, host, port.Port(), testContainerDatabase, + ) + }), + } + + var err error + s.Container, err = testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + s.Require().NoErrorf(err, "failed to start PostgreSQL Container") + + s.MigrateUp(ctx) +} + +func (s *PgContainerSuite) TearDownSuite() { + ctx := context.Background() + s.MigrateDown(ctx) + err := s.Container.Terminate(ctx) + s.Assert().NoErrorf(err, "failed to terminate PostgreSQL Container") +} + +func (s *PgContainerSuite) SetMigrationUp(sql string) *PgContainerSuite { + s.MigrationUp = sql + return s +} + +func (s *PgContainerSuite) SetMigrationDown(sql string) *PgContainerSuite { + s.MigrationDown = sql + return s +} + +func (s *PgContainerSuite) GetConnection(ctx context.Context) ( + conn *pgx.Conn, err error, +) { + return s.GetConnectionWithUser(ctx, testContainerUser, testContainerPassword) +} + +func (s *PgContainerSuite) GetConnectionWithUser(ctx context.Context, username, password string) ( + conn *pgx.Conn, err error, +) { + // Get the host and port for connecting to the Container + host, err := s.Container.Host(ctx) + s.Require().NoErrorf(err, "failed to get Container host") + port, err := s.Container.MappedPort(ctx, testContainerPort) + s.Require().NoErrorf(err, "failed to get Container port") + + // Create the connection string + connStr := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s?sslmode=disable", + username, password, host, port.Port(), testContainerDatabase, + ) + + return pgx.Connect(ctx, connStr) +} + +func (s *PgContainerSuite) GetSuperUser() string { + return testContainerUser +} + +func (s *PgContainerSuite) MigrateUp(ctx context.Context) { + if s.MigrationUp == "" { + return + } + conn, err := s.GetConnection(ctx) + s.Require().NoErrorf(err, "failed to connect to PostgreSQL") + defer conn.Close(ctx) + _, err = conn.Exec(ctx, s.MigrationUp) + s.Require().NoErrorf(err, "failed to run up migration") +} + +func (s *PgContainerSuite) MigrateDown(ctx context.Context) { + if s.MigrationDown == "" { + return + } + conn, err := s.GetConnection(ctx) + s.Require().NoErrorf(err, "failed to connect to PostgreSQL") + defer conn.Close(ctx) + _, err = conn.Exec(ctx, s.MigrationDown) + s.Require().NoErrorf(err, "failed to run down migration") +} diff --git a/internal/utils/testutils/storage.go b/internal/utils/testutils/storage.go new file mode 100644 index 00000000..f854b0fa --- /dev/null +++ b/internal/utils/testutils/storage.go @@ -0,0 +1,65 @@ +package testutils + +import ( + "context" + "io" + + "github.com/stretchr/testify/mock" + + "github.com/greenmaskio/greenmask/internal/storages" + "github.com/greenmaskio/greenmask/internal/storages/domains" +) + +type StorageMock struct { + mock.Mock +} + +func (s *StorageMock) GetCwd() string { + args := s.Called() + return args.String(0) +} + +func (s *StorageMock) Dirname() string { + args := s.Called() + return args.String(0) +} + +func (s *StorageMock) ListDir(ctx context.Context) (files []string, dirs []storages.Storager, err error) { + args := s.Called(ctx) + return args.Get(0).([]string), args.Get(1).([]storages.Storager), args.Error(2) +} + +func (s *StorageMock) GetObject(ctx context.Context, filePath string) (reader io.ReadCloser, err error) { + args := s.Called(ctx, filePath) + return args.Get(0).(io.ReadCloser), args.Error(1) +} + +func (s *StorageMock) PutObject(ctx context.Context, filePath string, body io.Reader) error { + args := s.Called(ctx, filePath, body) + return args.Error(0) +} + +func (s *StorageMock) Delete(ctx context.Context, filePaths ...string) error { + args := s.Called(ctx, filePaths) + return args.Error(0) +} + +func (s *StorageMock) DeleteAll(ctx context.Context, pathPrefix string) error { + args := s.Called(ctx, pathPrefix) + return args.Error(0) +} + +func (s *StorageMock) Exists(ctx context.Context, fileName string) (bool, error) { + args := s.Called(ctx, fileName) + return args.Bool(0), args.Error(1) +} + +func (s *StorageMock) SubStorage(subPath string, relative bool) storages.Storager { + args := s.Called(subPath, relative) + return args.Get(0).(storages.Storager) +} + +func (s *StorageMock) Stat(fileName string) (*domains.ObjectStat, error) { + args := s.Called(fileName) + return args.Get(0).(*domains.ObjectStat), args.Error(1) +}