diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index 863ecd4c87..e1cb1a9a86 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -70,6 +70,7 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList // Reset the pairingPurge. // This happens only after an entire epoch. so its impossible to have session connected to the old purged list + csm.closePurgedUnusedPairingsConnections() // this must be before updating csm.pairingPurge as we want to close the connections of older sessions (prev 2 epochs) csm.pairingPurge = csm.pairing csm.pairing = make(map[string]*ConsumerSessionsWithProvider, pairingListLength) for idx, provider := range pairingList { @@ -81,6 +82,17 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList return nil } +// After 2 epochs we need to close all open connections. +// otherwise golang garbage collector is not closing network connections and they +// will remain open forever. +func (csm *ConsumerSessionManager) closePurgedUnusedPairingsConnections() { + for _, purgedPairing := range csm.pairingPurge { + for _, endpoint := range purgedPairing.Endpoints { + endpoint.connection.Close() + } + } +} + func (csm *ConsumerSessionManager) validAddressesLen() int { csm.lock.RLock() defer csm.lock.RUnlock() diff --git a/protocol/lavasession/consumer_session_manager_test.go b/protocol/lavasession/consumer_session_manager_test.go index db211a5617..75ffbcbdb2 100644 --- a/protocol/lavasession/consumer_session_manager_test.go +++ b/protocol/lavasession/consumer_session_manager_test.go @@ -12,8 +12,11 @@ import ( "github.com/lavanet/lava/protocol/provideroptimizer" "github.com/lavanet/lava/utils" + pairingtypes "github.com/lavanet/lava/x/pairing/types" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/wrapperspb" ) const ( @@ -516,3 +519,19 @@ func TestContext(t *testing.T) { require.Equal(t, ctxTO.Err(), context.DeadlineExceeded) cancel() } + +func TestGrpcClientHang(t *testing.T) { + ctx := context.Background() + s := createGRPCServer(t) // create a grpcServer so we can connect to its endpoint and validate everything works. + defer s.Stop() // stop the server when finished. + conn, err := grpc.DialContext(ctx, grpcListener, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + require.NoError(t, err) + client := pairingtypes.NewRelayerClient(conn) + err = conn.Close() + require.NoError(t, err) + err = conn.Close() + require.Error(t, err) + _, err = client.Probe(ctx, &wrapperspb.UInt64Value{}) + fmt.Println(err) + require.Error(t, err) +} diff --git a/protocol/lavasession/consumer_types.go b/protocol/lavasession/consumer_types.go index b59dc51ffb..dbb4d745c6 100644 --- a/protocol/lavasession/consumer_types.go +++ b/protocol/lavasession/consumer_types.go @@ -13,6 +13,7 @@ import ( "github.com/lavanet/lava/utils" pairingtypes "github.com/lavanet/lava/x/pairing/types" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" ) @@ -59,6 +60,7 @@ type Endpoint struct { NetworkAddress string // change at the end to NetworkAddress Enabled bool Client *pairingtypes.RelayerClient + connection *grpc.ClientConn ConnectionRefusals uint64 } @@ -184,18 +186,18 @@ func (cswp *ConsumerSessionsWithProvider) decreaseUsedComputeUnits(cu uint64) er return nil } -func (cswp *ConsumerSessionsWithProvider) connectRawClientWithTimeout(ctx context.Context, addr string) (*pairingtypes.RelayerClient, error) { +func (cswp *ConsumerSessionsWithProvider) connectRawClientWithTimeout(ctx context.Context, addr string) (*pairingtypes.RelayerClient, *grpc.ClientConn, error) { connectCtx, cancel := context.WithTimeout(ctx, TimeoutForEstablishingAConnection) defer cancel() conn, err := grpc.DialContext(connectCtx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil { - return nil, err + return nil, nil, err } /*defer conn.Close()*/ c := pairingtypes.NewRelayerClient(conn) - return &c, nil + return &c, conn, nil } func (cswp *ConsumerSessionsWithProvider) getConsumerSessionInstanceFromEndpoint(endpoint *Endpoint, numberOfResets uint64) (singleConsumerSession *SingleConsumerSession, pairingEpoch uint64, err error) { @@ -262,8 +264,11 @@ func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSes if !endpoint.Enabled { continue } - if endpoint.Client == nil { - conn, err := cswp.connectRawClientWithTimeout(ctx, endpoint.NetworkAddress) + connectEndpoint := func(cswp *ConsumerSessionsWithProvider, ctx context.Context, endpoint *Endpoint) (connected_ bool) { + if endpoint.Client != nil && endpoint.connection.GetState() != connectivity.Shutdown { + return true + } + client, conn, err := cswp.connectRawClientWithTimeout(ctx, endpoint.NetworkAddress) if err != nil { endpoint.ConnectionRefusals++ utils.LavaFormatError("error connecting to provider", err, utils.Attribute{Key: "provider endpoint", Value: endpoint.NetworkAddress}, utils.Attribute{Key: "provider address", Value: cswp.PublicLavaAddress}, utils.Attribute{Key: "endpoint", Value: endpoint}) @@ -271,10 +276,29 @@ func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSes endpoint.Enabled = false utils.LavaFormatWarning("disabling provider endpoint for the duration of current epoch.", nil, utils.Attribute{Key: "Endpoint", Value: endpoint.NetworkAddress}, utils.Attribute{Key: "address", Value: cswp.PublicLavaAddress}) } - continue + return false } endpoint.ConnectionRefusals = 0 - endpoint.Client = conn + endpoint.Client = client + if endpoint.connection != nil { + endpoint.connection.Close() // just to be safe + } + endpoint.connection = conn + return true + } + if endpoint.Client == nil { + connected_ := connectEndpoint(cswp, ctx, endpoint) + if !connected_ { + continue + } + } else if endpoint.connection.GetState() == connectivity.Shutdown { + // connection was shut down, so we need to create a new one + endpoint.connection.Close() + endpoint.Client = nil + connected_ := connectEndpoint(cswp, ctx, endpoint) + if !connected_ { + continue + } } cswp.Endpoints[idx] = endpoint return true, endpoint, false