Skip to content

Commit

Permalink
PRT-598 lock misuse bug in provider session manager (#394)
Browse files Browse the repository at this point in the history
* added test, solved relay verification problem

* fixed the bug, added prints, and tests
  • Loading branch information
omerlavanet authored Apr 3, 2023
1 parent 1b9728b commit bd55173
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 9 deletions.
33 changes: 24 additions & 9 deletions protocol/lavasession/provider_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,22 @@ func (psm *ProviderSessionManager) IsActiveConsumer(epoch uint64, address string
return providerSessionWithConsumer, nil // no error
}

func (psm *ProviderSessionManager) getSingleSessionFromProviderSessionWithConsumer(providerSessionWithConsumer *ProviderSessionsWithConsumer, sessionId uint64, epoch uint64, relayNumber uint64) (*SingleProviderSession, error) {
if providerSessionWithConsumer.atomicReadConsumerBlocked() != notBlockListedConsumer {
return nil, utils.LavaFormatError("This consumer address is blocked.", nil, utils.Attribute{Key: "RequestedEpoch", Value: epoch}, utils.Attribute{Key: "consumer", Value: providerSessionWithConsumer.consumerAddr})
func (psm *ProviderSessionManager) getSingleSessionFromProviderSessionWithConsumer(providerSessionsWithConsumer *ProviderSessionsWithConsumer, sessionId uint64, epoch uint64, relayNumber uint64) (*SingleProviderSession, error) {
if providerSessionsWithConsumer.atomicReadConsumerBlocked() != notBlockListedConsumer {
return nil, utils.LavaFormatError("This consumer address is blocked.", nil, utils.Attribute{Key: "RequestedEpoch", Value: epoch}, utils.Attribute{Key: "consumer", Value: providerSessionsWithConsumer.consumerAddr})
}
// before getting any sessions.
singleProviderSession, err := psm.getSessionFromAnActiveConsumer(providerSessionWithConsumer, sessionId, epoch) // after getting session verify relayNum etc..
// get a single session and lock it, for error it's not locked
singleProviderSession, err := psm.getSessionFromAnActiveConsumer(providerSessionsWithConsumer, sessionId, epoch) // after getting session verify relayNum etc..
if err != nil {
return nil, utils.LavaFormatError("getSessionFromAnActiveConsumer Failure", err, utils.Attribute{Key: "RequestedEpoch", Value: epoch}, utils.Attribute{Key: "sessionId", Value: sessionId})
}
if singleProviderSession.RelayNum+1 < relayNumber { // validate relay number here, but add only in PrepareSessionForUsage
if singleProviderSession.RelayNum+1 > relayNumber { // validate relay number here, but add only in PrepareSessionForUsage
// unlock the session since we are returning an error
defer singleProviderSession.lock.Unlock()
return nil, utils.LavaFormatError("singleProviderSession.RelayNum mismatch, session out of sync", SessionOutOfSyncError, utils.Attribute{Key: "singleProviderSession.RelayNum", Value: singleProviderSession.RelayNum + 1}, utils.Attribute{Key: "request.relayNumber", Value: relayNumber})
}
// singleProviderSession is locked at this point.
return singleProviderSession, err
return singleProviderSession, nil
}

func (psm *ProviderSessionManager) getOrCreateDataReliabilitySessionWithConsumer(address string, epoch uint64, sessionId uint64, selfProviderIndex int64) (providerSessionWithConsumer *ProviderSessionsWithConsumer, err error) {
Expand Down Expand Up @@ -145,12 +147,12 @@ func (psm *ProviderSessionManager) GetSession(address string, epoch uint64, sess
return nil, InvalidEpochError
}

providerSessionWithConsumer, err := psm.IsActiveConsumer(epoch, address)
providerSessionsWithConsumer, err := psm.IsActiveConsumer(epoch, address)
if err != nil {
return nil, err
}

return psm.getSingleSessionFromProviderSessionWithConsumer(providerSessionWithConsumer, sessionId, epoch, relayNumber)
return psm.getSingleSessionFromProviderSessionWithConsumer(providerSessionsWithConsumer, sessionId, epoch, relayNumber)
}

func (psm *ProviderSessionManager) registerNewConsumer(consumerAddr string, epoch uint64, maxCuForConsumer uint64, selfProviderIndex int64) (*ProviderSessionsWithConsumer, error) {
Expand Down Expand Up @@ -229,6 +231,14 @@ func (psm *ProviderSessionManager) ReportConsumer() (address string, epoch uint6

// OnSessionDone unlocks the session gracefully, this happens when session finished with an error
func (psm *ProviderSessionManager) OnSessionFailure(singleProviderSession *SingleProviderSession) (err error) {
if !psm.IsValidEpoch(singleProviderSession.PairingEpoch) {
// the single provider session is no longer valid, so do not do a onSessionFailure, we don;t want it racing with cleanup touching other objects
utils.LavaFormatWarning("epoch changed during session usage, so discarding sessionID changes on failure", nil,
utils.Attribute{Key: "sessionID", Value: singleProviderSession.SessionID},
utils.Attribute{Key: "cuSum", Value: singleProviderSession.CuSum},
utils.Attribute{Key: "PairingEpoch", Value: singleProviderSession.PairingEpoch})
return singleProviderSession.onSessionDone() // to unlock it and resume
}
return singleProviderSession.onSessionFailure()
}

Expand All @@ -245,6 +255,11 @@ func (psm *ProviderSessionManager) RPCProviderEndpoint() *RPCProviderEndpoint {
func (psm *ProviderSessionManager) UpdateEpoch(epoch uint64) {
psm.lock.Lock()
defer psm.lock.Unlock()
if epoch <= psm.blockedEpochHeight {
// this shouldn't happen, but nothing to do
utils.LavaFormatWarning("called updateEpoch with invalid epoch", nil, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "blockedEpoch", Value: psm.blockedEpochHeight})
return
}
if epoch > psm.blockDistanceForEpochValidity {
psm.blockedEpochHeight = epoch - psm.blockDistanceForEpochValidity
} else {
Expand Down
187 changes: 187 additions & 0 deletions protocol/lavasession/provider_session_manager_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package lavasession

import (
"math"
"math/rand"
"testing"
"time"

"github.com/lavanet/lava/protocol/common"
"github.com/lavanet/lava/utils"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -553,3 +557,186 @@ func TestPSMSubscribeEpochChange(t *testing.T) {
require.Empty(t, psm.subscriptionSessionsWithAllConsumers)
require.Empty(t, psm.sessionsWithAllConsumers)
}

type testSessionData struct {
currentCU uint64
inUse bool
sessionID uint64
relayNum uint64
epoch uint64
session *SingleProviderSession
history []string
}

// this test is running sessions and usage in a sync way to see integrity of behavior, opening and closing of sessions is separate
func TestPSMUsageSync(t *testing.T) {
psm := NewProviderSessionManager(&RPCProviderEndpoint{
NetworkAddress: "127.0.0.1:6666",
ChainID: "LAV1",
ApiInterface: "tendermint",
Geolocation: 1,
NodeUrls: []common.NodeUrl{{Url: "http://localhost:666"}, {Url: "ws://localhost:666/websocket"}},
}, 20)
seed := time.Now().UnixNano()
rand.Seed(seed)
utils.LavaFormatInfo("started test with randomness, to reproduce use seed", utils.Attribute{Key: "seed", Value: seed})
consumerAddress := "stub-consumer"
maxCuForConsumer := uint64(math.MaxInt64)
selfProviderIndex := int64(0)
numSessions := 5
psm.UpdateEpoch(10)
sessionsStore := initSessionStore(numSessions, 10)
sessionsStoreTooAdvanced := initSessionStore(numSessions, 15) // sessionIDs will overlap, this is intentional
// an attempt is either a valid opening, valid closing, invalid opening, erroring session, epoch too advanced usage
simulateUsageOnSessionsStore := func(attemptsNum int, sessionsStoreArg []*testSessionData, needsRegister bool) {
for attempts := 0; attempts < attemptsNum; attempts++ {
// pick scenario:
sessionIdx := rand.Intn(len(sessionsStoreArg))
sessionStoreTest := sessionsStoreArg[sessionIdx]
inUse := sessionStoreTest.inUse
if inUse {
// session is in use so either we close it or try to use and fail
choice := rand.Intn(2)
if choice == 0 {
// close it
choice = rand.Intn(2)
// proper closing or error closing
if choice == 0 {
relayCU := sessionStoreTest.session.LatestRelayCu
// proper closing
err := psm.OnSessionDone(sessionStoreTest.session)
require.NoError(t, err)
sessionStoreTest.inUse = false
sessionStoreTest.relayNum += 1
sessionStoreTest.currentCU += relayCU
sessionStoreTest.history = append(sessionStoreTest.history, ",OnSessionDone")
} else {
// error closing
err := psm.OnSessionFailure(sessionStoreTest.session)
require.NoError(t, err)
sessionStoreTest.inUse = false
sessionStoreTest.history = append(sessionStoreTest.history, ",OnSessionFailure")
}
} else {
// try to use and fail
relayNumToGet := sessionStoreTest.relayNum + uint64(rand.Intn(3))
_, err := psm.GetSession(consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, relayNumToGet)
require.Error(t, err)
require.False(t, ConsumerNotRegisteredYet.Is(err))
sessionStoreTest.history = append(sessionStoreTest.history, ",TryToUseAgain")
}
} else {
// session not in use yet, so try to use it. we have several options:
// 1. proper usage /
// 2. usage with wrong CU
// 3. usage with wrong relay number
// 4. usage with wrong epoch number
choice := rand.Intn(2)
if choice == 0 || sessionStoreTest.relayNum == 0 {
// getSession should work
session, err := psm.GetSession(consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1)
if sessionStoreTest.relayNum > 0 {
// this is not a first relay so we expect this to work
require.NoError(t, err, "sessionID %d relayNum %d storedRelayNum %d epoch %d, history %s", sessionStoreTest.sessionID, sessionStoreTest.relayNum+1, sessionStoreTest.session.RelayNum, sessionStoreTest.epoch, sessionStoreTest.history)
require.Same(t, session, sessionStoreTest.session)
sessionStoreTest.history = append(sessionStoreTest.history, ",GetSession")
} else {
// this can be a first relay or after an error, so allow not registered error
if err != nil {
// first relay
require.True(t, ConsumerNotRegisteredYet.Is(err))
require.True(t, needsRegister)
needsRegister = false
utils.LavaFormatInfo("registered session", utils.Attribute{Key: "sessionID", Value: sessionStoreTest.sessionID}, utils.Attribute{Key: "epoch", Value: sessionStoreTest.epoch})
session, err := psm.RegisterProviderSessionWithConsumer(consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1, maxCuForConsumer, selfProviderIndex)
require.NoError(t, err)
sessionStoreTest.session = session
sessionStoreTest.history = append(sessionStoreTest.history, ",RegisterGet")
} else {
sessionStoreTest.session = session
sessionStoreTest.history = append(sessionStoreTest.history, ",GetSession")
}
}
choice := rand.Intn(2)
switch choice {
case 0:
cuToUse := uint64(rand.Intn(10)) + 1
err = sessionStoreTest.session.PrepareSessionForUsage(cuToUse, cuToUse+sessionStoreTest.currentCU, sessionStoreTest.relayNum+1)
require.NoError(t, err)
sessionStoreTest.inUse = true
sessionStoreTest.history = append(sessionStoreTest.history, ",PrepareForUsage")
case 1:
cuToUse := uint64(rand.Intn(10)) + 1
cuMissing := rand.Intn(int(cuToUse)) + 1
if cuToUse+sessionStoreTest.currentCU <= uint64(cuMissing) {
cuToUse += 1
}
err = sessionStoreTest.session.PrepareSessionForUsage(cuToUse, cuToUse+sessionStoreTest.currentCU-uint64(cuMissing), sessionStoreTest.relayNum+1)
require.Error(t, err)
sessionStoreTest.history = append(sessionStoreTest.history, ",ErrCUPrepareForUsage")
}
} else {
// getSession should fail
relayNumSubs := rand.Intn(int(sessionStoreTest.relayNum) + 1) // [0,relayNum]
_, err := psm.GetSession(consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum-uint64(relayNumSubs))
require.Error(t, err, "sessionID %d relayNum %d storedRelayNum %d", sessionStoreTest.sessionID, sessionStoreTest.relayNum-uint64(relayNumSubs), sessionStoreTest.session.RelayNum)
_, err = psm.GetSession(consumerAddress, sessionStoreTest.epoch-1, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1)
require.Error(t, err)
_, err = psm.GetSession(consumerAddress, 5, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1)
require.Error(t, err)
sessionStoreTest.history = append(sessionStoreTest.history, ",ErrGet")
}
}
}
}

simulateUsageOnSessionsStore(500, sessionsStore, true)
// now repeat with epoch advancement on consumer and provider node
simulateUsageOnSessionsStore(100, sessionsStoreTooAdvanced, true)

psm.UpdateEpoch(20) // update session, still within size, so shouldn't affect anything

simulateUsageOnSessionsStore(500, sessionsStore, false)
simulateUsageOnSessionsStore(100, sessionsStoreTooAdvanced, false)

psm.UpdateEpoch(40) // update session, still within size, so shouldn't affect anything
for attempts := 0; attempts < 100; attempts++ {
// pick scenario:
sessionIdx := rand.Intn(len(sessionsStore))
sessionStoreTest := sessionsStore[sessionIdx]
inUse := sessionStoreTest.inUse
if inUse {
err := psm.OnSessionDone(sessionStoreTest.session)
require.NoError(t, err)
sessionStoreTest.inUse = false
sessionStoreTest.relayNum += 1
} else {
_, err := psm.GetSession(consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1)
require.Error(t, err)
}
}
// .IsValidEpoch(uint64(request.RelaySession.Epoch))
// .GetSession(consumerAddressString, uint64(request.Epoch), request.SessionId, request.RelayNum)
// on err: lavasession.ConsumerNotRegisteredYet.Is(err)
// // .RegisterProviderSessionWithConsumer(consumerAddressString, uint64(request.Epoch), request.SessionId, request.RelayNum, maxCuForConsumer, selfProviderIndex)
// .PrepareSessionForUsage(relayCU, request.RelaySession.CuSum, request.RelaySession.RelayNum)
// simulate error: .OnSessionFailure(relaySession)
// simulate success: .OnSessionDone(relaySession)
}

func initSessionStore(numSessions int, epoch uint64) []*testSessionData {
retSessions := make([]*testSessionData, numSessions)
for i := 0; i < numSessions; i++ {
retSessions[i] = &testSessionData{
currentCU: 0,
inUse: false,
sessionID: uint64(i) + 1,
relayNum: 0,
epoch: epoch,
session: nil,
history: []string{},
}
utils.LavaFormatInfo("session", utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "sessionID", Value: retSessions[i].sessionID})
}
return retSessions
}
1 change: 1 addition & 0 deletions protocol/lavasession/provider_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ func (sps *SingleProviderSession) onSessionFailure() error {
}

func (sps *SingleProviderSession) onSessionDone() error {
// this can be called on collected sessions, so if in the future you need to touch the parent, take this into consideration to change the OnSessionDone calls in provider_session_manager
err := sps.VerifyLock() // sps is locked
if err != nil {
return utils.LavaFormatError("sps.verifyLock() failed in onSessionDone", err)
Expand Down

0 comments on commit bd55173

Please sign in to comment.