Skip to content

Commit

Permalink
pipeline queue for client requests
Browse files Browse the repository at this point in the history
  • Loading branch information
zenkovev committed Jan 6, 2025
1 parent 76593f3 commit de3f868
Show file tree
Hide file tree
Showing 2 changed files with 461 additions and 83 deletions.
220 changes: 162 additions & 58 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgconn

import (
"container/list"
"context"
"crypto/md5"
"crypto/tls"
Expand Down Expand Up @@ -1408,9 +1409,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co

// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
ctx context.Context
pipeline *Pipeline
pgConn *PgConn
ctx context.Context

rr *ResultReader

Expand Down Expand Up @@ -1443,12 +1443,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.closed = true
if mrr.pipeline != nil {
mrr.pipeline.expectedReadyForQueryCount--
} else {
mrr.pgConn.contextWatcher.Unwatch()
mrr.pgConn.unlock()
}
mrr.pgConn.contextWatcher.Unwatch()
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
mrr.err = ErrorResponseToPgError(msg)
}
Expand Down Expand Up @@ -1672,7 +1668,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
pgErr := ErrorResponseToPgError(msg)
if rr.pipeline != nil {
rr.pipeline.state.HandleError(pgErr)
}
rr.concludeCommand(CommandTag{}, pgErr)
}

return msg, nil
Expand Down Expand Up @@ -1999,9 +1999,7 @@ type Pipeline struct {
conn *PgConn
ctx context.Context

expectedReadyForQueryCount int
pendingSync bool

state pipelineState
err error
closed bool
}
Expand All @@ -2012,6 +2010,122 @@ type PipelineSync struct{}
// CloseComplete is returned by GetResults when a CloseComplete message is received.
type CloseComplete struct{}

type pipelineRequestType int

const (
PIPELINE_NIL pipelineRequestType = iota
PIPELINE_PREPARE
PIPELINE_QUERY_PARAMS
PIPELINE_QUERY_PREPARED
PIPELINE_DEALLOCATE
PIPELINE_SYNC_REQUEST
PIPELINE_FLUSH_REQUEST
)

type pipelineRequestEvent struct {
RequestType pipelineRequestType
WasSentToServer bool
BeforeFlushOrSync bool
}

type pipelineState struct {
requestEventQueue list.List
lastRequestType pipelineRequestType
pgErr *PgError
expectedReadyForQueryCount int
}

func (s *pipelineState) Init() {
s.requestEventQueue.Init()
s.lastRequestType = PIPELINE_NIL
}

func (s *pipelineState) RegisterSendingToServer() {
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
val := elem.Value.(pipelineRequestEvent)
if val.WasSentToServer {
return
}
val.WasSentToServer = true
elem.Value = val
}
}

func (s *pipelineState) registerFlushingBufferOnServer() {
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
val := elem.Value.(pipelineRequestEvent)
if val.BeforeFlushOrSync {
return
}
val.BeforeFlushOrSync = true
elem.Value = val
}
}

func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
if req == PIPELINE_NIL {
return
}

if req != PIPELINE_FLUSH_REQUEST {
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
}
if req == PIPELINE_FLUSH_REQUEST || req == PIPELINE_SYNC_REQUEST {
s.registerFlushingBufferOnServer()
}
s.lastRequestType = req

if req == PIPELINE_SYNC_REQUEST {
s.expectedReadyForQueryCount++
}
}

func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
for {
elem := s.requestEventQueue.Front()
if elem == nil {
return PIPELINE_NIL
}
val := elem.Value.(pipelineRequestEvent)
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
return PIPELINE_NIL
}

s.requestEventQueue.Remove(elem)
if val.RequestType == PIPELINE_SYNC_REQUEST {
s.pgErr = nil
}
if s.pgErr == nil {
return val.RequestType
}
}
}

func (s *pipelineState) HandleError(err *PgError) {
s.pgErr = err
}

func (s *pipelineState) HandleReadyForQuery() {
s.expectedReadyForQueryCount--
}

func (s *pipelineState) PendingSync() bool {
var notPendingSync bool

if elem := s.requestEventQueue.Back(); elem != nil {
val := elem.Value.(pipelineRequestEvent)
notPendingSync = (val.RequestType == PIPELINE_SYNC_REQUEST) && val.WasSentToServer
} else {
notPendingSync = (s.lastRequestType == PIPELINE_SYNC_REQUEST) || (s.lastRequestType == PIPELINE_NIL)
}

return !notPendingSync
}

