Skip to content

Commit

Permalink
Refresh collation versions if the underlying locale version changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
davissp14 committed Jun 23, 2024
1 parent 17ed339 commit 58eddf9
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 146 deletions.
18 changes: 0 additions & 18 deletions bin/refresh-collation

This file was deleted.

109 changes: 0 additions & 109 deletions internal/flypg/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package admin

import (
"context"
"database/sql"
"fmt"
"log"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -431,109 +428,3 @@ func ValidatePGSettings(ctx context.Context, conn *pgx.Conn, requested map[strin

return nil
}

func fixCollationMismatch(ctx context.Context, db *sql.DB) error {
query := `
SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation",
pg_describe_object(classid, objid, objsubid) AS "Object"
FROM pg_depend d JOIN pg_collation c
ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid
WHERE c.collversion <> pg_collation_actual_version(c.oid)
ORDER BY 1, 2;`

rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("failed to query collation mismatches: %v", err)
}
defer rows.Close()

var collation, object string
for rows.Next() {
if err := rows.Scan(&collation, &object); err != nil {
return fmt.Errorf("failed to scan row: %v", err)
}

fixObject(db, object)
}

if err := rows.Err(); err != nil {
return fmt.Errorf("failed to iterate over rows: %v", err)
}

return nil
}

func fixObject(db *sql.DB, object string) {
fmt.Printf("Fixing object: %s\n", object)

switch {
case regexp.MustCompile(`index`).MatchString(object):
// reindex(db, object)
case regexp.MustCompile(`column`).MatchString(object):
// alterColumn(db, object)
case regexp.MustCompile(`constraint`).MatchString(object):
// dropAndRecreateConstraint(db, object)
case regexp.MustCompile(`materialized view`).MatchString(object):
// refreshMaterializedView(db, object)
case regexp.MustCompile(`function`).MatchString(object):
// recreateFunction(db, object)
case regexp.MustCompile(`view`).MatchString(object):
// recreateView(db, object)
case regexp.MustCompile(`trigger`).MatchString(object):
// recreateTrigger(db, object)
default:
log.Printf("Unknown object type: %s", object)
}
}

const refreshCollationSQL = `
DO $$
DECLARE
r RECORD;
BEGIN
FOR r IN (SELECT datname FROM pg_database WHERE datallowconn = true)
LOOP
BEGIN
EXECUTE 'ALTER DATABASE ' || quote_ident(r.datname) || ' REFRESH COLLATION VERSION;';
EXCEPTION
WHEN OTHERS THEN
RAISE NOTICE 'Failed to refresh collation for database: % - %', r.datname, SQLERRM;
END;
END LOOP;
END $$;`

// RefreshCollationVersion will refresh the collation version for all databases.
func RefreshCollationVersion(ctx context.Context, conn *pgx.Conn) error {
_, err := conn.Exec(ctx, refreshCollationSQL)
return err
}

const identifyCollationObjectsSQL = `
SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation",
pg_describe_object(classid, objid, objsubid) AS "Object"
FROM pg_depend d JOIN pg_collation c
ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid
WHERE c.collversion <> pg_collation_actual_version(c.oid)
ORDER BY 1, 2;`

const reIndexSQL = `
DO $$
DECLARE
r RECORD;
BEGIN
FOR r IN (SELECT n.nspname, i.relname
FROM pg_index x
JOIN pg_class c ON c.oid = x.indrelid
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_class i ON i.oid = x.indexrelid
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(x.indkey)
JOIN pg_collation col ON col.oid = a.attcollation
WHERE col.collname = 'en_US.utf8') LOOP
EXECUTE 'REINDEX INDEX ' || quote_ident(r.nspname) || '.' || quote_ident(r.relname);
END LOOP;
END $$;`

func ReIndex(ctx context.Context, conn *pgx.Conn) error {
_, err := conn.Exec(ctx, reIndexSQL)
return err
}
159 changes: 159 additions & 0 deletions internal/flypg/collations.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package flypg

import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"log"
"os"

"github.com/fly-apps/postgres-flex/internal/utils"
"github.com/jackc/pgx/v5"
)

