From b66e1bf47abb91cdf494b5eef05fda13c6d36b71 Mon Sep 17 00:00:00 2001 From: Nikita Skrynnik <93182827+NikitaSkrynnik@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:37:34 +0700 Subject: [PATCH] Add more mutexes in dial chain element to fix race conditions (#1670) * some minor change Signed-off-by: NikitaSkrynnik * add more locks Signed-off-by: NikitaSkrynnik --------- Signed-off-by: NikitaSkrynnik --- pkg/networkservice/common/dial/dialer.go | 25 +++++++++++++------ .../common/discoverforwarder/server.go | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/pkg/networkservice/common/dial/dialer.go b/pkg/networkservice/common/dial/dialer.go index 2b285f9d9..2d9e769b7 100644 --- a/pkg/networkservice/common/dial/dialer.go +++ b/pkg/networkservice/common/dial/dialer.go @@ -31,10 +31,9 @@ import ( ) type dialer struct { - ctx context.Context - cleanupContext context.Context - clientURL *url.URL - cleanupCancel context.CancelFunc + ctx context.Context + clientURL *url.URL + cleanupCancel context.CancelFunc *grpc.ClientConn dialOptions []grpc.DialOption dialTimeout time.Duration @@ -70,7 +69,10 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { } // Dial + di.mu.Lock() target := grpcutils.URLToTarget(di.clientURL) + di.mu.Unlock() + cc, err := grpc.DialContext(dialCtx, target, di.dialOptions...) if err != nil { if cc != nil { @@ -78,26 +80,32 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { } return errors.Wrapf(err, "failed to dial %s", target) } + di.mu.Lock() di.ClientConn = cc - - di.cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) + var cleanupContext context.Context + cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) + di.mu.Unlock() go func(cleanupContext context.Context, cc *grpc.ClientConn) { <-cleanupContext.Done() _ = cc.Close() - }(di.cleanupContext, cc) + }(cleanupContext, cc) return nil } func (di *dialer) Close() error { if di != nil && di.cleanupCancel != nil { + di.mu.Lock() di.cleanupCancel() + di.mu.Unlock() runtime.Gosched() } return nil } func (di *dialer) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + di.mu.Lock() + defer di.mu.Unlock() if di.ClientConn == nil { return errors.New("no dialer.ClientConn found") } @@ -105,6 +113,9 @@ func (di *dialer) Invoke(ctx context.Context, method string, args, reply interfa } func (di *dialer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + di.mu.Lock() + defer di.mu.Unlock() + if di.ClientConn == nil { return nil, errors.New("no dialer.ClientConn found") } diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index 0c583956b..285a55f5d 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -1,6 +1,6 @@ // Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 //