Skip to content

Commit

Permalink
Change strict reading PG to only return rows when valid
Browse files Browse the repository at this point in the history
This is necessary for the error and retry logic to work in the strict read proxy
  • Loading branch information
josephschorr committed Jan 30, 2025
1 parent 542053f commit cdc7f5a
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 27 deletions.
65 changes: 65 additions & 0 deletions internal/datastore/postgres/postgres_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/authzed/spicedb/internal/datastore/common"
pgcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
pgversion "github.com/authzed/spicedb/internal/datastore/postgres/version"
"github.com/authzed/spicedb/internal/datastore/proxy"
"github.com/authzed/spicedb/internal/testfixtures"
testdatastore "github.com/authzed/spicedb/internal/testserver/datastore"
"github.com/authzed/spicedb/pkg/datastore"
Expand Down Expand Up @@ -240,6 +241,16 @@ func testPostgresDatastore(t *testing.T, config postgresTestConfig) {
MigrationPhase(config.migrationPhase),
))

t.Run("TestStrictReadModeFallback", createReplicaDatastoreTest(
b,
StrictReadModeFallbackTest,
RevisionQuantization(0),
GCWindow(1000*time.Second),
GCInterval(veryLargeGCInterval),
WatchBufferLength(50),
MigrationPhase(config.migrationPhase),
))

t.Run("TestLocking", createMultiDatastoreTest(
b,
LockingTest,
Expand Down Expand Up @@ -1568,6 +1579,60 @@ func LockingTest(t *testing.T, ds datastore.Datastore, ds2 datastore.Datastore)
require.NoError(t, err)
}

func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unwrappedReplicaDS datastore.Datastore) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Write some relationships.
_, err := primaryDS.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
rtu := tuple.Touch(tuple.MustParse("resource:123#reader@user:456"))
return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{rtu})
})
require.NoError(err)

// Get the HEAD revision.
lowestRevision, err := primaryDS.HeadRevision(ctx)
require.NoError(err)

// Wrap the replica DS.
replicaDS, err := proxy.NewStrictReplicatedDatastore(primaryDS, unwrappedReplicaDS.(datastore.StrictReadDatastore))
require.NoError(err)

// Perform a read at the head revision, which should succeed.
reader := replicaDS.SnapshotReader(lowestRevision)
it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(err)

found, err := datastore.IteratorToSlice(it)
require.NoError(err)
require.NotEmpty(found)

// Perform a read at a manually constructed revision beyond head, which should fallback to the primary.
badRev := postgresRevision{
snapshot: pgSnapshot{
// NOTE: the struct defines this value as uint64, but the underlying
// revision is defined as an int64, so we run into an overflow issue
// if we try and use a big uint64.
xmin: 123456789,
xmax: 123456789,
},
}

limit := uint64(50)
it, err = replicaDS.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
}, options.WithLimit(&limit))
require.NoError(err)

found2, err := datastore.IteratorToSlice(it)
require.NoError(err)
require.Equal(len(found), len(found2))
}

