Skip to content

Commit

Permalink
Merge pull request #430 from lavanet:CNS-374-VRF-self-index-mismatch
Browse files Browse the repository at this point in the history
CNS-374: fix vrfIndex mismatch between consumer and provider
  • Loading branch information
omerlavanet authored Apr 20, 2023
2 parents dbea6bf + 33af6db commit 6b87806
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 72 deletions.
2 changes: 1 addition & 1 deletion protocol/lavasession/end_to_end_lavasession_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestHappyFlowE2E(t *testing.T) {
require.Error(t, err)
require.True(t, ConsumerNotRegisteredYet.Is(err))
// expect session to be missing, so we need to register it for the first time
sps, err = psm.RegisterProviderSessionWithConsumer(ctx, consumerOneAddress, uint64(cs.Client.PairingEpoch), uint64(cs.SessionId), cs.RelayNum, cs.Client.MaxComputeUnits, selfProviderIndex)
sps, err = psm.RegisterProviderSessionWithConsumer(ctx, consumerOneAddress, uint64(cs.Client.PairingEpoch), uint64(cs.SessionId), cs.RelayNum, cs.Client.MaxComputeUnits, selfProviderIndex, pairedProviders)
// validate session was added
require.NotEmpty(t, psm.sessionsWithAllConsumers)
require.Nil(t, err)
Expand Down
49 changes: 33 additions & 16 deletions protocol/lavasession/provider_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type ProviderSessionManager struct {
blockDistanceForEpochValidity uint64 // sessionsWithAllConsumers with epochs older than ((latest epoch) - numberOfBlocksKeptInMemory) are deleted.
}

func (psm *ProviderSessionManager) GetProviderIndexWithConsumer(epoch uint64, consumerAddress string) (int64, error) {
func (psm *ProviderSessionManager) GetProviderIndexWithConsumer(epoch uint64, consumerAddress string) (int64, int64, error) {
providerSessionWithConsumer, err := psm.IsActiveConsumer(epoch, consumerAddress)
if err != nil {
// if consumer not active maybe it has a DR session. so check there as well
Expand All @@ -29,13 +29,13 @@ func (psm *ProviderSessionManager) GetProviderIndexWithConsumer(epoch uint64, co
if found {
drProviderSessionWithConsumer, foundDrSession := drSessionData.sessionMap[consumerAddress]
if foundDrSession {
return drProviderSessionWithConsumer.atomicReadProviderIndex(), nil
return drProviderSessionWithConsumer.atomicReadProviderIndex(), drProviderSessionWithConsumer.atomicReadPairedProviders(), nil
}
}
// we didn't find the consumer in both maps
return IndexNotFound, CouldNotFindIndexAsConsumerNotYetRegisteredError
return IndexNotFound, IndexNotFound, CouldNotFindIndexAsConsumerNotYetRegisteredError
}
return providerSessionWithConsumer.atomicReadProviderIndex(), nil
return providerSessionWithConsumer.atomicReadProviderIndex(), providerSessionWithConsumer.atomicReadPairedProviders(), nil
}

// reads cs.BlockedEpoch atomically
Expand Down Expand Up @@ -78,7 +78,7 @@ func (psm *ProviderSessionManager) getSingleSessionFromProviderSessionWithConsum
return singleProviderSession, nil
}

func (psm *ProviderSessionManager) getOrCreateDataReliabilitySessionWithConsumer(address string, epoch uint64, sessionId uint64, selfProviderIndex int64) (providerSessionWithConsumer *ProviderSessionsWithConsumer, err error) {
func (psm *ProviderSessionManager) getOrCreateDataReliabilitySessionWithConsumer(address string, epoch uint64, sessionId uint64, selfProviderIndex, pairedProviders int64) (providerSessionWithConsumer *ProviderSessionsWithConsumer, err error) {
if mapOfDataReliabilitySessionsWithConsumer, consumerFoundInEpoch := psm.dataReliabilitySessionsWithAllConsumers[epoch]; consumerFoundInEpoch {
if providerSessionWithConsumer, consumerAddressFound := mapOfDataReliabilitySessionsWithConsumer.sessionMap[address]; consumerAddressFound {
if providerSessionWithConsumer.atomicReadConsumerBlocked() == blockListedConsumer { // we atomic read block listed so we dont need to lock the provider. (double lock is always a bad idea.)
Expand All @@ -90,6 +90,9 @@ func (psm *ProviderSessionManager) getOrCreateDataReliabilitySessionWithConsumer
if selfProviderIndex != providerSessionWithConsumer.atomicReadProviderIndex() {
return nil, ProviderIndexMisMatchError
}
if pairedProviders != providerSessionWithConsumer.atomicReadPairedProviders() {
return nil, ProviderIndexMisMatchError
}
return providerSessionWithConsumer, nil // no error
}
} else {
Expand All @@ -98,13 +101,13 @@ func (psm *ProviderSessionManager) getOrCreateDataReliabilitySessionWithConsumer
}

// If we got here, we need to create a new instance for this consumer address.
providerSessionWithConsumer = NewProviderSessionsWithConsumer(address, nil, isDataReliabilityPSWC, selfProviderIndex)
providerSessionWithConsumer = NewProviderSessionsWithConsumer(address, nil, isDataReliabilityPSWC, selfProviderIndex, pairedProviders)
psm.dataReliabilitySessionsWithAllConsumers[epoch].sessionMap[address] = providerSessionWithConsumer
return providerSessionWithConsumer, nil
}

// GetDataReliabilitySession fetches a data reliability session
func (psm *ProviderSessionManager) GetDataReliabilitySession(address string, epoch uint64, sessionId uint64, relayNumber uint64, selfProviderIndex int64) (*SingleProviderSession, error) {
func (psm *ProviderSessionManager) GetDataReliabilitySession(address string, epoch uint64, sessionId uint64, relayNumber uint64, selfProviderIndex, pairedProviders int64) (*SingleProviderSession, error) {
// validate Epoch
if !psm.IsValidEpoch(epoch) { // fast checking to see if epoch is even relevant
utils.LavaFormatError("GetSession", InvalidEpochError, utils.Attribute{Key: "RequestedEpoch", Value: epoch})
Expand All @@ -113,22 +116,34 @@ func (psm *ProviderSessionManager) GetDataReliabilitySession(address string, epo

// validate sessionId
if sessionId > DataReliabilitySessionId {
return nil, utils.LavaFormatError("request's sessionId is larger than the data reliability allowed session ID", nil, utils.Attribute{Key: "sessionId", Value: sessionId}, utils.Attribute{Key: "DataReliabilitySessionId", Value: strconv.Itoa(DataReliabilitySessionId)})
return nil, utils.LavaFormatError("request's sessionId is larger than the data reliability allowed session ID", nil,
utils.Attribute{Key: "sessionId", Value: sessionId},
utils.Attribute{Key: "DataReliabilitySessionId", Value: strconv.Itoa(DataReliabilitySessionId)},
)
}

// validate RelayNumber
if relayNumber == 0 {
return nil, utils.LavaFormatError("request's relayNumber zero, expecting consumer to increment", nil, utils.Attribute{Key: "relayNumber", Value: relayNumber}, utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber})
return nil, utils.LavaFormatError("request's relayNumber zero, expecting consumer to increment", nil,
utils.Attribute{Key: "relayNumber", Value: relayNumber},
utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber},
)
}

if relayNumber > DataReliabilityRelayNumber {
return nil, utils.LavaFormatError("request's relayNumber is larger than the DataReliabilityRelayNumber allowed in Data Reliability", nil, utils.Attribute{Key: "relayNumber", Value: relayNumber}, utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber})
return nil, utils.LavaFormatError("request's relayNumber is larger than the DataReliabilityRelayNumber allowed in Data Reliability", nil,
utils.Attribute{Key: "relayNumber", Value: relayNumber},
utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber},
)
}

// validate active consumer.
providerSessionWithConsumer, err := psm.getOrCreateDataReliabilitySessionWithConsumer(address, epoch, sessionId, selfProviderIndex)
providerSessionWithConsumer, err := psm.getOrCreateDataReliabilitySessionWithConsumer(address, epoch, sessionId, selfProviderIndex, pairedProviders)
if err != nil {
return nil, utils.LavaFormatError("getOrCreateDataReliabilitySessionWithConsumer Failed", err, utils.Attribute{Key: "relayNumber", Value: relayNumber}, utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber})
return nil, utils.LavaFormatError("getOrCreateDataReliabilitySessionWithConsumer Failed", err,
utils.Attribute{Key: "relayNumber", Value: relayNumber},
utils.Attribute{Key: "DataReliabilityRelayNumber", Value: DataReliabilityRelayNumber},
)
}

// singleProviderSession is locked after this method is called unless we got an error
Expand Down Expand Up @@ -160,7 +175,7 @@ func (psm *ProviderSessionManager) GetSession(ctx context.Context, address strin
return psm.getSingleSessionFromProviderSessionWithConsumer(ctx, providerSessionsWithConsumer, sessionId, epoch, relayNumber)
}

func (psm *ProviderSessionManager) registerNewConsumer(consumerAddr string, epoch uint64, maxCuForConsumer uint64, selfProviderIndex int64) (*ProviderSessionsWithConsumer, error) {
func (psm *ProviderSessionManager) registerNewConsumer(consumerAddr string, epoch uint64, maxCuForConsumer uint64, selfProviderIndex, pairedProviders int64) (*ProviderSessionsWithConsumer, error) {
psm.lock.Lock()
defer psm.lock.Unlock()
if !psm.IsValidEpoch(epoch) { // checking again because we are now locked and epoch cant change now.
Expand All @@ -176,17 +191,19 @@ func (psm *ProviderSessionManager) registerNewConsumer(consumerAddr string, epoc

providerSessionWithConsumer, foundAddressInMap := mapOfProviderSessionsWithConsumer.sessionMap[consumerAddr]
if !foundAddressInMap {
providerSessionWithConsumer = NewProviderSessionsWithConsumer(consumerAddr, &ProviderSessionsEpochData{MaxComputeUnits: maxCuForConsumer}, notDataReliabilityPSWC, selfProviderIndex)
epochData := &ProviderSessionsEpochData{MaxComputeUnits: maxCuForConsumer}
providerSessionWithConsumer = NewProviderSessionsWithConsumer(consumerAddr, epochData, notDataReliabilityPSWC, selfProviderIndex, pairedProviders)
mapOfProviderSessionsWithConsumer.sessionMap[consumerAddr] = providerSessionWithConsumer
}

return providerSessionWithConsumer, nil
}

func (psm *ProviderSessionManager) RegisterProviderSessionWithConsumer(ctx context.Context, consumerAddress string, epoch uint64, sessionId uint64, relayNumber uint64, maxCuForConsumer uint64, selfProviderIndex int64) (*SingleProviderSession, error) {
func (psm *ProviderSessionManager) RegisterProviderSessionWithConsumer(ctx context.Context, consumerAddress string, epoch uint64, sessionId uint64, relayNumber uint64, maxCuForConsumer uint64, selfProviderIndex, pairedProviders int64) (*SingleProviderSession, error) {
providerSessionWithConsumer, err := psm.IsActiveConsumer(epoch, consumerAddress)
if err != nil {
if ConsumerNotRegisteredYet.Is(err) {
providerSessionWithConsumer, err = psm.registerNewConsumer(consumerAddress, epoch, maxCuForConsumer, selfProviderIndex)
providerSessionWithConsumer, err = psm.registerNewConsumer(consumerAddress, epoch, maxCuForConsumer, selfProviderIndex, pairedProviders)
if err != nil {
return nil, utils.LavaFormatError("RegisterProviderSessionWithConsumer Failed to registerNewSession", err)
}
Expand Down
11 changes: 6 additions & 5 deletions protocol/lavasession/provider_session_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
epoch2 = testNumberOfBlocksKeptInMemory + epoch1
consumerOneAddress = "consumer1"
selfProviderIndex = int64(1)
pairedProviders = int64(1)
)

func initProviderSessionManager() *ProviderSessionManager {
Expand All @@ -53,7 +54,7 @@ func prepareSession(t *testing.T, ctx context.Context) (*ProviderSessionManager,
require.True(t, ConsumerNotRegisteredYet.Is(err))

// expect session to be missing, so we need to register it for the first time
sps, err = psm.RegisterProviderSessionWithConsumer(ctx, consumerOneAddress, epoch1, sessionId, relayNumber, maxCu, selfProviderIndex)
sps, err = psm.RegisterProviderSessionWithConsumer(ctx, consumerOneAddress, epoch1, sessionId, relayNumber, maxCu, selfProviderIndex, pairedProviders)

// validate session was added
require.NotEmpty(t, psm.sessionsWithAllConsumers)
Expand All @@ -79,7 +80,7 @@ func prepareDRSession(t *testing.T, ctx context.Context) (*ProviderSessionManage
psm := initProviderSessionManager()

// get data reliability session
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex)
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex, pairedProviders)

// validate results
require.Nil(t, err)
Expand Down Expand Up @@ -268,7 +269,7 @@ func TestPSMDataReliabilityTwicePerEpoch(t *testing.T) {
require.Equal(t, epoch1, sps.PairingEpoch)

// try to get a data reliability session again.
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex)
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex, pairedProviders)

// validate we cant get more than one data reliability session per epoch (might change in the future)
require.Error(t, err)
Expand Down Expand Up @@ -308,7 +309,7 @@ func TestPSMDataReliabilityRetryAfterFailure(t *testing.T) {
require.Equal(t, epoch1, sps.PairingEpoch)

// try to get a data reliability session again.
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex)
sps, err := psm.GetDataReliabilitySession(consumerOneAddress, epoch1, dataReliabilitySessionId, relayNumber, selfProviderIndex, pairedProviders)

// validate we can get a data reliability session if we failed before
require.Nil(t, err)
Expand Down Expand Up @@ -660,7 +661,7 @@ func TestPSMUsageSync(t *testing.T) {
require.True(t, needsRegister)
needsRegister = false
utils.LavaFormatInfo("registered session", utils.Attribute{Key: "sessionID", Value: sessionStoreTest.sessionID}, utils.Attribute{Key: "epoch", Value: sessionStoreTest.epoch})
session, err := psm.RegisterProviderSessionWithConsumer(ctx, consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1, maxCuForConsumer, selfProviderIndex)
session, err := psm.RegisterProviderSessionWithConsumer(ctx, consumerAddress, sessionStoreTest.epoch, sessionStoreTest.sessionID, sessionStoreTest.relayNum+1, maxCuForConsumer, selfProviderIndex, pairedProviders)
require.NoError(t, err)
sessionStoreTest.session = session
sessionStoreTest.history = append(sessionStoreTest.history, ",RegisterGet")
Expand Down
9 changes: 8 additions & 1 deletion protocol/lavasession/provider_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,28 @@ type ProviderSessionsWithConsumer struct {
epochData *ProviderSessionsEpochData
Lock sync.RWMutex
isDataReliability uint32 // 0 is false, 1 is true. set to uint so we can atomically read
pairedProviders int64
selfProviderIndex int64
}

func NewProviderSessionsWithConsumer(consumerAddr string, epochData *ProviderSessionsEpochData, isDataReliability uint32, selfProviderIndex int64) *ProviderSessionsWithConsumer {
func NewProviderSessionsWithConsumer(consumerAddr string, epochData *ProviderSessionsEpochData, isDataReliability uint32, selfProviderIndex, pairedProviders int64) *ProviderSessionsWithConsumer {
pswc := &ProviderSessionsWithConsumer{
Sessions: map[uint64]*SingleProviderSession{},
isBlockListed: 0,
consumerAddr: consumerAddr,
epochData: epochData,
isDataReliability: isDataReliability,
pairedProviders: pairedProviders,
selfProviderIndex: selfProviderIndex,
}
return pswc
}

// reads the pairedProviders data atomically for DR
func (pswc *ProviderSessionsWithConsumer) atomicReadPairedProviders() int64 {
return atomic.LoadInt64(&pswc.pairedProviders)
}

// reads the selfProviderIndex data atomically for DR
func (pswc *ProviderSessionsWithConsumer) atomicReadProviderIndex() int64 {
return atomic.LoadInt64(&pswc.selfProviderIndex)
Expand Down
2 changes: 1 addition & 1 deletion protocol/rpcprovider/rpcprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type ProviderStateTrackerInf interface {
SendVoteCommitment(voteID string, vote *reliabilitymanager.VoteData) error
LatestBlock() int64
GetVrfPkAndMaxCuForUser(ctx context.Context, consumerAddress string, chainID string, epocu uint64) (vrfPk *utils.VrfPubKey, maxCu uint64, err error)
VerifyPairing(ctx context.Context, consumerAddress string, providerAddress string, epoch uint64, chainID string) (valid bool, index int64, err error)
VerifyPairing(ctx context.Context, consumerAddress string, providerAddress string, epoch uint64, chainID string) (valid bool, index, total int64, err error)
GetProvidersCountForConsumer(ctx context.Context, consumerAddress string, epoch uint64, chainID string) (uint32, error)
GetEpochSize(ctx context.Context) (uint64, error)
EarliestBlockInMemory(ctx context.Context) (uint64, error)
Expand Down
Loading

0 comments on commit 6b87806

Please sign in to comment.