From ae3ded99195799967aa8b8da395d11da730dab48 Mon Sep 17 00:00:00 2001 From: NikitaSkrynnik Date: Tue, 24 Dec 2024 12:48:57 +1100 Subject: [PATCH] use a proper context for test + fix data races in dialer Signed-off-by: NikitaSkrynnik --- pkg/networkservice/chains/nsmgr/vl3_test.go | 12 ++++-------- pkg/networkservice/common/dial/dialer.go | 9 +++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pkg/networkservice/chains/nsmgr/vl3_test.go b/pkg/networkservice/chains/nsmgr/vl3_test.go index c28280f7c..261d16465 100644 --- a/pkg/networkservice/chains/nsmgr/vl3_test.go +++ b/pkg/networkservice/chains/nsmgr/vl3_test.go @@ -100,15 +100,11 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { for i := 0; i < 10; i++ { nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) - reqCtx, reqClose := context.WithTimeout(ctx, time.Second*1) - defer reqClose() - req := defaultRequest(nsReg.Name) req.Connection.Id = uuid.New().String() - req.Connection.Labels["podName"] = nscName + fmt.Sprint(i) - resp, err := nsc.Request(reqCtx, req) + resp, err := nsc.Request(ctx, req) require.NoError(t, err) req.Connection = resp.Clone() @@ -117,15 +113,15 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1") - resp, err = nsc.Request(reqCtx, req) + resp, err = nsc.Request(ctx, req) require.NoError(t, err) requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1") - _, err = nsc.Close(reqCtx, resp) + _, err = nsc.Close(ctx, resp) require.NoError(t, err) - _, err = resolver.LookupIP(reqCtx, "ip4", nscName+fmt.Sprint(i)+".vl3") + _, err = resolver.LookupIP(ctx, "ip4", nscName+fmt.Sprint(i)+".vl3") require.Error(t, err) } } diff --git a/pkg/networkservice/common/dial/dialer.go b/pkg/networkservice/common/dial/dialer.go index b0abe5d14..119e125bb 100644 --- a/pkg/networkservice/common/dial/dialer.go +++ b/pkg/networkservice/common/dial/dialer.go @@ -20,6 +20,7 @@ import ( "context" "net/url" "runtime" + "sync" "time" "github.com/pkg/errors" @@ -37,6 +38,8 @@ type dialer struct { *grpc.ClientConn dialOptions []grpc.DialOption dialTimeout time.Duration + + mu sync.Mutex } func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer { @@ -74,7 +77,9 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { } return errors.Wrapf(err, "failed to dial %s", target) } + di.mu.Lock() di.ClientConn = cc + di.mu.Unlock() di.cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) @@ -94,6 +99,8 @@ func (di *dialer) Close() error { } func (di *dialer) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + di.mu.Lock() + defer di.mu.Unlock() if di.ClientConn == nil { return errors.New("no dialer.ClientConn found") } @@ -101,6 +108,8 @@ func (di *dialer) Invoke(ctx context.Context, method string, args, reply interfa } func (di *dialer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + di.mu.Lock() + defer di.mu.Unlock() if di.ClientConn == nil { return nil, errors.New("no dialer.ClientConn found") }