diff --git a/cmd/greenmask/cmd/restore/restore.go b/cmd/greenmask/cmd/restore/restore.go index 039725ba..1082e02e 100644 --- a/cmd/greenmask/cmd/restore/restore.go +++ b/cmd/greenmask/cmd/restore/restore.go @@ -171,6 +171,10 @@ func init() { "pgzip", "", false, "use pgzip decompression instead of gzip", ) + Cmd.Flags().Int64P( + "batch-size", "", 0, + "the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch_", + ) // Connection options: Cmd.Flags().StringP("host", "h", "/var/run/postgres", "database server host or socket directory") @@ -185,7 +189,7 @@ func init() { "disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables", "no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section", "strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "restore-in-order", - "pgzip", + "pgzip", "batch-size", "host", "port", "username", } { diff --git a/docs/commands/restore.md b/docs/commands/restore.md index 6db4a469..1a034bfe 100644 --- a/docs/commands/restore.md +++ b/docs/commands/restore.md @@ -18,6 +18,7 @@ allowing you to configure the restoration process as needed. Mostly it supports the same flags as the `pg_restore` utility, with some extra flags for Greenmask-specific features. ```text title="Supported flags" + --batch-size int the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch_ -c, --clean clean (drop) database objects before recreating -C, --create create the target database -a, --data-only restore only the data, no schema @@ -112,5 +113,23 @@ If your database has cyclic dependencies you will be notified about it but the r By default, Greenmask uses gzip decompression to restore data. In mist cases it is quite slow and does not utilize all available resources and is a bootleneck for IO operations. To speed up the restoration process, you can use the `--pgzip` flag to use pgzip decompression instead of gzip. This method splits the data into blocks, which are -decompressed in parallel, making it ideal for handling large volumes of data. The output remains a standard gzip file. +decompressed in parallel, making it ideal for handling large volumes of data. +```shell title="example with pgzip decompression" +greenmask --config=config.yml restore latest --pgzip +``` + +### Restore data batching + +The COPY command returns the error only on transaction commit. This means that if you have a large dump and an error +occurs, you will have to wait until the end of the transaction to see the error message. To avoid this, you can use the +`--batch-size` flag to specify the number of rows to insert in a single batch during the COPY command. If an error occurs +during the batch insertion, the error message will be displayed immediately. The data will be committed **only if all +batches are inserted successfully**. + +In the example below, the batch size is set to 1000 rows. This means that 1000 rows will be inserted in a single batch, +so you will be notified of any errors immediately after each batch is inserted. + +```shell title="example with batch size" +greenmask --config=config.yml restore latest --batch-size 1000 +``` diff --git a/internal/db/postgres/cmd/restore.go b/internal/db/postgres/cmd/restore.go index fc35946e..d95c62d6 100644 --- a/internal/db/postgres/cmd/restore.go +++ b/internal/db/postgres/cmd/restore.go @@ -646,7 +646,9 @@ func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTa r.cfg.ErrorExclusions, r.restoreOpt.Pgzip, ) } else { - task = restorers.NewTableRestorer(entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.Pgzip) + task = restorers.NewTableRestorer( + entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.Pgzip, r.restoreOpt.BatchSize, + ) } case toc.SequenceSetDesc: diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index d756570f..0c93e0e8 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -223,7 +223,7 @@ func (v *Validate) readRecords(r *bufio.Reader, t *entries.Table) (original, tra originalRow = pgcopy.NewRow(len(t.Columns)) transformedRow = pgcopy.NewRow(len(t.Columns)) - originalLine, err = reader.ReadLine(r) + originalLine, err = reader.ReadLine(r, nil) if err != nil { if errors.Is(err, io.EOF) { return nil, nil, err @@ -235,7 +235,7 @@ func (v *Validate) readRecords(r *bufio.Reader, t *entries.Table) (original, tra return nil, nil, io.EOF } - transformedLine, err = reader.ReadLine(r) + transformedLine, err = reader.ReadLine(r, nil) if err != nil { return nil, nil, fmt.Errorf("unable to read line: %w", err) } diff --git a/internal/db/postgres/pgrestore/pgrestore.go b/internal/db/postgres/pgrestore/pgrestore.go index c52895ea..4396054d 100644 --- a/internal/db/postgres/pgrestore/pgrestore.go +++ b/internal/db/postgres/pgrestore/pgrestore.go @@ -97,7 +97,8 @@ type Options struct { Inserts bool `mapstructure:"inserts"` RestoreInOrder bool `mapstructure:"restore-in-order"` // Use pgzip decompression instead of gzip - Pgzip bool `mapstructure:"pgzip"` + Pgzip bool `mapstructure:"pgzip"` + BatchSize int64 `mapstructure:"batch-size"` // Connection options: Host string `mapstructure:"host"` diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index 326f75f6..be30718c 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -15,6 +15,7 @@ package restorers import ( + "bufio" "context" "errors" "fmt" @@ -22,6 +23,7 @@ import ( "github.com/greenmaskio/greenmask/internal/utils/ioutils" "github.com/greenmaskio/greenmask/internal/utils/pgerrors" + "github.com/greenmaskio/greenmask/internal/utils/reader" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog/log" @@ -37,14 +39,18 @@ type TableRestorer struct { St storages.Storager exitOnError bool usePgzip bool + batchSize int64 } -func NewTableRestorer(entry *toc.Entry, st storages.Storager, exitOnError bool, usePgzip bool) *TableRestorer { +func NewTableRestorer( + entry *toc.Entry, st storages.Storager, exitOnError bool, usePgzip bool, batchSize int64, +) *TableRestorer { return &TableRestorer{ Entry: entry, St: st, exitOnError: exitOnError, usePgzip: usePgzip, + batchSize: batchSize, } } @@ -117,8 +123,14 @@ func (td *TableRestorer) restoreCopy(ctx context.Context, f *pgproto3.Frontend, return fmt.Errorf("error initializing pgcopy: %w", err) } - if err := td.streamCopyData(ctx, f, r); err != nil { - return fmt.Errorf("error streaming pgcopy data: %w", err) + if td.batchSize > 0 { + if err := td.streamCopyDataByBatch(ctx, f, r); err != nil { + return fmt.Errorf("error streaming pgcopy data: %w", err) + } + } else { + if err := td.streamCopyData(ctx, f, r); err != nil { + return fmt.Errorf("error streaming pgcopy data: %w", err) + } } if err := td.postStreamingHandle(ctx, f); err != nil { @@ -134,8 +146,7 @@ func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) err } // Prepare for streaming the pgcopy data - process := true - for process { + for { select { case <-ctx.Done(): return ctx.Err() @@ -148,35 +159,67 @@ func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) err } switch v := msg.(type) { case *pgproto3.CopyInResponse: - process = false + return nil case *pgproto3.ErrorResponse: return fmt.Errorf("error from postgres connection: %w", pgerrors.NewPgError(v)) default: return fmt.Errorf("unknown message %+v", v) } } - return nil } -func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Frontend, r io.Reader) error { - // Streaming pgcopy data from table dump - +// streamCopyDataByBatch - stream pgcopy data from table dump in batches. It handles errors only on the end each batch +// If the batch size is reached it completes the batch and starts a new one. If an error occurs during the batch it +// stops immediately and returns the error +func (td *TableRestorer) streamCopyDataByBatch(ctx context.Context, f *pgproto3.Frontend, r io.Reader) (err error) { + bi := bufio.NewReader(r) buf := make([]byte, DefaultBufferSize) + var lineNum int64 for { - var n int + buf, err = reader.ReadLine(bi, buf) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("error readimg from table dump: %w", err) + } + if isTerminationSeq(buf) { + break + } + lineNum++ + buf = append(buf, '\n') + + err = sendMessage(f, &pgproto3.CopyData{Data: buf}) + if err != nil { + return fmt.Errorf("error sending CopyData message: %w", err) + } + + if lineNum%td.batchSize == 0 { + if err = td.completeBatch(ctx, f); err != nil { + return fmt.Errorf("error completing batch: %w", err) + } + } + select { case <-ctx.Done(): return ctx.Err() default: } + } + return nil +} + +// streamCopyData - stream pgcopy data from table dump in classic way. It handles errors only on the end of the stream +func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Frontend, r io.Reader) error { + // Streaming pgcopy data from table dump + + buf := make([]byte, DefaultBufferSize) + for { + var n int n, err := r.Read(buf) if err != nil { if errors.Is(err, io.EOF) { - completionErr := sendMessage(f, &pgproto3.CopyDone{}) - if completionErr != nil { - return fmt.Errorf("error sending CopyDone message: %w", err) - } break } return fmt.Errorf("error readimg from table dump: %w", err) @@ -186,12 +229,32 @@ func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Fronten if err != nil { return fmt.Errorf("error sending DopyData message: %w", err) } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + return nil +} + +// completeBatch - complete batch of pgcopy data and initiate new one +func (td *TableRestorer) completeBatch(ctx context.Context, f *pgproto3.Frontend) error { + if err := td.postStreamingHandle(ctx, f); err != nil { + return err + } + if err := td.initCopy(ctx, f); err != nil { + return err } return nil } func (td *TableRestorer) postStreamingHandle(ctx context.Context, f *pgproto3.Frontend) error { // Perform post streaming handling + err := sendMessage(f, &pgproto3.CopyDone{}) + if err != nil { + return fmt.Errorf("error sending CopyDone message: %w", err) + } var mainErr error for { select { diff --git a/internal/db/postgres/restorers/table_insert_format.go b/internal/db/postgres/restorers/table_insert_format.go index 7d3359cf..5700493e 100644 --- a/internal/db/postgres/restorers/table_insert_format.go +++ b/internal/db/postgres/restorers/table_insert_format.go @@ -140,7 +140,7 @@ func (td *TableRestorerInsertFormat) streamInsertData(ctx context.Context, conn default: } - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { if errors.Is(err, io.EOF) { break diff --git a/internal/db/postgres/transformers/custom/dynamic_definition.go b/internal/db/postgres/transformers/custom/dynamic_definition.go index a562b810..16c488f1 100644 --- a/internal/db/postgres/transformers/custom/dynamic_definition.go +++ b/internal/db/postgres/transformers/custom/dynamic_definition.go @@ -85,7 +85,7 @@ func GetDynamicTransformerDefinition(ctx context.Context, executable string, arg buf := bufio.NewReader(bytes.NewBuffer(stdoutData)) for { - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { break } @@ -102,7 +102,7 @@ func GetDynamicTransformerDefinition(ctx context.Context, executable string, arg buf := bufio.NewReader(bytes.NewBuffer(stderrData)) for { - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { break } diff --git a/internal/db/postgres/transformers/utils/cmd_transformer_base.go b/internal/db/postgres/transformers/utils/cmd_transformer_base.go index 5599e6b3..5207481e 100644 --- a/internal/db/postgres/transformers/utils/cmd_transformer_base.go +++ b/internal/db/postgres/transformers/utils/cmd_transformer_base.go @@ -315,7 +315,7 @@ func (ctb *CmdTransformerBase) init() error { func (ctb *CmdTransformerBase) ReceiveStderrLine(ctx context.Context) (line []byte, err error) { go func() { - line, err = reader.ReadLine(ctb.StderrReader) + line, err = reader.ReadLine(ctb.StderrReader, nil) ctb.receiveChan <- struct{}{} }() select { @@ -333,7 +333,7 @@ func (ctb *CmdTransformerBase) ReceiveStderrLine(ctx context.Context) (line []by func (ctb *CmdTransformerBase) ReceiveStdoutLine(ctx context.Context) (line []byte, err error) { go func() { - line, err = reader.ReadLine(ctb.StdoutReader) + line, err = reader.ReadLine(ctb.StdoutReader, nil) ctb.receiveChan <- struct{}{} }() select { diff --git a/internal/utils/cmd_runner/cmd_runner.go b/internal/utils/cmd_runner/cmd_runner.go index ceec1863..3a7021f7 100644 --- a/internal/utils/cmd_runner/cmd_runner.go +++ b/internal/utils/cmd_runner/cmd_runner.go @@ -53,7 +53,7 @@ func Run(ctx context.Context, logger *zerolog.Logger, name string, args ...strin return gtx.Err() default: } - line, err := reader.ReadLine(lineScanner) + line, err := reader.ReadLine(lineScanner, nil) if err != nil { if errors.Is(err, io.EOF) { return nil @@ -73,7 +73,7 @@ func Run(ctx context.Context, logger *zerolog.Logger, name string, args ...strin return gtx.Err() default: } - line, err := reader.ReadLine(lineScanner) + line, err := reader.ReadLine(lineScanner, nil) if err != nil { if errors.Is(err, io.EOF) { return nil diff --git a/internal/utils/reader/reader.go b/internal/utils/reader/reader.go index 59164ee7..d28d24de 100644 --- a/internal/utils/reader/reader.go +++ b/internal/utils/reader/reader.go @@ -5,18 +5,18 @@ import ( "fmt" ) -func ReadLine(r *bufio.Reader) ([]byte, error) { - var res []byte +func ReadLine(r *bufio.Reader, buf []byte) ([]byte, error) { + buf = buf[:0] for { var line []byte line, isPrefix, err := r.ReadLine() if err != nil { return nil, fmt.Errorf("unable to read line: %w", err) } - res = append(res, line...) + buf = append(buf, line...) if !isPrefix { break } } - return res, nil + return buf, nil }