Skip to content

Commit

Permalink
Add a timeout for Closes in begin.Server (#1650)
Browse files Browse the repository at this point in the history
* fix corner cases of the begin chain element

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>
Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* disable Test_RestartDuringRefresh

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>
Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* add fresh context

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* add extended context

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* add refreshed close context everywhere in begin

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix some unit tests

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* unskip some tests

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix golang linter issues

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* debug

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* cleanup

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix race condition

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* add unit tests

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix go linter issues

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix race conditiong

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* apply review comments

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

---------

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>
Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>
Co-authored-by: denis-tingaikin <denis.tingajkin@xored.com>
  • Loading branch information
NikitaSkrynnik and denis-tingaikin authored Aug 8, 2024
1 parent bc1d964 commit 3016313
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 23 deletions.
14 changes: 12 additions & 2 deletions pkg/networkservice/common/begin/event_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package begin

import (
"context"
"time"

"github.com/edwarnicke/serialize"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"

"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/extend"
"github.com/networkservicemesh/sdk/pkg/tools/postpone"

Expand Down Expand Up @@ -158,14 +160,16 @@ type eventFactoryServer struct {
ctxFunc func() (context.Context, context.CancelFunc)
request *networkservice.NetworkServiceRequest
returnedConnection *networkservice.Connection
closeTimeout time.Duration
afterCloseFunc func()
server networkservice.NetworkServiceServer
}

func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactoryServer {
func newEventFactoryServer(ctx context.Context, closeTimeout time.Duration, afterClose func()) *eventFactoryServer {
f := &eventFactoryServer{
server: next.Server(ctx),
initialCtxFunc: postpone.Context(ctx),
closeTimeout: closeTimeout,
}
f.updateContext(ctx)

Expand Down Expand Up @@ -231,7 +235,13 @@ func (f *eventFactoryServer) Close(opts ...Option) <-chan error {
default:
ctx, cancel := f.ctxFunc()
defer cancel()
_, err := f.server.Close(ctx, f.request.GetConnection())

c := clock.FromContext(ctx)
closeCtx, cancel := c.WithTimeout(context.Background(), f.closeTimeout)
defer cancel()

closeCtx = extend.WithValuesFromContext(closeCtx, ctx)
_, err := f.server.Close(closeCtx, f.request.GetConnection())
f.afterCloseFunc()
ch <- err
}
Expand Down
12 changes: 7 additions & 5 deletions pkg/networkservice/common/begin/event_factory_server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -142,14 +142,15 @@ func TestContextTimeout_Server(t *testing.T) {
clockMock := clockmock.New(ctx)
ctx = clock.WithClock(ctx, clockMock)

ctx, cancel = context.WithDeadline(ctx, clockMock.Now().Add(time.Second*3))
ctx, cancel = clockMock.WithDeadline(ctx, clockMock.Now().Add(time.Second*3))
defer cancel()

closeTimeout := time.Minute
eventFactoryServ := &eventFactoryServer{}
server := chain.NewNetworkServiceServer(
begin.NewServer(),
begin.NewServer(begin.WithCloseTimeout(closeTimeout)),
eventFactoryServ,
&delayedNSEServer{t: t, clock: clockMock},
&delayedNSEServer{t: t, closeTimeout: closeTimeout, clock: clockMock},
)

// Do Request
Expand Down Expand Up @@ -230,6 +231,7 @@ type delayedNSEServer struct {
t *testing.T
clock *clockmock.Mock
initialTimeout time.Duration
closeTimeout time.Duration
}

func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
Expand Down Expand Up @@ -258,7 +260,7 @@ func (d *delayedNSEServer) Close(ctx context.Context, conn *networkservice.Conne
deadline, _ := ctx.Deadline()
clockTime := clock.FromContext(ctx)

require.Equal(d.t, d.initialTimeout, clockTime.Until(deadline))
require.Equal(d.t, d.closeTimeout, clockTime.Until(deadline))

return next.Server(ctx).Close(ctx, conn)
}
15 changes: 12 additions & 3 deletions pkg/networkservice/common/begin/options.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Cisco and/or its affiliates.
// Copyright (c) 2021-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -18,11 +18,13 @@ package begin

import (
"context"
"time"
)

type option struct {
cancelCtx context.Context
reselect bool
cancelCtx context.Context
reselect bool
closeTimeout time.Duration
}

// Option - event option
Expand All @@ -41,3 +43,10 @@ func WithReselect() Option {
o.reselect = true
}
}

// WithCloseTimeout - set a custom timeout for a context in begin.Close
func WithCloseTimeout(timeout time.Duration) Option {
return func(o *option) {
o.closeTimeout = timeout
}
}
62 changes: 49 additions & 13 deletions pkg/networkservice/common/begin/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2023 Cisco and/or its affiliates.
// Copyright (c) 2021-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -18,27 +18,45 @@ package begin

import (
"context"
"time"

"github.com/edwarnicke/genericsync"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/networkservicemesh/sdk/pkg/tools/extend"
"github.com/networkservicemesh/sdk/pkg/tools/log"

"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

type beginServer struct {
genericsync.Map[string, *eventFactoryServer]
closeTimeout time.Duration
}

// NewServer - creates a new begin chain element
func NewServer() networkservice.NetworkServiceServer {
return &beginServer{}
func NewServer(opts ...Option) networkservice.NetworkServiceServer {
o := &option{
cancelCtx: context.Background(),
reselect: false,
closeTimeout: time.Minute,
}

for _, opt := range opts {
opt(o)
}

return &beginServer{
closeTimeout: o.closeTimeout,
}
}

func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (conn *networkservice.Connection, err error) {
func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
var conn *networkservice.Connection
var err error

// No connection.ID, no service
if request.GetConnection().GetId() == "" {
return nil, errors.New("request.EventFactory.Id must not be zero valued")
Expand All @@ -50,12 +68,14 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
eventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(),
newEventFactoryServer(
ctx,
b.closeTimeout,
func() {
b.Delete(request.GetRequestConnection().GetId())
},
),
)
<-eventFactoryServer.executor.AsyncExec(func() {
select {
case <-eventFactoryServer.executor.AsyncExec(func() {
currentEventFactoryServer, _ := b.Load(request.GetConnection().GetId())
if currentEventFactoryServer != eventFactoryServer {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer")
Expand Down Expand Up @@ -93,33 +113,49 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo

eventFactoryServer.returnedConnection = conn.Clone()
eventFactoryServer.updateContext(ctx)
})
}):
case <-ctx.Done():
return nil, ctx.Err()
}

return conn, err
}

func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (emp *emptypb.Empty, err error) {
func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) {
var err error
connID := conn.GetId()
// If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally
if fromContext(ctx) != nil {
return next.Server(ctx).Close(ctx, conn)
}
eventFactoryServer, ok := b.Load(conn.GetId())
eventFactoryServer, ok := b.Load(connID)
if !ok {
// If we don't have a connection to Close, just let it be
return &emptypb.Empty{}, nil
}
<-eventFactoryServer.executor.AsyncExec(func() {

select {
case <-eventFactoryServer.executor.AsyncExec(func() {
if eventFactoryServer.state != established || eventFactoryServer.request == nil {
return
}
currentServerClient, _ := b.Load(conn.GetId())
currentServerClient, _ := b.Load(connID)
if currentServerClient != eventFactoryServer {
return
}
closeCtx, cancel := context.WithTimeout(context.Background(), b.closeTimeout)
defer cancel()

// Always close with the last valid EventFactory we got
conn = eventFactoryServer.request.Connection
withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer)
emp, err = next.Server(withEventFactoryCtx).Close(withEventFactoryCtx, conn)
closeCtx = extend.WithValuesFromContext(closeCtx, withEventFactoryCtx)
_, err = next.Server(closeCtx).Close(closeCtx, conn)
eventFactoryServer.afterCloseFunc()
})
return &emptypb.Empty{}, err
}):
return &emptypb.Empty{}, err
case <-ctx.Done():
b.Delete(connID)
return nil, ctx.Err()
}
}
84 changes: 84 additions & 0 deletions pkg/networkservice/common/begin/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package begin_test

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

const (
waitTime = time.Second
)

type waitServer struct {
requestDone atomic.Int32
closeDone atomic.Int32
}

func (s *waitServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
time.Sleep(waitTime)
s.requestDone.Store(1)
return next.Server(ctx).Request(ctx, request)
}

func (s *waitServer) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) {
time.Sleep(waitTime)
s.closeDone.Store(1)
return next.Server(ctx).Close(ctx, connection)
}

func TestBeginWorksWithSmallTimeout(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
})
requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()

waitSrv := &waitServer{}
server := next.NewNetworkServiceServer(
begin.NewServer(),
waitSrv,
)

request := testRequest("id")
_, err := server.Request(requestCtx, request)
require.EqualError(t, err, context.DeadlineExceeded.Error())
require.Equal(t, int32(0), waitSrv.requestDone.Load())
require.Eventually(t, func() bool {
return waitSrv.requestDone.Load() == 1
}, waitTime*2, time.Millisecond*500)

closeCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
_, err = server.Close(closeCtx, request.Connection)
require.EqualError(t, err, context.DeadlineExceeded.Error())
require.Equal(t, int32(0), waitSrv.closeDone.Load())
require.Eventually(t, func() bool {
return waitSrv.closeDone.Load() == 1
}, waitTime*2, time.Millisecond*500)
}

0 comments on commit 3016313

Please sign in to comment.