func (s *pipelineState) ExpectedReadyForQuery() int {
return s.expectedReadyForQueryCount
}

// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
Expand All @@ -2020,16 +2134,21 @@ type CloseComplete struct{}
// Prefer ExecBatch when only sending one group of queries at once.
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
if err := pgConn.lock(); err != nil {
return &Pipeline{
pipeline := &Pipeline{
closed: true,
err: err,
}
pipeline.state.Init()

return pipeline
}

pgConn.pipeline = Pipeline{
conn: pgConn,
ctx: ctx,
}
pgConn.pipeline.state.Init()

pipeline := &pgConn.pipeline

if ctx != context.Background() {
Expand All @@ -2052,45 +2171,45 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
p.state.PushBackRequestType(PIPELINE_PREPARE)
}

// SendDeallocate deallocates a prepared statement.
func (p *Pipeline) SendDeallocate(name string) {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
p.state.PushBackRequestType(PIPELINE_DEALLOCATE)
}

// SendQueryParams is the pipeline version of *PgConn.QueryParams.
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
p.state.PushBackRequestType(PIPELINE_QUERY_PARAMS)
}

// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
p.state.PushBackRequestType(PIPELINE_QUERY_PREPARED)
}

// SendFlushRequest sends a request for the server to flush its output buffer.
Expand All @@ -2104,9 +2223,24 @@ func (p *Pipeline) SendFlushRequest() {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.Send(&pgproto3.Flush{})
p.state.PushBackRequestType(PIPELINE_FLUSH_REQUEST)
}

// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
// without flushing the send buffer. This serves as the delimiter of an implicit
// transaction and an error recovery point.
//
// Note that the request is not itself flushed to the server automatically; use Flush if
// necessary. This copies the behavior of libpq PQsendPipelineSync.
func (p *Pipeline) SendPipelineSync() {
if p.closed {
return
}

p.conn.frontend.SendSync(&pgproto3.Sync{})
p.state.PushBackRequestType(PIPELINE_SYNC_REQUEST)
}

// Flush flushes the queued requests without establishing a synchronization point.
Expand All @@ -2131,28 +2265,14 @@ func (p *Pipeline) Flush() error {
return err
}

p.state.RegisterSendingToServer()
return nil
}

// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}

p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
return err
}

p.pendingSync = false
p.expectedReadyForQueryCount++

return nil
p.SendPipelineSync()
return p.Flush()
}

// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
Expand All @@ -2166,30 +2286,13 @@ func (p *Pipeline) GetResults() (results any, err error) {
return nil, errors.New("pipeline closed")
}

if p.expectedReadyForQueryCount == 0 {
if p.state.ExtractFrontRequestType() == PIPELINE_NIL {
return nil, nil
}

return p.getResults()
}

// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription,
// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError.
//
// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush,
// without using Sync. In this case, you need to identify on your own when all results are received and
// there is no need to call the method anymore.
func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}

return p.getResults()
}

func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
Expand Down Expand Up @@ -2228,13 +2331,13 @@ func (p *Pipeline) getResults() (results any, err error) {
case *pgproto3.CloseComplete:
return &CloseComplete{}, nil
case *pgproto3.ReadyForQuery:
p.expectedReadyForQueryCount--
p.state.HandleReadyForQuery()
return &PipelineSync{}, nil
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
p.state.HandleError(pgErr)
return nil, pgErr
}

}
}

Expand Down Expand Up @@ -2264,6 +2367,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
// These should never happen here. But don't take chances that could lead to a deadlock.
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
p.state.HandleError(pgErr)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
Expand All @@ -2283,7 +2387,7 @@ func (p *Pipeline) Close() error {

p.closed = true

if p.pendingSync {
if p.state.PendingSync() {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.contextWatcher.Unwatch()
Expand All @@ -2292,7 +2396,7 @@ func (p *Pipeline) Close() error {
return p.err
}

for p.expectedReadyForQueryCount > 0 {
for p.state.ExpectedReadyForQuery() > 0 {
_, err := p.getResults()
if err != nil {
p.err = err
Expand Down
Loading

0 comments on commit de3f868

Please sign in to comment.