Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize xproto. Cache portal descriptions. #696

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ var runCmd = &cobra.Command{
if cpuProfile {
// write profile
pprof.StopCPUProfile()
spqrlog.Zero.Info().Msg("writing cpu prof")
spqrlog.Zero.Info().Str("fname", pprofCpuFile.Name()).Msg("writing cpu prof")

if err := pprofCpuFile.Close(); err != nil {
spqrlog.Zero.Error().Err(err).Msg("")
}
}
if memProfile {
// write profile
spqrlog.Zero.Info().Msg("writing mem prof")
spqrlog.Zero.Info().Str("fname", pprofMemFile.Name()).Msg("writing mem prof")

if err := pprof.WriteHeapProfile(pprofMemFile); err != nil {
spqrlog.Zero.Error().Err(err).Msg("")
Expand Down
11 changes: 7 additions & 4 deletions examples/init.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
CREATE DISTRIBUTION ds1 COLUMN TYPES integer;
create distribution ds1 column types integer;

CREATE KEY RANGE krid3 FROM 21 ROUTE TO sh2 FOR DISTRIBUTION ds1;
CREATE KEY RANGE krid2 FROM 11 ROUTE TO sh1 FOR DISTRIBUTION ds1;
CREATE KEY RANGE krid1 FROM 1 ROUTE TO sh1 FOR DISTRIBUTION ds1;
alter distribution ds1 attach relation pgbench_branches distribution key bid;
alter distribution ds1 attach relation pgbench_tellers distribution key tid;
alter distribution ds1 attach relation abalance distribution key aid;
alter distribution ds1 attach relation pgbench_accounts distribution key aid;

CREATE KEY RANGE krid1 FROM 0 ROUTE TO sh1 FOR DISTRIBUTION ds1;
24 changes: 17 additions & 7 deletions router/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"

"github.com/pg-sharding/spqr/pkg/models/spqrerror"
"github.com/spaolacci/murmur3"

"github.com/jackc/pgx/v5/pgproto3"
"github.com/pg-sharding/spqr/pkg/auth"
Expand All @@ -31,6 +32,7 @@ var NotRouted = fmt.Errorf("client not routed")

type PreparedStatementMapper interface {
PreparedStatementQueryByName(name string) string
PreparedStatementQueryHashByName(name string) uint64
StorePreparedStatement(name, query string)
}

Expand Down Expand Up @@ -97,7 +99,8 @@ type PsqlClient struct {

r *route.Route

prepStmts map[string]string
prepStmts map[string]string
prepStmtsHash map[string]uint64

/* target-session-attrs */
tsa string
Expand Down Expand Up @@ -215,12 +218,13 @@ func NewPsqlClient(pgconn conn.RawConn, pt port.RouterPortType, defaultRouteBeha
session.SPQR_DISTRIBUTION: "default",
session.SPQR_DEFAULT_ROUTE_BEHAVIOUR: defaultRouteBehaviour,
},
conn: pgconn,
startupMsg: &pgproto3.StartupMessage{},
prepStmts: map[string]string{},
tsa: tsa,
defaultTsa: tsa,
rh: routehint.EmptyRouteHint{},
conn: pgconn,
startupMsg: &pgproto3.StartupMessage{},
prepStmts: map[string]string{},
prepStmtsHash: map[string]uint64{},
tsa: tsa,
defaultTsa: tsa,
rh: routehint.EmptyRouteHint{},

show_notice_messages: showNoticeMessages,
}
Expand Down Expand Up @@ -328,7 +332,9 @@ func (cl *PsqlClient) ResetAll() {
}

func (cl *PsqlClient) StorePreparedStatement(name, query string) {
hash := murmur3.Sum64([]byte(query))
cl.prepStmts[name] = query
cl.prepStmtsHash[name] = hash
}

func (cl *PsqlClient) PreparedStatementQueryByName(name string) string {
Expand All @@ -338,6 +344,10 @@ func (cl *PsqlClient) PreparedStatementQueryByName(name string) string {
return ""
}

func (cl *PsqlClient) PreparedStatementQueryHashByName(name string) uint64 {
return cl.prepStmtsHash[name]
}

func (cl *PsqlClient) ResetParam(name string) {
if val, ok := cl.startupMsg.Parameters[name]; ok {
cl.activeParamSet[name] = val
Expand Down
1 change: 1 addition & 0 deletions router/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ func TestFrontendXProto(t *testing.T) {

cl.EXPECT().StorePreparedStatement("stmtcache_1", "select 'Hello, world!'").Times(1).Return()
cl.EXPECT().PreparedStatementQueryByName("stmtcache_1").AnyTimes().Return("select 'Hello, world!'")
cl.EXPECT().PreparedStatementQueryHashByName("stmtcache_1").AnyTimes().Return(uint64(17731273590378676854))

cl.EXPECT().ServerAcquireUse().AnyTimes()
cl.EXPECT().ServerReleaseUse().AnyTimes()
Expand Down
28 changes: 28 additions & 0 deletions router/mock/client/mock_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/relay/qstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ func ProcQueryAdvanced(rst RelayStateMgr, query string, ph ProtoStateHandler, bi
// sql level prepares stmt pooling
if AdvancedPoolModeNeeded(rst) {
spqrlog.Zero.Debug().Msg("sql level prep statement pooling support is on")

rst.Client().StorePreparedStatement(st.Name, st.Query)
return nil
} else {
Expand All @@ -383,6 +384,7 @@ func ProcQueryAdvanced(rst RelayStateMgr, query string, ph ProtoStateHandler, bi
case parser.ParseStateExecute:
if AdvancedPoolModeNeeded(rst) {
// do nothing
// wtf? TODO: test and fix
rst.Client().PreparedStatementQueryByName(st.Name)
return nil
} else {
Expand Down
147 changes: 94 additions & 53 deletions router/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/pg-sharding/spqr/router/routingstate"
"github.com/pg-sharding/spqr/router/server"
"github.com/pg-sharding/spqr/router/statistics"
"github.com/spaolacci/murmur3"
"golang.org/x/exp/slices"
)

Expand Down Expand Up @@ -106,6 +105,11 @@ func InternalBufferedMessage(q pgproto3.FrontendMessage) BufferedMessage {
}
}

type PortalDesc struct {
rd *pgproto3.RowDescription
nodata *pgproto3.NoData
}

type RelayStateImpl struct {
txStatus txstatus.TXStatus
CopyActive bool
Expand Down Expand Up @@ -137,7 +141,8 @@ type RelayStateImpl struct {

execute func() error

saveBind *pgproto3.Bind
saveBind *pgproto3.Bind
savedPortalDesc map[string]PortalDesc

// buffer of messages to process on Sync request
xBuf []pgproto3.FrontendMessage
Expand All @@ -162,6 +167,7 @@ func NewRelayState(qr qrouter.QueryRouter, client client.RouterClient, manager p
maintain_params: rcfg.MaintainParams,
pgprotoDebug: rcfg.PgprotoDebug,
execute: nil,
savedPortalDesc: map[string]PortalDesc{},
}
}

Expand Down Expand Up @@ -994,7 +1000,7 @@ var MultiShardPrepStmtDeployError = fmt.Errorf("multishard prepared statement de
// TODO : unit tests
func (rst *RelayStateImpl) DeployPrepStmt(qname string) (*shard.PreparedStatementDescriptor, pgproto3.BackendMessage, error) {
query := rst.Client().PreparedStatementQueryByName(qname)
hash := murmur3.Sum64([]byte(query))
hash := rst.Client().PreparedStatementQueryHashByName(qname)

if len(rst.Client().Server().Datashards()) != 1 {
return nil, nil, MultiShardPrepStmtDeployError
Expand Down Expand Up @@ -1043,7 +1049,10 @@ func (rst *RelayStateImpl) ProcessExtendedBuffer(cmngr poolmgr.PoolMgr) error {
switch q := msg.(type) {
case *pgproto3.Parse:

hash := murmur3.Sum64([]byte(q.Query))
rst.Client().StorePreparedStatement(q.Name, q.Query)

hash := rst.Client().PreparedStatementQueryHashByName(q.Name)

spqrlog.Zero.Debug().
Str("name", q.Name).
Str("query", q.Query).
Expand All @@ -1056,7 +1065,6 @@ func (rst *RelayStateImpl) ProcessExtendedBuffer(cmngr poolmgr.PoolMgr) error {
return err
}
}
rst.Client().StorePreparedStatement(q.Name, q.Query)

fin, err := rst.PrepareRelayStepOnAnyRoute(cmngr)
if err != nil {
Expand Down Expand Up @@ -1113,7 +1121,7 @@ func (rst *RelayStateImpl) ProcessExtendedBuffer(cmngr poolmgr.PoolMgr) error {
rst.saveBind.DestinationPortal = q.DestinationPortal

rst.lastBindName = q.PreparedStatement
hash := murmur3.Sum64([]byte(rst.lastBindQuery))
hash := rst.Client().PreparedStatementQueryHashByName(q.PreparedStatement)

rst.saveBind.PreparedStatement = fmt.Sprintf("%d", hash)
rst.saveBind.ParameterFormatCodes = q.ParameterFormatCodes
Expand Down Expand Up @@ -1184,63 +1192,96 @@ func (rst *RelayStateImpl) ProcessExtendedBuffer(cmngr poolmgr.PoolMgr) error {
Str("last-bind-name", rst.lastBindName).
Msg("Describe portal")

err := rst.PrepareRelayStepOnHintRoute(cmngr, rst.bindRoute)
if err != nil {
return err
}
if cachedPd, ok := rst.savedPortalDesc[rst.lastBindName]; ok {
if cachedPd.rd != nil {
// send to the client
if err := rst.Client().Send(cachedPd.rd); err != nil {
return err
}
}
if cachedPd.nodata != nil {
// send to the client
if err := rst.Client().Send(cachedPd.nodata); err != nil {
return err
}
}
} else {

if _, _, err := rst.DeployPrepStmt(rst.lastBindName); err != nil {
return err
}
cachedPd = PortalDesc{}

// do not send saved bind twice
if rst.saveBind == nil {
// wtf?
return fmt.Errorf("failed to describe statement, stmt was never deployed")
}
err := rst.PrepareRelayStepOnHintRoute(cmngr, rst.bindRoute)
if err != nil {
return err
}

_, _, err = rst.RelayStep(rst.saveBind, false, false)
if err != nil {
return err
}
if _, _, err := rst.DeployPrepStmt(rst.lastBindName); err != nil {
return err
}

_, _, err = rst.RelayStep(q, false, false)
if err != nil {
return err
}
// do not send saved bind twice
if rst.saveBind == nil {
// wtf?
return fmt.Errorf("failed to describe statement, stmt was never deployed")
}

_, _, err = rst.RelayStep(&pgproto3.Close{
ObjectType: 'P',
}, false, false)
if err != nil {
return err
}
_, _, err = rst.RelayStep(rst.saveBind, false, false)
if err != nil {
return err
}

_, unreplied, err := rst.RelayStep(&pgproto3.Sync{}, true, false)
if err != nil {
return err
}
_, _, err = rst.RelayStep(q, false, false)
if err != nil {
return err
}

for _, msg := range unreplied {
spqrlog.Zero.Debug().Type("msg type", msg).Msg("desctibe portal unreplied message")
// https://www.postgresql.org/docs/current/protocol-flow.html
switch qq := msg.(type) {
case *pgproto3.RowDescription:
// send to the client
if err := rst.Client().Send(qq); err != nil {
return err
}
case *pgproto3.NoData:
// send to the client
if err := rst.Client().Send(qq); err != nil {
return err
_, _, err = rst.RelayStep(&pgproto3.Close{
ObjectType: 'P',
}, false, false)
if err != nil {
return err
}

_, unreplied, err := rst.RelayStep(&pgproto3.Sync{}, true, false)
if err != nil {
return err
}

for _, msg := range unreplied {
spqrlog.Zero.Debug().Type("msg type", msg).Msg("desctibe portal unreplied message")
// https://www.postgresql.org/docs/current/protocol-flow.html
switch qq := msg.(type) {
case *pgproto3.RowDescription:

cachedPd.rd = &pgproto3.RowDescription{}

cachedPd.rd.Fields = make([]pgproto3.FieldDescription, len(qq.Fields))

for i := 0; i < len(qq.Fields); i++ {
s := make([]byte, len(qq.Fields[i].Name))
copy(s, qq.Fields[i].Name)

cachedPd.rd.Fields[i] = qq.Fields[i]
cachedPd.rd.Fields[i].Name = s
}
// send to the client
if err := rst.Client().Send(qq); err != nil {
return err
}
case *pgproto3.NoData:
cpQ := *qq
cachedPd.nodata = &cpQ
// send to the client
if err := rst.Client().Send(qq); err != nil {
return err
}
default:
// error out? panic? protoc violation?
// no, just chill
}
default:
// error out? panic? protoc violation?
// no, just chill
}
}

rst.savedPortalDesc[rst.lastBindName] = cachedPd
}
} else {
spqrlog.Zero.Debug().
Uint("client", rst.Client().ID()).
Expand Down
Loading