Skip to content

Commit

Permalink
use a proper context for test + fix data races in dialer
Browse files Browse the repository at this point in the history
Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>
  • Loading branch information
NikitaSkrynnik committed Dec 24, 2024
1 parent 55b1f2b commit ae3ded9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
12 changes: 4 additions & 8 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,11 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
for i := 0; i < 10; i++ {
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken)

reqCtx, reqClose := context.WithTimeout(ctx, time.Second*1)
defer reqClose()

req := defaultRequest(nsReg.Name)
req.Connection.Id = uuid.New().String()

req.Connection.Labels["podName"] = nscName + fmt.Sprint(i)

resp, err := nsc.Request(reqCtx, req)
resp, err := nsc.Request(ctx, req)
require.NoError(t, err)

req.Connection = resp.Clone()
Expand All @@ -117,15 +113,15 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {

requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1")

resp, err = nsc.Request(reqCtx, req)
resp, err = nsc.Request(ctx, req)
require.NoError(t, err)

requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1")

_, err = nsc.Close(reqCtx, resp)
_, err = nsc.Close(ctx, resp)
require.NoError(t, err)

_, err = resolver.LookupIP(reqCtx, "ip4", nscName+fmt.Sprint(i)+".vl3")
_, err = resolver.LookupIP(ctx, "ip4", nscName+fmt.Sprint(i)+".vl3")
require.Error(t, err)
}
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/networkservice/common/dial/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"net/url"
"runtime"
"sync"
"time"

"github.com/pkg/errors"
Expand All @@ -37,6 +38,8 @@ type dialer struct {
*grpc.ClientConn
dialOptions []grpc.DialOption
dialTimeout time.Duration

mu sync.Mutex
}

func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer {
Expand Down Expand Up @@ -74,7 +77,9 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error {
}
return errors.Wrapf(err, "failed to dial %s", target)
}
di.mu.Lock()
di.ClientConn = cc
di.mu.Unlock()

di.cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx)

Expand All @@ -94,13 +99,17 @@ func (di *dialer) Close() error {
}

func (di *dialer) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
di.mu.Lock()
defer di.mu.Unlock()
if di.ClientConn == nil {
return errors.New("no dialer.ClientConn found")
}
return di.ClientConn.Invoke(ctx, method, args, reply, opts...)
}

func (di *dialer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
di.mu.Lock()
defer di.mu.Unlock()
if di.ClientConn == nil {
return nil, errors.New("no dialer.ClientConn found")
}
Expand Down

0 comments on commit ae3ded9

Please sign in to comment.