Skip to content

Commit

Permalink
Retry to get session when session is invalid in SessionPool (#252)
Browse files Browse the repository at this point in the history
* Update readme for v3.4 release (#250)

* Fix format

* Update README

* Retry to get session when session is invalid

* Simplify client test

* Add comments

* Add comments
  • Loading branch information
Aiee authored and Sophie-Xie committed Mar 3, 2023
1 parent 2aa065d commit deec9fc
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 32 deletions.
20 changes: 6 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,11 @@ func TestReconnect(t *testing.T) {
defer pool.Close()

// Create session
var sessionList []*Session

for i := 0; i < 3; i++ {
session, err := pool.GetSession(username, password)
if err != nil {
t.Errorf("fail to create a new session from connection pool, %s", err.Error())
}
sessionList = append(sessionList, session)
session, err := pool.GetSession(username, password)
if err != nil {
t.Errorf("fail to create a new session from connection pool, %s", err.Error())
}
defer session.Release()

// Send query to server periodically
for i := 0; i < timeoutConfig.MaxConnPoolSize; i++ {
Expand All @@ -1221,7 +1217,7 @@ func TestReconnect(t *testing.T) {
if i == 7 {
stopContainer(t, "nebula-docker-compose_graphd1_1")
}
_, err := sessionList[0].Execute("SHOW HOSTS;")
_, err := session.Execute("SHOW HOSTS;")
fmt.Println("Sending query...")

if err != nil {
Expand All @@ -1230,7 +1226,7 @@ func TestReconnect(t *testing.T) {
}
}

resp, err := sessionList[0].Execute("SHOW HOSTS;")
resp, err := session.Execute("SHOW HOSTS;")
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1240,10 +1236,6 @@ func TestReconnect(t *testing.T) {
startContainer(t, "nebula-docker-compose_graphd_1")
startContainer(t, "nebula-docker-compose_graphd1_1")

for i := 0; i < len(sessionList); i++ {
sessionList[i].Release()
}

// Wait for graphd to be up
time.Sleep(5 * time.Second)
}
Expand Down
32 changes: 17 additions & 15 deletions configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ func openAndReadFile(path string) ([]byte, error) {
// SessionPoolConf is the configs of a session pool
// Note that the space name is bound to the session pool for its lifetime
type SessionPoolConf struct {
username string // username for authentication
password string // password for authentication
serviceAddrs []HostAddress // service addresses for session pool
hostIndex int // index of the host in ServiceAddrs that the next new session will connect to
spaceName string // The space name that all sessions in the pool are bound to
sslConfig *tls.Config // Optional SSL config for the connection
username string // username for authentication
password string // password for authentication
serviceAddrs []HostAddress // service addresses for session pool
hostIndex int // index of the host in ServiceAddrs that the next new session will connect to
spaceName string // The space name that all sessions in the pool are bound to
sslConfig *tls.Config // Optional SSL config for the connection
retryGetSessionTimes int // The max times to retry get new session when executing a query

// Basic pool configs
// Socket timeout and Socket connection timeout, unit: seconds
Expand All @@ -138,15 +139,16 @@ func NewSessionPoolConf(
spaceName string, opts ...SessionPoolConfOption) (*SessionPoolConf, error) {
// Set default values for basic pool configs
newPoolConf := SessionPoolConf{
username: username,
password: password,
serviceAddrs: serviceAddrs,
spaceName: spaceName,
timeOut: 0 * time.Millisecond,
idleTime: 0 * time.Millisecond,
maxSize: 30,
minSize: 1,
hostIndex: 0,
username: username,
password: password,
serviceAddrs: serviceAddrs,
spaceName: spaceName,
retryGetSessionTimes: 1,
timeOut: 0 * time.Millisecond,
idleTime: 0 * time.Millisecond,
maxSize: 30,
minSize: 1,
hostIndex: 0,
}

// Iterate the given options and apply them to the config.
Expand Down
4 changes: 2 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ func (session *Session) ExecuteJsonWithParameter(stmt string, params map[string]
}

func (session *Session) reConnect() error {
newconnection, err := session.connPool.getIdleConn()
newConnection, err := session.connPool.getIdleConn()
if err != nil {
err = fmt.Errorf(err.Error())
return err
}

// Release connection to pool
session.connPool.release(session.connection)
session.connection = newconnection
session.connection = newConnection
return nil
}

Expand Down
56 changes: 55 additions & 1 deletion session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ import (
"container/list"
"crypto/tls"
"fmt"
"strconv"
"sync"
"time"

"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
)

// SessionPool is a pool that manages sessions internally.
Expand Down Expand Up @@ -119,7 +121,15 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int
}

// Execute the query
resp, err := session.connection.executeWithParameter(session.sessionID, stmt, paramsMap)
execFunc := func(s *Session) (*graph.ExecutionResponse, error) {
resp, err := s.connection.executeWithParameter(s.sessionID, stmt, paramsMap)
if err != nil {
return nil, err
}
return resp, nil
}

resp, err := pool.executeWithRetry(session, execFunc, pool.conf.retryGetSessionTimes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -385,6 +395,50 @@ func (pool *SessionPool) getIdleSession() (*Session, error) {
" session pool and the total session count has reached the limit")
}

// retryGetSession tries to create a new session when the current session is invalid.
func (pool *SessionPool) executeWithRetry(
session *Session,
f func(*Session) (*graph.ExecutionResponse, error),
retry int) (*graph.ExecutionResponse, error) {
pool.rwLock.Lock()
defer pool.rwLock.Unlock()

resp, err := f(session)
if err != nil {
pool.removeSessionFromList(&pool.activeSessions, session)
return nil, err
}

if resp.ErrorCode == nebula.ErrorCode_SUCCEEDED {
return resp, nil
} else if ErrorCode(resp.ErrorCode) != ErrorCode_E_SESSION_INVALID { // only retry when the session is invalid
return resp, err
}

// remove invalid session regardless of the retry is successful or not
defer pool.removeSessionFromList(&pool.activeSessions, session)
// If the session is invalid, close it and get a new session
for i := 0; i < retry; i++ {
pool.log.Info("retry to get sessions")
newSession, err := pool.newSession()
if err != nil {
return nil, err
}

pingErr := newSession.Ping()
if pingErr != nil {
pool.log.Error("failed to ping the session, error: " + pingErr.Error())
continue
}
pool.log.Info("retry to get sessions successfully")
pool.addSessionToList(&pool.activeSessions, newSession)

return f(newSession)
}
pool.log.Error(fmt.Sprintf("failed to get session after " + strconv.Itoa(retry) + " retries"))
return nil, fmt.Errorf("failed to get session after %d retries", retry)
}

// startCleaner starts sessionCleaner if idleTime > 0.
func (pool *SessionPool) startCleaner() {
if pool.conf.idleTime > 0 && pool.cleanerChan == nil {
Expand Down
44 changes: 44 additions & 0 deletions session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,50 @@ func TestIdleSessionCleaner(t *testing.T) {
sessionPool.conf.minSize, sessionPool.GetTotalSessionCount())
}

func TestRetryGetSession(t *testing.T) {
err := prepareSpace("client_test")
if err != nil {
t.Fatal(err)
}
defer dropSpace("client_test")

hostAddress := HostAddress{Host: address, Port: port}
config, err := NewSessionPoolConf(
"root",
"nebula",
[]HostAddress{hostAddress},
"client_test")
if err != nil {
t.Errorf("failed to create session pool config, %s", err.Error())
}
config.minSize = 2
config.maxSize = 2
config.retryGetSessionTimes = 1

// create session pool
sessionPool, err := NewSessionPool(*config, DefaultLogger{})
if err != nil {
t.Fatal(err)
}
defer sessionPool.Close()

// kill all sessions in the cluster
resultSet, err := sessionPool.Execute("SHOW SESSIONS | KILL SESSIONS $-.SessionId")
if err != nil {
t.Fatal(err)
}
assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s",
resultSet.GetErrorCode(), resultSet.GetErrorMsg()))

// execute query, it should retry to get session
resultSet, err = sessionPool.Execute("SHOW HOSTS;")
if err != nil {
t.Fatal(err)
}
assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s",
resultSet.GetErrorCode(), resultSet.GetErrorMsg()))
}

func BenchmarkConcurrency(b *testing.B) {
err := prepareSpace("client_test")
if err != nil {
Expand Down

0 comments on commit deec9fc

Please sign in to comment.