const collationVersionFile = "/data/.collationVersion"

func calculateLocaleVersionHash() (string, error) {
output, err := utils.RunCommand("locale --version", "postgres")
if err != nil {
return "", fmt.Errorf("failed to read locale version: %w", err)
}

hash := sha256.Sum256(output)
return hex.EncodeToString(hash[:]), nil
}

func writeCollationVersionFile(versionHash string) error {
// Write the collation lock file.
if err := os.WriteFile(collationVersionFile, []byte(versionHash), 0600); err != nil {
return fmt.Errorf("failed to write collation version file: %w", err)
}

return nil
}

func collationHashChanged(versionHash string) (bool, error) {
// Short-circuit if there's no collation file.
_, err := os.Stat(collationVersionFile)
switch {
case os.IsNotExist(err):
return true, nil
case err != nil:
return false, fmt.Errorf("failed to stat collation lock file: %w", err)
}

// Read the collation version file.
oldVersionHash, err := os.ReadFile(collationVersionFile)
if err != nil {
return false, fmt.Errorf("failed to read collation lock file: %w", err)
}

// Compare the version hashes.
return versionHash != string(oldVersionHash), nil
}

const identifyImpactedCollationObjectsSQL = `
SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation",
pg_describe_object(classid, objid, objsubid) AS "Object"
FROM pg_depend d JOIN pg_collation c
ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid
WHERE c.collversion <> pg_collation_actual_version(c.oid)
ORDER BY 1, 2;
`

type collationObject struct {
collation string
object string
}

func impactedCollationObjects(ctx context.Context, conn *pgx.Conn) ([]collationObject, error) {
rows, err := conn.Query(ctx, identifyImpactedCollationObjectsSQL)
if err != nil {
return nil, fmt.Errorf("failed to query impacted objects: %v", err)
}
defer rows.Close()

var objects []collationObject

var collation, object string
for rows.Next() {
if err := rows.Scan(&collation, &object); err != nil {
return nil, fmt.Errorf("failed to scan row: %v", err)
}
objects = append(objects, collationObject{collation: collation, object: object})
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate over rows: %v", err)
}

return objects, nil
}

func refreshCollations(ctx context.Context, dbConn *pgx.Conn, dbName string) error {
if dbName != "template1" {
if err := refreshDatabaseCollations(ctx, dbConn, dbName); err != nil {
return err
}
}

return refreshDatabase(ctx, dbConn, dbName)
}

func refreshDatabaseCollations(ctx context.Context, dbConn *pgx.Conn, dbName string) error {
collations, err := fetchCollations(ctx, dbConn)
if err != nil {
return fmt.Errorf("failed to fetch collations: %w", err)
}

for _, collation := range collations {
if err := refreshCollation(ctx, dbConn, collation); err != nil {
log.Printf("[WARN] failed to refresh collation version in db %s: %v\n", dbName, err)
}
}

return nil
}

func refreshCollation(ctx context.Context, dbConn *pgx.Conn, collation string) error {
query := fmt.Sprintf("ALTER COLLATION pg_catalog.\"%s\" REFRESH VERSION;", collation)
_, err := dbConn.Exec(ctx, query)
return err
}

func refreshDatabase(ctx context.Context, dbConn *pgx.Conn, dbName string) error {
query := fmt.Sprintf("ALTER DATABASE %s REFRESH COLLATION VERSION;", dbName)
_, err := dbConn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to refresh database collation version: %w", err)
}
return nil
}

func fetchCollations(ctx context.Context, dbConn *pgx.Conn) ([]string, error) {
query := "SELECT DISTINCT datcollate FROM pg_database WHERE datcollate != 'C'"
rows, err := dbConn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to fetch collations: %w", err)
}
defer rows.Close()