func StrictReadModeTest(t *testing.T, primaryDS datastore.Datastore, replicaDS datastore.Datastore) {
require := require.New(t)

Expand Down
28 changes: 22 additions & 6 deletions internal/datastore/postgres/strictreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/authzed/spicedb/internal/datastore/common"
pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

const pgInvalidArgument = "22023"
Expand All @@ -26,15 +27,15 @@ type strictReaderQueryFuncs struct {
func (srqf strictReaderQueryFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, args ...any) error {
// NOTE: it is *required* for the pgx.QueryExecModeSimpleProtocol to be added as pgx will otherwise wrap
// the query as a prepared statement, which does *not* support running more than a single statement at a time.
return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) rewriteError(err error) error {
Expand All @@ -53,13 +54,28 @@ func (srqf strictReaderQueryFuncs) rewriteError(err error) error {
return err
}

func (srqf strictReaderQueryFuncs) addAssertToSQL(sql string) string {
func (srqf strictReaderQueryFuncs) addAssertToSelectSQL(sql string) string {
spiceerrors.DebugAssert(func() bool {
return strings.HasPrefix(sql, "SELECT ")
}, "strictReaderQueryFuncs can only wrap SELECT queries")

// The assertion checks that the transaction is not reading from the future or from a
// transaction that is still in-progress on the replica. If the transaction is not yet
// available on the replica at all, the call to `pg_xact_status` will fail with an invalid
// argument error and a message indicating that the xid "is in the future". If the transaction
// does exist, but has not yet been committed (or aborted), the call to `pg_xact_status` will return
// "in progress". rewriteError will catch these cases and return a RevisionUnavailableError.
assertion := fmt.Sprintf(`; do $$ begin assert (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';end;$$`, srqf.revision.snapshot.xmin-1)
return sql + assertion
//
// We run the query *first* (but filtered) as PGX will not be able to read rows if the assertion
// is run first. However, we do not want to return any rows if the assertion will fail, so we add it
// as a filter to the select as well.
wrapped := fmt.Sprintf(`
SELECT * FROM (%s) WHERE pg_xact_status(%d::text::xid8) != 'in progress';
DO $$
BEGIN
ASSERT (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';
END
$$;
`, sql, srqf.revision.snapshot.xmin-1, srqf.revision.snapshot.xmin-1)
return wrapped
}
46 changes: 31 additions & 15 deletions internal/datastore/proxy/checkingreplicated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
)

func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewCheckingReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand All @@ -40,8 +40,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *tes
}

func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewCheckingReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand All @@ -55,8 +55,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t
func TestReplicatedReaderReturnsExpectedError(t *testing.T) {
for _, requireCheck := range []bool{true, false} {
t.Run(fmt.Sprintf("requireCheck=%v", requireCheck), func(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

var ds datastore.Datastore
if requireCheck {
Expand All @@ -79,14 +79,14 @@ func TestReplicatedReaderReturnsExpectedError(t *testing.T) {
}

type fakeDatastore struct {
isPrimary bool
revision datastore.Revision
state string
revision datastore.Revision
}

func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader {
return fakeSnapshotReader{
revision: revision,
isPrimary: f.isPrimary,
revision: revision,
state: f.state,
}
}

Expand Down Expand Up @@ -143,12 +143,12 @@ func (f fakeDatastore) IsStrictReadModeEnabled() bool {
}

type fakeSnapshotReader struct {
revision datastore.Revision
isPrimary bool
revision datastore.Revision
state string
}

func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) {
if fsr.isPrimary {
if fsr.state == "primary" {
return []datastore.RevisionedDefinition[*corev1.NamespaceDefinition]{
{
Definition: &corev1.NamespaceDefinition{
Expand All @@ -159,7 +159,7 @@ func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNam
}, nil
}

if !fsr.isPrimary && fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
return nil, common.NewRevisionUnavailableError(fmt.Errorf("revision not available"))
}

Expand Down Expand Up @@ -208,7 +208,7 @@ func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.Relat

func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator {
return func(yield func(tuple.Relationship, error) bool) {
if fsr.isPrimary {
if fsr.state == "primary" {
if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) {
return
}
Expand All @@ -218,6 +218,22 @@ func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator {
return
}

if fsr.state == "replica-with-normal-error" {
if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) {
return
}
if !yield(tuple.MustParse("resource:456#viewer@user:tom"), nil) {
return
}
if !yield(tuple.Relationship{}, fmt.Errorf("raising an expected error")) {
return
}
if !yield(tuple.MustParse("resource:789#viewer@user:tom"), nil) {
return
}
return
}

if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
yield(tuple.Relationship{}, common.NewRevisionUnavailableError(fmt.Errorf("revision not available")))
return
Expand Down
7 changes: 4 additions & 3 deletions internal/datastore/proxy/strictreplicated.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func queryRelationships[F any, O any](
return nil, err
}

isFirstResult := true
beforeResultsYielded := true
requiresFallback := false
return func(yield func(tuple.Relationship, error) bool) {
replicaLoop:
Expand All @@ -143,7 +143,7 @@ func queryRelationships[F any, O any](
// If the RevisionUnavailableError is returned on the first result, we should fallback
// to the primary.
if errors.As(err, &common.RevisionUnavailableError{}) {
if !isFirstResult {
if !beforeResultsYielded {
yield(tuple.Relationship{}, spiceerrors.MustBugf("RevisionUnavailableError should only be returned on the first result"))
return
}
Expand All @@ -154,9 +154,10 @@ func queryRelationships[F any, O any](
if !yield(tuple.Relationship{}, err) {
return
}
continue
}

isFirstResult = false
beforeResultsYielded = false
if !yield(result, nil) {
return
}
Expand Down
34 changes: 31 additions & 3 deletions internal/datastore/proxy/strictreplicated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}

replicated, err := NewStrictReplicatedDatastore(primary)
require.NoError(t, err)
Expand All @@ -20,8 +20,8 @@ func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) {
}

func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewStrictReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand Down Expand Up @@ -87,3 +87,31 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t
require.NoError(t, err)
require.Equal(t, 2, len(revfound))
}

func TestStrictReplicatedQueryNonFallbackError(t *testing.T) {
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica-with-normal-error", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewStrictReplicatedDatastore(primary, replica)
require.NoError(t, err)

// Query the replicated, which should return the error.
reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"))
iter, err := reader.QueryRelationships(context.Background(), datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(t, err)

relsCollected := 0
var errFound error
for _, err := range iter {
if err != nil {
errFound = err
} else {
relsCollected++
}
}

require.Equal(t, 3, relsCollected)
require.ErrorContains(t, errFound, "raising an expected error")
}

0 comments on commit cdc7f5a

Please sign in to comment.