From deec9fc26a8965c8aa8d525b70bb6f2f6f5fdca6 Mon Sep 17 00:00:00 2001 From: Yichen Wang <18348405+Aiee@users.noreply.github.com> Date: Fri, 10 Feb 2023 18:23:04 +0800 Subject: [PATCH] Retry to get session when session is invalid in `SessionPool` (#252) * Update readme for v3.4 release (#250) * Fix format * Update README * Retry to get session when session is invalid * Simplify client test * Add comments * Add comments --- client_test.go | 20 +++++----------- configs.go | 32 +++++++++++++------------ session.go | 4 ++-- session_pool.go | 56 +++++++++++++++++++++++++++++++++++++++++++- session_pool_test.go | 44 ++++++++++++++++++++++++++++++++++ 5 files changed, 124 insertions(+), 32 deletions(-) diff --git a/client_test.go b/client_test.go index 1b2546c1..3367062f 100644 --- a/client_test.go +++ b/client_test.go @@ -1202,15 +1202,11 @@ func TestReconnect(t *testing.T) { defer pool.Close() // Create session - var sessionList []*Session - - for i := 0; i < 3; i++ { - session, err := pool.GetSession(username, password) - if err != nil { - t.Errorf("fail to create a new session from connection pool, %s", err.Error()) - } - sessionList = append(sessionList, session) + session, err := pool.GetSession(username, password) + if err != nil { + t.Errorf("fail to create a new session from connection pool, %s", err.Error()) } + defer session.Release() // Send query to server periodically for i := 0; i < timeoutConfig.MaxConnPoolSize; i++ { @@ -1221,7 +1217,7 @@ func TestReconnect(t *testing.T) { if i == 7 { stopContainer(t, "nebula-docker-compose_graphd1_1") } - _, err := sessionList[0].Execute("SHOW HOSTS;") + _, err := session.Execute("SHOW HOSTS;") fmt.Println("Sending query...") if err != nil { @@ -1230,7 +1226,7 @@ func TestReconnect(t *testing.T) { } } - resp, err := sessionList[0].Execute("SHOW HOSTS;") + resp, err := session.Execute("SHOW HOSTS;") if err != nil { t.Fatalf(err.Error()) return @@ -1240,10 +1236,6 @@ func TestReconnect(t *testing.T) { startContainer(t, "nebula-docker-compose_graphd_1") startContainer(t, "nebula-docker-compose_graphd1_1") - for i := 0; i < len(sessionList); i++ { - sessionList[i].Release() - } - // Wait for graphd to be up time.Sleep(5 * time.Second) } diff --git a/configs.go b/configs.go index 6263e89e..e6e16ba0 100644 --- a/configs.go +++ b/configs.go @@ -109,12 +109,13 @@ func openAndReadFile(path string) ([]byte, error) { // SessionPoolConf is the configs of a session pool // Note that the space name is bound to the session pool for its lifetime type SessionPoolConf struct { - username string // username for authentication - password string // password for authentication - serviceAddrs []HostAddress // service addresses for session pool - hostIndex int // index of the host in ServiceAddrs that the next new session will connect to - spaceName string // The space name that all sessions in the pool are bound to - sslConfig *tls.Config // Optional SSL config for the connection + username string // username for authentication + password string // password for authentication + serviceAddrs []HostAddress // service addresses for session pool + hostIndex int // index of the host in ServiceAddrs that the next new session will connect to + spaceName string // The space name that all sessions in the pool are bound to + sslConfig *tls.Config // Optional SSL config for the connection + retryGetSessionTimes int // The max times to retry get new session when executing a query // Basic pool configs // Socket timeout and Socket connection timeout, unit: seconds @@ -138,15 +139,16 @@ func NewSessionPoolConf( spaceName string, opts ...SessionPoolConfOption) (*SessionPoolConf, error) { // Set default values for basic pool configs newPoolConf := SessionPoolConf{ - username: username, - password: password, - serviceAddrs: serviceAddrs, - spaceName: spaceName, - timeOut: 0 * time.Millisecond, - idleTime: 0 * time.Millisecond, - maxSize: 30, - minSize: 1, - hostIndex: 0, + username: username, + password: password, + serviceAddrs: serviceAddrs, + spaceName: spaceName, + retryGetSessionTimes: 1, + timeOut: 0 * time.Millisecond, + idleTime: 0 * time.Millisecond, + maxSize: 30, + minSize: 1, + hostIndex: 0, } // Iterate the given options and apply them to the config. diff --git a/session.go b/session.go index 3b611b2f..8cf23e43 100644 --- a/session.go +++ b/session.go @@ -198,7 +198,7 @@ func (session *Session) ExecuteJsonWithParameter(stmt string, params map[string] } func (session *Session) reConnect() error { - newconnection, err := session.connPool.getIdleConn() + newConnection, err := session.connPool.getIdleConn() if err != nil { err = fmt.Errorf(err.Error()) return err @@ -206,7 +206,7 @@ func (session *Session) reConnect() error { // Release connection to pool session.connPool.release(session.connection) - session.connection = newconnection + session.connection = newConnection return nil } diff --git a/session_pool.go b/session_pool.go index 5883d77e..c1ba19c1 100644 --- a/session_pool.go +++ b/session_pool.go @@ -12,10 +12,12 @@ import ( "container/list" "crypto/tls" "fmt" + "strconv" "sync" "time" "github.com/vesoft-inc/nebula-go/v3/nebula" + "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) // SessionPool is a pool that manages sessions internally. @@ -119,7 +121,15 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int } // Execute the query - resp, err := session.connection.executeWithParameter(session.sessionID, stmt, paramsMap) + execFunc := func(s *Session) (*graph.ExecutionResponse, error) { + resp, err := s.connection.executeWithParameter(s.sessionID, stmt, paramsMap) + if err != nil { + return nil, err + } + return resp, nil + } + + resp, err := pool.executeWithRetry(session, execFunc, pool.conf.retryGetSessionTimes) if err != nil { return nil, err } @@ -385,6 +395,50 @@ func (pool *SessionPool) getIdleSession() (*Session, error) { " session pool and the total session count has reached the limit") } +// retryGetSession tries to create a new session when the current session is invalid. +func (pool *SessionPool) executeWithRetry( + session *Session, + f func(*Session) (*graph.ExecutionResponse, error), + retry int) (*graph.ExecutionResponse, error) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + + resp, err := f(session) + if err != nil { + pool.removeSessionFromList(&pool.activeSessions, session) + return nil, err + } + + if resp.ErrorCode == nebula.ErrorCode_SUCCEEDED { + return resp, nil + } else if ErrorCode(resp.ErrorCode) != ErrorCode_E_SESSION_INVALID { // only retry when the session is invalid + return resp, err + } + + // remove invalid session regardless of the retry is successful or not + defer pool.removeSessionFromList(&pool.activeSessions, session) + // If the session is invalid, close it and get a new session + for i := 0; i < retry; i++ { + pool.log.Info("retry to get sessions") + newSession, err := pool.newSession() + if err != nil { + return nil, err + } + + pingErr := newSession.Ping() + if pingErr != nil { + pool.log.Error("failed to ping the session, error: " + pingErr.Error()) + continue + } + pool.log.Info("retry to get sessions successfully") + pool.addSessionToList(&pool.activeSessions, newSession) + + return f(newSession) + } + pool.log.Error(fmt.Sprintf("failed to get session after " + strconv.Itoa(retry) + " retries")) + return nil, fmt.Errorf("failed to get session after %d retries", retry) +} + // startCleaner starts sessionCleaner if idleTime > 0. func (pool *SessionPool) startCleaner() { if pool.conf.idleTime > 0 && pool.cleanerChan == nil { diff --git a/session_pool_test.go b/session_pool_test.go index cd4c5006..5b36e1ce 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -319,6 +319,50 @@ func TestIdleSessionCleaner(t *testing.T) { sessionPool.conf.minSize, sessionPool.GetTotalSessionCount()) } +func TestRetryGetSession(t *testing.T) { + err := prepareSpace("client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace("client_test") + + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.minSize = 2 + config.maxSize = 2 + config.retryGetSessionTimes = 1 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + + // kill all sessions in the cluster + resultSet, err := sessionPool.Execute("SHOW SESSIONS | KILL SESSIONS $-.SessionId") + if err != nil { + t.Fatal(err) + } + assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s", + resultSet.GetErrorCode(), resultSet.GetErrorMsg())) + + // execute query, it should retry to get session + resultSet, err = sessionPool.Execute("SHOW HOSTS;") + if err != nil { + t.Fatal(err) + } + assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s", + resultSet.GetErrorCode(), resultSet.GetErrorMsg())) +} + func BenchmarkConcurrency(b *testing.B) { err := prepareSpace("client_test") if err != nil {