var collations []string
for rows.Next() {
var collation sql.NullString
if err := rows.Scan(&collation); err != nil {
return nil, fmt.Errorf("failed to scan collation row: %w", err)
}
if collation.Valid {
collations = append(collations, collation.String)
}
}

if rows.Err() != nil {
return nil, rows.Err()
}

return collations, nil
}
66 changes: 56 additions & 10 deletions internal/flypg/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ func (n *Node) PostInit(ctx context.Context) error {
}
}

// Refresh collation for all databases.
if err := refreshCollation(ctx, conn); err != nil {
log.Printf("failed to refresh collation: %s", err)
// This is a safety check to ensure collation integrity is maintained.
if err := n.evaluateCollationIntegrity(ctx, conn); err != nil {
log.Printf("[WARN] Problem occurred while evaluating collation integrity: %s", err)
}

case StandbyRoleName:
Expand Down Expand Up @@ -478,30 +478,76 @@ func setDirOwnership() error {
return nil
}

func (n *Node) fixCollationMismatch(ctx context.Context, conn *pgx.Conn) error {
func (n *Node) evaluateCollationIntegrity(ctx context.Context, conn *pgx.Conn) error {
// Calculate the current collation version hash.
versionHash, err := calculateLocaleVersionHash()
if err != nil {
return fmt.Errorf("failed to calculate collation sum: %w", err)
}

// Check to see if the collation version has changed.
changed, err := collationHashChanged(versionHash)
if err != nil {
return fmt.Errorf("failed to check collation version file: %s", err)
}

if !changed {
log.Printf("[INFO] Collation version has not changed.\n")
return nil
}

fmt.Printf("[INFO] Collation version has changed or has not been evaluated. Evaluating collation integrity.\n")

dbs, err := admin.ListDatabases(ctx, conn)
if err != nil {
return fmt.Errorf("failed to list databases: %s", err)
}

// Add the template1 database to the list of databases to refresh.
dbs = append(dbs, admin.DbInfo{Name: "template1"})

collationIssues := 0

for _, db := range dbs {
// Establish a connection to the database.
dbConn, err := n.NewLocalConnection(ctx, db.Name, n.SUCredentials)
if err != nil {
return fmt.Errorf("failed to establish connection to local node: %s", err)
return fmt.Errorf("failed to establish connection to database %s: %s", db.Name, err)
}
defer func() { _ = dbConn.Close(ctx) }()

if err := admin.RefreshCollationVersion(ctx, dbConn); err != nil {
return fmt.Errorf("failed to refresh collation: %s", err)
log.Printf("[INFO] Refreshing collations for database %s\n", db.Name)

if err := refreshCollations(ctx, dbConn, db.Name); err != nil {
return fmt.Errorf("failed to refresh collations for db %s: %s", db.Name, err)
}

// TODO - Consider logging a link to documentation on how to resolve collation issues not resolved by the refresh process.

// The collation refresh process should resolve "most" issues, but there are cases that may require
// re-indexing or other manual intervention. In the event any objects are found we will log a warning.
colObjects, err := impactedCollationObjects(ctx, dbConn)
if err != nil {
return fmt.Errorf("failed to fetch impacted collation objects: %s", err)
}

if err := admin.ReIndex(ctx, dbConn); err != nil {
return fmt.Errorf("failed to reindex database: %s", err)
for _, obj := range colObjects {
log.Printf("[WARN] Collation mismatch detected - Database %s, Collation: %s, Object: %s\n", db.Name, obj.collation, obj.object)
collationIssues++
}
}

// Don't set the version file if there are collation issues.
// This will force the system to re-evaluate the collation integrity on the next boot and ensure
// issues continue to be logged.
if collationIssues > 0 {
return nil
}

// No collation issues found, we can safely update the version file.
// This will prevent the system from re-evaluating the collation integrity on every boot.
if err := writeCollationVersionFile(versionHash); err != nil {
return fmt.Errorf("failed to write collation version file: %s", err)
}

return nil
}
Loading

0 comments on commit 58eddf9

Please sign in to comment.