Skip to content

Commit

Permalink
fix: Fixed generated columns dumping and restoration #77
Browse files Browse the repository at this point in the history
* Added column attgenerated introspection for pg >= 12 version
* Excluded generated columns from COPY stmnt
  • Loading branch information
wwoytenko committed Apr 26, 2024
1 parent c21cc3b commit 36ff821
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 27 deletions.
4 changes: 2 additions & 2 deletions internal/db/postgres/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ func NewRuntimeContext(
return nil, fmt.Errorf("cannot validate and build table config: %w", err)
}

dataSectionObjects, err := getDumpObjects(ctx, tx, opt, tables)
dataSectionObjects, err := getDumpObjects(ctx, version, tx, opt, tables)
if err != nil {
return nil, fmt.Errorf("cannot build dump object list: %w", err)
}

scoreTablesEntriesAndSort(dataSectionObjects, cfg)

schema, err := getDatabaseSchema(ctx, tx, opt)
schema, err := getDatabaseSchema(ctx, tx, opt, version)
if err != nil {
return nil, fmt.Errorf("cannot get database schema: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/db/postgres/context/pg_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const (
// TODO: Rewrite it using gotemplate

func getDumpObjects(
ctx context.Context, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table,
ctx context.Context, version int, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table,
) ([]entries.Entry, error) {

// Building relation search query using regexp adaptation rules and pre-defined query templates
Expand Down Expand Up @@ -137,7 +137,7 @@ func getDumpObjects(
for _, obj := range dataObjects {
switch v := obj.(type) {
case *entries.Table:
columns, err := getColumnsConfig(ctx, tx, v.Oid)
columns, err := getColumnsConfig(ctx, tx, v.Oid, version)
if err != nil {
return nil, fmt.Errorf("unable to collect table columns: %w", err)
}
Expand Down
7 changes: 5 additions & 2 deletions internal/db/postgres/context/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ var (
`

// TableColumnsQuery - SQL query for getting all columns of table
TableColumnsQuery = `
TableColumnsQuery = template.Must(template.New("TableColumnsQuery").Parse(`
SELECT
a.attname as name,
a.atttypid::TEXT::INT as typeoid,
pg_catalog.format_type(a.atttypid, a.atttypmod) as typename,
a.attnotnull as notnull,
a.atttypmod as mod,
a.attnum as num
{{ if ge .Version 120000 }}
,a.attgenerated != '' as attgenerated
{{ end }}
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = $1 AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
`
`))

CustomTypesWithTypeChainQuery = `
with RECURSIVE
Expand Down
4 changes: 2 additions & 2 deletions internal/db/postgres/context/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func getDatabaseSchema(
ctx context.Context, tx pgx.Tx, options *pgdump.Options,
ctx context.Context, tx pgx.Tx, options *pgdump.Options, version int,
) ([]*toolkit.Table, error) {
var res []*toolkit.Table
query, err := BuildSchemaIntrospectionQuery(
Expand Down Expand Up @@ -41,7 +41,7 @@ func getDatabaseSchema(

// fill columns
for _, table := range res {
columns, err := getColumnsConfig(ctx, tx, table.Oid)
columns, err := getColumnsConfig(ctx, tx, table.Oid, version)
if err != nil {
return nil, err
}
Expand Down
30 changes: 20 additions & 10 deletions internal/db/postgres/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func validateAndBuildTablesConfig(
table.Constraints = constraints

// Assign columns and transformersMap if were found
columns, err := getColumnsConfig(ctx, tx, table.Oid)
columns, err := getColumnsConfig(ctx, tx, table.Oid, version)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -177,9 +177,17 @@ func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Tabl
return tables, warnings, nil
}

func getColumnsConfig(ctx context.Context, tx pgx.Tx, oid toolkit.Oid) ([]*toolkit.Column, error) {
func getColumnsConfig(ctx context.Context, tx pgx.Tx, oid toolkit.Oid, version int) ([]*toolkit.Column, error) {
var res []*toolkit.Column
rows, err := tx.Query(ctx, TableColumnsQuery, oid)
buf := bytes.NewBuffer(nil)
err := TableColumnsQuery.Execute(
buf,
map[string]int{"Version": version},
)
if err != nil {
return nil, fmt.Errorf("error templating TableColumnsQuery: %w", err)
}
rows, err := tx.Query(ctx, buf.String(), oid)
if err != nil {
return nil, fmt.Errorf("unable execute tableColumnQuery: %w", err)
}
Expand All @@ -188,8 +196,14 @@ func getColumnsConfig(ctx context.Context, tx pgx.Tx, oid toolkit.Oid) ([]*toolk
idx := 0
for rows.Next() {
column := toolkit.Column{Idx: idx}
if err = rows.Scan(&column.Name, &column.TypeOid, &column.TypeName,
&column.NotNull, &column.Length, &column.Num); err != nil {
if version >= 120000 {
err = rows.Scan(&column.Name, &column.TypeOid, &column.TypeName,
&column.NotNull, &column.Length, &column.Num, &column.IsGenerated)
} else {
err = rows.Scan(&column.Name, &column.TypeOid, &column.TypeName,
&column.NotNull, &column.Length, &column.Num)
}
if err != nil {
return nil, fmt.Errorf("cannot scan tableColumnQuery: %w", err)
}
res = append(res, &column)
Expand Down Expand Up @@ -289,11 +303,7 @@ func getTableConstraints(ctx context.Context, tx pgx.Tx, tableOid toolkit.Oid, v
buf := bytes.NewBuffer(nil)
err = TablePrimaryKeyReferencesConstraintsQuery.Execute(
buf,
struct {
Version int
}{
Version: version,
},
map[string]int{"Version": version},
)
if err != nil {
return nil, fmt.Errorf("error templating TablePrimaryKeyReferencesConstraintsQuery: %w", err)
Expand Down
5 changes: 3 additions & 2 deletions internal/db/postgres/entries/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ func (t *Table) Entry() (*toc.Entry, error) {
columns := make([]string, 0, len(t.Columns))

for _, column := range t.Columns {
columns = append(columns, fmt.Sprintf(`"%s"`, column.Name))
if !column.IsGenerated {
columns = append(columns, fmt.Sprintf(`"%s"`, column.Name))
}
}

//var query = `COPY "%s"."%s" (%s) FROM stdin WITH (FORMAT CSV, NULL '\N');`
var query = `COPY "%s"."%s" (%s) FROM stdin`
var schemaName, tableName string
if t.LoadViaPartitionRoot && t.RootPtSchema != "" && t.RootPtName != "" {
Expand Down
15 changes: 8 additions & 7 deletions pkg/toolkit/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package toolkit

type Column struct {
Name string `json:"name"`
TypeName string `json:"type_name"`
TypeOid Oid `json:"type_oid"`
Num AttNum `json:"num"`
NotNull bool `json:"not_null"`
Length int `json:"length"`
Idx int `json:"idx"`
Name string `json:"name"`
TypeName string `json:"type_name"`
TypeOid Oid `json:"type_oid"`
Num AttNum `json:"num"`
NotNull bool `json:"not_null"`
Length int `json:"length"`
Idx int `json:"idx"`
IsGenerated bool `json:"is_generated"`
}

0 comments on commit 36ff821

Please sign in to comment.