diff --git a/.golangci.yml b/.golangci.yml index d9b2bacd1..3478094af 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -35,7 +35,7 @@ linters-settings: dupl: threshold: 150 funlen: - Lines: 100 + Lines: 110 Statements: 50 goconst: min-len: 2 @@ -219,10 +219,6 @@ issues: linters: - interfacer text: "can be `fmt.Stringer`" - - path: pkg/networkservice/chains/nsmgr/peertracker/server.go - linters: - - interfacer - text: "can be `fmt.Stringer`" - path: pkg/networkservice/core/trace/client.go linters: - dupl diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index f8eaf6466..4c4953f21 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -21,12 +21,14 @@ import ( "testing" "time" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "go.uber.org/goleak" "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/registry/chains/client" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" ) @@ -71,7 +73,7 @@ func testNSMGRHealEndpoint(t *testing.T, nodeNum int) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -148,7 +150,7 @@ func testNSMGRHealForwarder(t *testing.T, nodeNum int) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) counter := new(count.Server) @@ -230,7 +232,7 @@ func testNSMGRHealNSMgr(t *testing.T, nodeNum int, restored bool) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -300,7 +302,7 @@ func TestNSMGR_HealRegistry(t *testing.T) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -366,14 +368,14 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { SetRegistryProxySupplier(nil) if withNSEExpiration { - builder = builder.SetRegistryExpiryDuration(sandbox.RegistryExpiryDuration) + builder = builder.SetRegistryExpiryDuration(time.Second / 2) } domain := builder.Build() nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseCtx, nseCtxCancel := context.WithCancel(ctx) @@ -407,7 +409,18 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { if withNSEExpiration { // 3.1 Wait for the endpoint expiration - time.Sleep(sandbox.RegistryExpiryDuration) + time.Sleep(time.Second) + c := client.NewNetworkServiceEndpointRegistryClient(ctx, domain.Nodes[0].NSMgr.URL, client.WithDialOptions(sandbox.DialOptions(sandbox.WithTokenGenerator(sandbox.GenerateTestToken))...)) + + stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: "final-endpoint", + }, + }) + + require.NoError(t, err) + + require.Len(t, registry.ReadNetworkServiceEndpointList(stream), 0) } // 4. Close connection @@ -415,7 +428,12 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { nscCtxCancel() + for _, fwd := range domain.Nodes[0].Forwarders { + fwd.Cancel() + } + require.Eventually(t, func() bool { + logrus.Error(goleak.Find()) return goleak.Find(ignoreCurrent) == nil }, timeout, tick) diff --git a/pkg/networkservice/chains/nsmgr/peertracker/server.go b/pkg/networkservice/chains/nsmgr/peertracker/server.go deleted file mode 100644 index 3dc6e94ff..000000000 --- a/pkg/networkservice/chains/nsmgr/peertracker/server.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2020 Cisco and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package peertracker provides a wrapper for a Nsmgr that tracks connections received from local Clients -// Its designed to be used in a DevicePlugin to allow us to properly Close connections on re-Allocate -package peertracker - -import ( - "context" - "net/url" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/networkservice" - "google.golang.org/grpc/peer" - - "github.com/edwarnicke/serialize" - - "github.com/networkservicemesh/sdk/pkg/networkservice/chains/nsmgr" -) - -const ( - unixScheme = "unix" -) - -type peerTrackerServer struct { - nsmgr.Nsmgr - executor serialize.Executor - // Outer map is peer url.URL.String(), inner map key is Connection.Id - connections map[string]map[string]*networkservice.Connection -} - -// NewServer - Creates a new peer tracker Server -// inner - Nsmgr being wrapped -// closeAll - pointer to memory location to which you should write a pointer to the function to be called -// to close all connections for the provided url (presuming a unix URL) -func NewServer(inner nsmgr.Nsmgr, closeAll *func(ctx context.Context, u *url.URL)) nsmgr.Nsmgr { - rv := &peerTrackerServer{ - connections: make(map[string]map[string]*networkservice.Connection), - Nsmgr: inner, - } - *closeAll = rv.closeAllConnectionsForPeer - return rv -} - -func (p *peerTrackerServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - conn, err := p.Nsmgr.Request(ctx, request) - if err != nil { - return nil, err - } - mypeer, ok := peer.FromContext(ctx) - if ok { - if mypeer.Addr.Network() == unixScheme { - u := &url.URL{ - Scheme: mypeer.Addr.Network(), - Path: mypeer.Addr.String(), - } - p.executor.AsyncExec(func() { - _, ok := p.connections[u.String()] - if !ok { - p.connections[u.String()] = make(map[string]*networkservice.Connection) - } - p.connections[u.String()][conn.GetId()] = conn - }) - } - } - return conn, nil -} - -func (p *peerTrackerServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - _, err := p.Nsmgr.Close(ctx, conn) - if err != nil { - return nil, err - } - mypeer, ok := peer.FromContext(ctx) - if ok { - if mypeer.Addr.Network() == unixScheme { - u := &url.URL{ - Scheme: mypeer.Addr.Network(), - Path: mypeer.Addr.String(), - } - p.executor.AsyncExec(func() { - delete(p.connections[u.String()], conn.GetId()) - }) - } - } - return &empty.Empty{}, nil -} - -func (p *peerTrackerServer) closeAllConnectionsForPeer(ctx context.Context, u *url.URL) { - finishedChan := make(chan struct{}) - <-p.executor.AsyncExec(func() { - if connMap, ok := p.connections[u.String()]; ok { - for _, conn := range connMap { - _, _ = p.Close(ctx, conn) - } - } - close(finishedChan) - }) -} diff --git a/pkg/networkservice/chains/nsmgr/server.go b/pkg/networkservice/chains/nsmgr/server.go index 6f63c8458..27557f328 100644 --- a/pkg/networkservice/chains/nsmgr/server.go +++ b/pkg/networkservice/chains/nsmgr/server.go @@ -43,18 +43,21 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/metrics" "github.com/networkservicemesh/sdk/pkg/networkservice/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" registryclientinfo "github.com/networkservicemesh/sdk/pkg/registry/common/clientinfo" "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" registryconnect "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/expire" "github.com/networkservicemesh/sdk/pkg/registry/common/localbypass" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" registryrecvfd "github.com/networkservicemesh/sdk/pkg/registry/common/recvfd" registrysendfd "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" + "github.com/networkservicemesh/sdk/pkg/registry/switchcase" - registryserialize "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" registryadapter "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" - registrychain "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -160,38 +163,62 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options var nsRegistry = memory.NewNetworkServiceRegistryServer() if opts.regURL != nil { // Use remote registry - nsRegistry = registrychain.NewNetworkServiceRegistryServer( - clienturl.NewNetworkServiceRegistryServer(opts.regURL), - registryconnect.NewNetworkServiceRegistryServer(ctx, registryconnect.WithDialOptions(opts.regDialOptions...)), + nsRegistry = registryconnect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + clienturl.NewNetworkServiceRegistryClient(opts.regURL), + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), + ), + registryconnect.NewNetworkServiceRegistryClient(), + ), ) } - nsRegistry = registrychain.NewNetworkServiceRegistryServer( - registryserialize.NewNetworkServiceRegistryServer(), + nsRegistry = chain.NewNetworkServiceRegistryServer( nsRegistry, ) - var nseInMemoryRegistry = memory.NewNetworkServiceEndpointRegistryServer() - - var nseRegistry = registrychain.NewNetworkServiceEndpointRegistryServer( + var nseRegistry = chain.NewNetworkServiceEndpointRegistryServer( registryclientinfo.NewNetworkServiceEndpointRegistryServer(), - registryserialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, time.Minute), registryrecvfd.NewNetworkServiceEndpointRegistryServer(), // Allow to receive a passed files - registrysendfd.NewNetworkServiceEndpointRegistryServer(), - nseInMemoryRegistry, + switchcase.NewNetworkServiceEndpointRegistryServer( + switchcase.NSEServerCase{ + Condition: func(c context.Context, nse *registryapi.NetworkServiceEndpoint) bool { + return opts.regURL != nil + }, + Action: registrysendfd.NewNetworkServiceEndpointRegistryServer(), + }, + ), localbypass.NewNetworkServiceEndpointRegistryServer(opts.url), + switchcase.NewNetworkServiceEndpointRegistryServer( + switchcase.NSEServerCase{ + Condition: func(c context.Context, nse *registryapi.NetworkServiceEndpoint) bool { + return opts.regURL == nil + }, + Action: memory.NewNetworkServiceEndpointRegistryServer(), + }, + switchcase.NSEServerCase{ + Condition: func(c context.Context, nse *registryapi.NetworkServiceEndpoint) bool { + return opts.regURL != nil + }, + Action: registryconnect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clienturl.NewNetworkServiceEndpointRegistryClient(opts.regURL), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), + ), + registryconnect.NewNetworkServiceEndpointRegistryClient(), + ), + ), + }), ) - if opts.regURL != nil { - // Add remote registry - nseRegistry = registrychain.NewNetworkServiceEndpointRegistryServer( - nseRegistry, - clienturl.NewNetworkServiceEndpointRegistryServer(opts.regURL), - registryconnect.NewNetworkServiceEndpointRegistryServer(ctx, registryconnect.WithDialOptions(opts.regDialOptions...)), - ) - } - // Construct Endpoint rv.Endpoint = endpoint.NewServer(ctx, tokenGenerator, endpoint.WithName(opts.name), @@ -200,8 +227,9 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options adapters.NewClientToServer(clientinfo.NewClient()), discoverforwarder.NewServer( registryadapter.NetworkServiceServerToClient(nsRegistry), - registryadapter.NetworkServiceEndpointServerToClient(nseInMemoryRegistry), + registryadapter.NetworkServiceEndpointServerToClient(nseRegistry), discoverforwarder.WithForwarderServiceName(opts.forwarderServiceName), + discoverforwarder.WithNSMgrURL(opts.url), ), excludedprefixes.NewServer(ctx), recvfd.NewServer(), // Receive any files passed diff --git a/pkg/networkservice/chains/nsmgr/single_test.go b/pkg/networkservice/chains/nsmgr/single_test.go index 127cab4f4..cd3392562 100644 --- a/pkg/networkservice/chains/nsmgr/single_test.go +++ b/pkg/networkservice/chains/nsmgr/single_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -50,7 +50,7 @@ func Test_DNSUsecase(t *testing.T) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -126,7 +126,7 @@ func Test_ShouldParseNetworkServiceLabelsTemplate(t *testing.T) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg := defaultRegistryService() + nsReg := defaultRegistryService(t.Name()) nsReg.Matches = []*registry.Match{ { Routes: []*registry.Destination{ @@ -176,7 +176,7 @@ func Test_UsecasePoint2MultiPoint(t *testing.T) { SetNodeSetup(func(ctx context.Context, node *sandbox.Node, _ int) { node.NewNSMgr(ctx, "nsmgr", nil, sandbox.GenerateTestToken, nsmgr.NewServer) }). - SetRegistryExpiryDuration(sandbox.RegistryExpiryDuration). + SetRegistryExpiryDuration(time.Second). Build() domain.Nodes[0].NewForwarder(ctx, ®istry.NetworkServiceEndpoint{ @@ -297,7 +297,7 @@ func Test_RemoteUsecase_Point2MultiPoint(t *testing.T) { SetNodeSetup(func(ctx context.Context, node *sandbox.Node, _ int) { node.NewNSMgr(ctx, "nsmgr", nil, sandbox.GenerateTestToken, nsmgr.NewServer) }). - SetRegistryExpiryDuration(sandbox.RegistryExpiryDuration). + SetRegistryExpiryDuration(time.Second). Build() for i := 0; i < nodeCount; i++ { diff --git a/pkg/networkservice/chains/nsmgr/suite_test.go b/pkg/networkservice/chains/nsmgr/suite_test.go index 36b8189c4..5c1489f39 100644 --- a/pkg/networkservice/chains/nsmgr/suite_test.go +++ b/pkg/networkservice/chains/nsmgr/suite_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -92,7 +92,7 @@ func (s *nsmgrSuite) Test_Remote_ParallelUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -144,7 +144,7 @@ func (s *nsmgrSuite) Test_SelectsRestartingEndpointUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -199,15 +199,17 @@ func (s *nsmgrSuite) Test_Remote_BusyEndpointsUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) counter := new(count.Server) + const nseCount = 3 + var wg sync.WaitGroup - var nseRegs [4]*registry.NetworkServiceEndpoint - var nses [4]*sandbox.EndpointEntry - for i := 0; i < 3; i++ { + var nseRegs [nseCount + 1]*registry.NetworkServiceEndpoint + var nses [nseCount + 1]*sandbox.EndpointEntry + for i := 0; i < nseCount; i++ { wg.Add(1) go func(id int) { nseRegs[id] = defaultRegistryEndpoint(nsReg.Name) @@ -225,10 +227,10 @@ func (s *nsmgrSuite) Test_Remote_BusyEndpointsUsecase() { wg.Wait() time.Sleep(time.Second / 2) - nseRegs[3] = defaultRegistryEndpoint(nsReg.Name) - nseRegs[3].Name += strconv.Itoa(3) + nseRegs[nseCount] = defaultRegistryEndpoint(nsReg.Name) + nseRegs[nseCount].Name += strconv.Itoa(3) - nses[3] = s.domain.Nodes[1].NewEndpoint(ctx, nseRegs[3], sandbox.GenerateTestToken, counter) + nses[nseCount] = s.domain.Nodes[1].NewEndpoint(ctx, nseRegs[nseCount], sandbox.GenerateTestToken, counter) }() nsc := s.domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) @@ -268,7 +270,7 @@ func (s *nsmgrSuite) Test_RemoteUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -312,7 +314,7 @@ func (s *nsmgrSuite) Test_ConnectToDeadNSEUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -361,7 +363,7 @@ func (s *nsmgrSuite) Test_LocalUsecase() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -607,6 +609,23 @@ func (s *nsmgrSuite) Test_PassThroughSameSourceSelector() { } } +func (s *nsmgrSuite) Test_ShouldCleanAllClientAndEndpointGoroutines() { + t := s.T() + t.Cleanup(func() { goleak.VerifyNone(t, goleak.IgnoreCurrent()) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) + require.NoError(t, err) + + // At this moment all possible endless NSMgr goroutines have been started. So we expect all newly created goroutines + // to be canceled no later than some of these events: + // 1. GRPC request context cancel + // 2. NSC connection close + // 3. NSE unregister + testNSEAndClient(ctx, t, s.domain, defaultRegistryEndpoint(nsReg.Name)) +} func (s *nsmgrSuite) Test_PassThroughLocalUsecaseMultiLabel() { t := s.T() @@ -672,24 +691,6 @@ func (s *nsmgrSuite) Test_PassThroughLocalUsecaseMultiLabel() { } } -func (s *nsmgrSuite) Test_ShouldCleanAllClientAndEndpointGoroutines() { - t := s.T() - t.Cleanup(func() { goleak.VerifyNone(t, goleak.IgnoreCurrent()) }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - nsReg, err := s.nsRegistryClient.Register(ctx, defaultRegistryService()) - require.NoError(t, err) - - // At this moment all possible endless NSMgr goroutines have been started. So we expect all newly created goroutines - // to be canceled no later than some of these events: - // 1. GRPC request context cancel - // 2. NSC connection close - // 3. NSE unregister - testNSEAndClient(ctx, t, s.domain, defaultRegistryEndpoint(nsReg.Name)) -} - const ( step = "step" labelA = "label_a" diff --git a/pkg/networkservice/chains/nsmgr/unix_test.go b/pkg/networkservice/chains/nsmgr/unix_test.go index e3c1c4cc6..95087f195 100644 --- a/pkg/networkservice/chains/nsmgr/unix_test.go +++ b/pkg/networkservice/chains/nsmgr/unix_test.go @@ -50,7 +50,7 @@ func Test_Local_NoURLUsecase(t *testing.T) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) @@ -124,7 +124,7 @@ func Test_MultiForwarderSendfd(t *testing.T) { nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) - nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService()) + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) nseReg := defaultRegistryEndpoint(nsReg.Name) diff --git a/pkg/networkservice/chains/nsmgr/utils_test.go b/pkg/networkservice/chains/nsmgr/utils_test.go index 1055643e0..b267dd516 100644 --- a/pkg/networkservice/chains/nsmgr/utils_test.go +++ b/pkg/networkservice/chains/nsmgr/utils_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,7 +20,6 @@ import ( "context" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/networkservicemesh/api/pkg/api/networkservice" @@ -31,9 +30,9 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/sandbox" ) -func defaultRegistryService() *registry.NetworkService { +func defaultRegistryService(name string) *registry.NetworkService { return ®istry.NetworkService{ - Name: "ns-" + uuid.New().String(), + Name: name, } } diff --git a/pkg/networkservice/chains/nsmgrproxy/server.go b/pkg/networkservice/chains/nsmgrproxy/server.go index a30681e52..e4e007d7f 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server.go +++ b/pkg/networkservice/chains/nsmgrproxy/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -35,15 +35,15 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/connect" "github.com/networkservicemesh/sdk/pkg/networkservice/common/discover" - "github.com/networkservicemesh/sdk/pkg/networkservice/common/interdomainurl" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/interdomainbypass" "github.com/networkservicemesh/sdk/pkg/networkservice/common/swapip" "github.com/networkservicemesh/sdk/pkg/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" registryconnect "github.com/networkservicemesh/sdk/pkg/registry/common/connect" - "github.com/networkservicemesh/sdk/pkg/registry/common/proxy" - "github.com/networkservicemesh/sdk/pkg/registry/common/seturl" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" registryswapip "github.com/networkservicemesh/sdk/pkg/registry/common/swapip" - registryadapter "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/tools/fs" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" @@ -65,13 +65,12 @@ type nsmgrProxyServer struct { } type serverOptions struct { - name string - mapipFilePath string - listenOn *url.URL - authorizeServer networkservice.NetworkServiceServer - dialOptions []grpc.DialOption - dialTimeout time.Duration - registryConnectOptions []registryconnect.Option + name string + mapipFilePath string + listenOn *url.URL + authorizeServer networkservice.NetworkServiceServer + dialOptions []grpc.DialOption + dialTimeout time.Duration } func (s *serverOptions) openMapIPChannel(ctx context.Context) <-chan map[string]string { @@ -116,13 +115,6 @@ func WithAuthorizeServer(authorizeServer networkservice.NetworkServiceServer) Op } } -// WithRegistryConnectOptions sets registry connect options -func WithRegistryConnectOptions(connectOptions ...registryconnect.Option) Option { - return func(o *serverOptions) { - o.registryConnectOptions = connectOptions - } -} - // WithListenOn sets current listenOn url func WithListenOn(u *url.URL) Option { return func(o *serverOptions) { @@ -165,27 +157,34 @@ func NewServer(ctx context.Context, regURL, proxyURL *url.URL, tokenGenerator to opt(opts) } - var nseStockServer registryapi.NetworkServiceEndpointRegistryServer + var interdomainBypassNSEServer registryapi.NetworkServiceEndpointRegistryServer - nseClient := registryadapter.NetworkServiceEndpointServerToClient( - chain.NewNetworkServiceEndpointRegistryServer( - clienturl.NewNetworkServiceEndpointRegistryServer(regURL), - registryconnect.NewNetworkServiceEndpointRegistryServer(ctx, opts.registryConnectOptions...), + nseClient := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clienturl.NewNetworkServiceEndpointRegistryClient(regURL), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), + dial.WithDialTimeout(opts.dialTimeout), ), + registryconnect.NewNetworkServiceEndpointRegistryClient(), ) - nsClient := registryadapter.NetworkServiceServerToClient( - chain.NewNetworkServiceRegistryServer( - clienturl.NewNetworkServiceRegistryServer(regURL), - registryconnect.NewNetworkServiceRegistryServer(ctx, opts.registryConnectOptions...), + nsClient := chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clienturl.NewNetworkServiceRegistryClient(regURL), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), ), + registryconnect.NewNetworkServiceRegistryClient(), ) rv.Endpoint = endpoint.NewServer(ctx, tokenGenerator, endpoint.WithName(opts.name), endpoint.WithAuthorizeServer(opts.authorizeServer), endpoint.WithAdditionalFunctionality( - interdomainurl.NewServer(&nseStockServer), + interdomainbypass.NewServer(&interdomainBypassNSEServer, opts.listenOn), discover.NewServer(nsClient, nseClient), swapip.NewServer(opts.openMapIPChannel(ctx)), connect.NewServer( @@ -200,17 +199,33 @@ func NewServer(ctx context.Context, regURL, proxyURL *url.URL, tokenGenerator to ), ) - var nsServerChain = chain.NewNetworkServiceRegistryServer( - proxy.NewNetworkServiceRegistryServer(proxyURL), - registryconnect.NewNetworkServiceRegistryServer(ctx, opts.registryConnectOptions...), + var nsServerChain = registryconnect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clienturl.NewNetworkServiceRegistryClient(proxyURL), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), + ), + registryconnect.NewNetworkServiceRegistryClient(), + ), ) var nseServerChain = chain.NewNetworkServiceEndpointRegistryServer( - proxy.NewNetworkServiceEndpointRegistryServer(proxyURL), - seturl.NewNetworkServiceEndpointRegistryServer(opts.listenOn), + begin.NewNetworkServiceEndpointRegistryServer(), + clienturl.NewNetworkServiceEndpointRegistryServer(proxyURL), + interdomainBypassNSEServer, registryswapip.NewNetworkServiceEndpointRegistryServer(opts.openMapIPChannel(ctx)), - nseStockServer, - registryconnect.NewNetworkServiceEndpointRegistryServer(ctx, opts.registryConnectOptions...), + registryconnect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(opts.dialOptions...), + dial.WithDialTimeout(opts.dialTimeout), + ), + registryconnect.NewNetworkServiceEndpointRegistryClient(), + ), + ), ) rv.Registry = registry.NewServer(nsServerChain, nseServerChain) diff --git a/pkg/networkservice/chains/nsmgrproxy/server_test.go b/pkg/networkservice/chains/nsmgrproxy/server_test.go index 6657b47d2..fe2397291 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server_test.go +++ b/pkg/networkservice/chains/nsmgrproxy/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -38,6 +38,7 @@ import ( kernelmech "github.com/networkservicemesh/sdk/pkg/networkservice/common/mechanisms/kernel" "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" ) @@ -266,6 +267,18 @@ func TestNSMGR_FloatingInterdomainUseCase(t *testing.T) { cluster2.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) + c := adapters.NetworkServiceEndpointServerToClient(cluster2.Nodes[0].NSMgr.NetworkServiceEndpointRegistryServer()) + + s, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: "final-endpoint@" + floating.Name, + }}) + + require.NoError(t, err) + + list := registry.ReadNetworkServiceEndpointList(s) + + require.Len(t, list, 1) + nsc := cluster1.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) request := &networkservice.NetworkServiceRequest{ diff --git a/pkg/networkservice/common/begin/client.go b/pkg/networkservice/common/begin/client.go index fc52238b7..48d8803c6 100644 --- a/pkg/networkservice/common/begin/client.go +++ b/pkg/networkservice/common/begin/client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -62,7 +62,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo currentEventFactoryClient, _ := b.LoadOrStore(request.GetConnection().GetId(), eventFactoryClient) if currentEventFactoryClient != eventFactoryClient { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") - conn, err = b.Request(ctx, request) + conn, err = b.Request(ctx, request, opts...) return } diff --git a/pkg/networkservice/common/discoverforwarder/option.go b/pkg/networkservice/common/discoverforwarder/option.go index ec8b8c438..c16a7f078 100644 --- a/pkg/networkservice/common/discoverforwarder/option.go +++ b/pkg/networkservice/common/discoverforwarder/option.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -26,3 +26,10 @@ func WithForwarderServiceName(serviceName string) Option { d.forwarderServiceName = serviceName } } + +// WithNSMgrURL sets URL for NSE Find queriees +func WithNSMgrURL(nsmgrURL string) Option { + return func(d *discoverForwarderServer) { + d.nsmgrURL = nsmgrURL + } +} diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index 9a5d2ab14..4077a6b55 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -36,6 +36,7 @@ type discoverForwarderServer struct { nseClient registry.NetworkServiceEndpointRegistryClient nsClient registry.NetworkServiceRegistryClient forwarderServiceName string + nsmgrURL string } // NewServer creates new instance of discoverforwarder networkservice.NetworkServiceServer. @@ -77,6 +78,7 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks NetworkServiceNames: []string{ d.forwarderServiceName, }, + Url: d.nsmgrURL, }, }) @@ -91,9 +93,11 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks return nil, errors.New("no candidates found") } + var candidatesErr = errors.New("all forwarders have failed") + // TODO: Should we consider about load balancing? // https://github.com/networkservicemesh/sdk/issues/790 - for _, candidate := range nses { + for i, candidate := range nses { u, err := url.Parse(candidate.Url) if err != nil { @@ -108,13 +112,15 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks return resp, nil } logger.Errorf("forwarder=%v url=%v returned error=%v", candidate.Name, candidate.Url, err.Error()) + candidatesErr = errors.Wrapf(candidatesErr, "%v. An error during select forwawrder %v --> %v", i, candidate.Name, err.Error()) } - return nil, errors.New("all forwarders failed") + return nil, candidatesErr } stream, err := d.nseClient.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ Name: forwarderName, + Url: d.nsmgrURL, }, }) @@ -155,6 +161,7 @@ func (d *discoverForwarderServer) Close(ctx context.Context, conn *networkservic stream, err := d.nseClient.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ Name: forwarderName, + Url: d.nsmgrURL, }, }) diff --git a/pkg/networkservice/common/interdomainbypass/server.go b/pkg/networkservice/common/interdomainbypass/server.go new file mode 100644 index 000000000..3ceaf7873 --- /dev/null +++ b/pkg/networkservice/common/interdomainbypass/server.go @@ -0,0 +1,80 @@ +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package interdomainbypass injects into incoming context the URL to remote side only if requesting endpoint has been resolved. +package interdomainbypass + +import ( + "context" + "net/url" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/registry/common/interdomainbypass" + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/interdomain" + "github.com/networkservicemesh/sdk/pkg/tools/stringurl" +) + +type interdomainBypassServer struct { + m stringurl.Map +} + +// NewServer - returns a new NetworkServiceServer that injects the URL to remote side into context on requesting resolved endpoint +func NewServer(rs *registry.NetworkServiceEndpointRegistryServer, listenOn *url.URL) networkservice.NetworkServiceServer { + var rv = new(interdomainBypassServer) + *rs = interdomainbypass.NewNetworkServiceEndpointRegistryServer(&rv.m, listenOn) + return rv +} + +func (n *interdomainBypassServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + u, ok := n.m.Load(request.Connection.NetworkServiceEndpointName) + // Always true when we are on local nsmgr proxy side. + // True on theremote nsmgr proxy side when it is floating interdomain usecase. + if ok { + ctx = clienturlctx.WithClientURL(ctx, u) + return next.Server(ctx).Request(ctx, request) + } + originalNSEName := request.GetConnection().NetworkServiceEndpointName + request.GetConnection().NetworkServiceEndpointName = interdomain.Target(originalNSEName) + resp, err := next.Server(ctx).Request(ctx, request) + if err != nil { + return nil, err + } + resp.NetworkServiceEndpointName = originalNSEName + return resp, nil +} + +func (n *interdomainBypassServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + u, ok := n.m.Load(conn.NetworkServiceEndpointName) + // Always true when we are on local nsmgr proxy side. + // True on theremote nsmgr proxy side when it is floating interdomain usecase. + if ok { + ctx = clienturlctx.WithClientURL(ctx, u) + return next.Server(ctx).Close(ctx, conn) + } + originalNSEName := conn.NetworkServiceEndpointName + conn.NetworkServiceEndpointName = interdomain.Target(originalNSEName) + resp, err := next.Server(ctx).Close(ctx, conn) + if err != nil { + return nil, err + } + conn.NetworkServiceEndpointName = originalNSEName + return resp, nil +} diff --git a/pkg/networkservice/common/interdomainurl/server_test.go b/pkg/networkservice/common/interdomainbypass/server_test.go similarity index 93% rename from pkg/networkservice/common/interdomainurl/server_test.go rename to pkg/networkservice/common/interdomainbypass/server_test.go index cd5262992..a2b5f0b3f 100644 --- a/pkg/networkservice/common/interdomainurl/server_test.go +++ b/pkg/networkservice/common/interdomainbypass/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package interdomainurl_test +package interdomainbypass_test import ( "context" @@ -27,7 +27,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/networkservicemesh/sdk/pkg/networkservice/common/interdomainurl" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/interdomainbypass" "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" @@ -46,7 +46,7 @@ func Test_StoreUrlNSEServer(t *testing.T) { var storeRegServer registry.NetworkServiceEndpointRegistryServer var s = chain.NewNetworkServiceServer( - interdomainurl.NewServer(&storeRegServer), + interdomainbypass.NewServer(&storeRegServer, new(url.URL)), checkcontext.NewServer(t, func(t *testing.T, c context.Context) { v := clienturlctx.ClientURL(c) require.NotNil(t, v) diff --git a/pkg/networkservice/common/interdomainurl/server.go b/pkg/networkservice/common/interdomainurl/server.go deleted file mode 100644 index 24b43895c..000000000 --- a/pkg/networkservice/common/interdomainurl/server.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package interdomainurl injects into incoming context the URL to remote side only if requesting endpoint has been resolved. -package interdomainurl - -import ( - "context" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/networkservice" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" - "github.com/networkservicemesh/sdk/pkg/registry/common/storeurl" - "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" - "github.com/networkservicemesh/sdk/pkg/tools/interdomain" - "github.com/networkservicemesh/sdk/pkg/tools/stringurl" -) - -type interdomainurlServer struct { - m stringurl.Map -} - -// NewServer - returns a new NetworkServiceServer that injects the URL to remote side into context on requesting resolved endpoint -func NewServer(rs *registry.NetworkServiceEndpointRegistryServer) networkservice.NetworkServiceServer { - var rv = new(interdomainurlServer) - *rs = storeurl.NewNetworkServiceEndpointRegistryServer(&rv.m) - return rv -} - -func (n *interdomainurlServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - u, ok := n.m.Load(request.Connection.NetworkServiceEndpointName) - if ok { - ctx = clienturlctx.WithClientURL(ctx, u) - originalNSEName := request.GetConnection().NetworkServiceEndpointName - request.GetConnection().NetworkServiceEndpointName = interdomain.Target(originalNSEName) - resp, err := next.Server(ctx).Request(ctx, request) - if err != nil { - return nil, err - } - resp.NetworkServiceEndpointName = originalNSEName - return resp, nil - } - - return next.Server(ctx).Request(ctx, request) -} - -func (n *interdomainurlServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - u, ok := n.m.Load(conn.NetworkServiceEndpointName) - if ok { - ctx = clienturlctx.WithClientURL(ctx, u) - originalNSEName := conn.NetworkServiceEndpointName - defer func() { - conn.NetworkServiceEndpointName = originalNSEName - }() - conn.NetworkServiceEndpointName = interdomain.Target(originalNSEName) - } - return next.Server(ctx).Close(ctx, conn) -} diff --git a/pkg/networkservice/common/roundrobin/server.go b/pkg/networkservice/common/roundrobin/server.go index 856e22763..900e60483 100644 --- a/pkg/networkservice/common/roundrobin/server.go +++ b/pkg/networkservice/common/roundrobin/server.go @@ -51,10 +51,12 @@ func (s *selectEndpointServer) Request(ctx context.Context, request *networkserv } candidates := discover.Candidates(ctx) + var candidatesErr = errors.New("all candidates have failed") + for i := 0; i < len(candidates.Endpoints); i++ { endpoint := s.selector.selectEndpoint(candidates.NetworkService, candidates.Endpoints) if endpoint == nil { - return nil, errors.Errorf("failed to find endpoint for Network Service: %v %v", candidates.NetworkService, candidates.Endpoints) + return nil, errors.Errorf("failed to select endpoint for Network Service: %v %v", candidates.NetworkService, candidates.Endpoints) } u, err := url.Parse(endpoint.Url) if err != nil { @@ -66,8 +68,9 @@ func (s *selectEndpointServer) Request(ctx context.Context, request *networkserv if err == nil { return resp, nil } + candidatesErr = errors.Wrapf(candidatesErr, "%v. An error during select endpoint %v --> %v", i, endpoint.Name, err.Error()) } - return nil, errors.Errorf("all candidates %#v fail", candidates) + return nil, candidatesErr } func (s *selectEndpointServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { diff --git a/pkg/networkservice/common/timeout/server_test.go b/pkg/networkservice/common/timeout/server_test.go index 29e98c383..6d33f9962 100644 --- a/pkg/networkservice/common/timeout/server_test.go +++ b/pkg/networkservice/common/timeout/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -235,7 +235,11 @@ func TestTimeoutServer_RefreshFailure(t *testing.T) { client := testClient( ctx, - refresh.NewClient(ctx), + next.NewNetworkServiceClient( + begin.NewClient(), + metadata.NewClient(), + refresh.NewClient(ctx), + ), next.NewNetworkServiceServer( injecterror.NewServer( injecterror.WithRequestErrorTimes(1, -1), diff --git a/pkg/registry/chains/client/ns_client.go b/pkg/registry/chains/client/ns_client.go index ce198ffa9..7b50c12b4 100644 --- a/pkg/registry/chains/client/ns_client.go +++ b/pkg/registry/chains/client/ns_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -19,10 +19,15 @@ package client import ( "context" "net/url" + "time" "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" @@ -35,19 +40,16 @@ func NewNetworkServiceRegistryClient(ctx context.Context, connectTo *url.URL, op opt(clientOpts) } - c := new(registry.NetworkServiceRegistryClient) - *c = chain.NewNetworkServiceRegistryClient( - retry.NewNetworkServiceRegistryClient(), - connect.NewNetworkServiceRegistryClient(ctx, connectTo, - connect.WithNSAdditionalFunctionality( - append( - clientOpts.nsAdditionalFunctionality, - heal.NewNetworkServiceRegistryClient(ctx, c), - )..., - ), - connect.WithDialOptions(clientOpts.dialOptions...), + return chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + retry.NewNetworkServiceRegistryClient(ctx), + heal.NewNetworkServiceRegistryClient(ctx), + clienturl.NewNetworkServiceRegistryClient(connectTo), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(clientOpts.dialOptions...), + dial.WithDialTimeout(time.Second), ), + connect.NewNetworkServiceRegistryClient(), ) - - return *c } diff --git a/pkg/registry/chains/client/nse_client.go b/pkg/registry/chains/client/nse_client.go index 81a01d7b7..03ebffdb8 100644 --- a/pkg/registry/chains/client/nse_client.go +++ b/pkg/registry/chains/client/nse_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -23,12 +23,14 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" - "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" - "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" ) @@ -39,20 +41,16 @@ func NewNetworkServiceEndpointRegistryClient(ctx context.Context, connectTo *url opt(clientOpts) } - c := new(registry.NetworkServiceEndpointRegistryClient) - *c = chain.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), - retry.NewNetworkServiceEndpointRegistryClient(), + return chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + retry.NewNetworkServiceEndpointRegistryClient(ctx), + heal.NewNetworkServiceEndpointRegistryClient(ctx), refresh.NewNetworkServiceEndpointRegistryClient(ctx), - connect.NewNetworkServiceEndpointRegistryClient(ctx, connectTo, - connect.WithNSEAdditionalFunctionality( - append( - clientOpts.nseAdditionalFunctionality, - heal.NewNetworkServiceEndpointRegistryClient(ctx, c), - sendfd.NewNetworkServiceEndpointRegistryClient())...), - connect.WithDialOptions(clientOpts.dialOptions...), + clienturl.NewNetworkServiceEndpointRegistryClient(connectTo), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(clientOpts.dialOptions...), ), + connect.NewNetworkServiceEndpointRegistryClient(), ) - - return *c } diff --git a/pkg/registry/chains/memory/server.go b/pkg/registry/chains/memory/server.go index ef3be3405..be2875547 100644 --- a/pkg/registry/chains/memory/server.go +++ b/pkg/registry/chains/memory/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -24,33 +24,89 @@ import ( "google.golang.org/grpc" + "github.com/networkservicemesh/api/pkg/api/registry" + registryserver "github.com/networkservicemesh/sdk/pkg/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/expire" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/common/proxy" - "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" "github.com/networkservicemesh/sdk/pkg/registry/common/setpayload" "github.com/networkservicemesh/sdk/pkg/registry/common/setregistrationtime" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/switchcase" + "github.com/networkservicemesh/sdk/pkg/tools/interdomain" ) // NewServer creates new registry server based on memory storage func NewServer(ctx context.Context, expiryDuration time.Duration, proxyRegistryURL *url.URL, dialOptions ...grpc.DialOption) registryserver.Registry { nseChain := chain.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), - setregistrationtime.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expiryDuration), - memory.NewNetworkServiceEndpointRegistryServer(), - proxy.NewNetworkServiceEndpointRegistryServer(proxyRegistryURL), - connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(dialOptions...)), + begin.NewNetworkServiceEndpointRegistryServer(), + switchcase.NewNetworkServiceEndpointRegistryServer(switchcase.NSEServerCase{ + Condition: func(c context.Context, nse *registry.NetworkServiceEndpoint) bool { + if interdomain.Is(nse.GetName()) { + return true + } + for _, ns := range nse.GetNetworkServiceNames() { + if interdomain.Is(ns) { + return true + } + } + return false + }, + Action: chain.NewNetworkServiceEndpointRegistryServer( + connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clienturl.NewNetworkServiceEndpointRegistryClient(proxyRegistryURL), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(dialOptions...), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ), + ), + }, + switchcase.NSEServerCase{ + Condition: func(c context.Context, nse *registry.NetworkServiceEndpoint) bool { return true }, + Action: chain.NewNetworkServiceEndpointRegistryServer( + setregistrationtime.NewNetworkServiceEndpointRegistryServer(), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expiryDuration), + memory.NewNetworkServiceEndpointRegistryServer(), + ), + }, + ), ) nsChain := chain.NewNetworkServiceRegistryServer( - serialize.NewNetworkServiceRegistryServer(), setpayload.NewNetworkServiceRegistryServer(), - memory.NewNetworkServiceRegistryServer(), - proxy.NewNetworkServiceRegistryServer(proxyRegistryURL), - connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(dialOptions...)), + switchcase.NewNetworkServiceRegistryServer( + switchcase.NSServerCase{ + Condition: func(c context.Context, ns *registry.NetworkService) bool { + return interdomain.Is(ns.GetName()) + }, + Action: connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + clienturl.NewNetworkServiceRegistryClient(proxyRegistryURL), + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(dialOptions...), + ), + connect.NewNetworkServiceRegistryClient(), + ), + ), + }, + switchcase.NSServerCase{ + Condition: func(c context.Context, ns *registry.NetworkService) bool { + return true + }, + Action: memory.NewNetworkServiceRegistryServer(), + }, + ), ) return registryserver.NewServer(nsChain, nseChain) diff --git a/pkg/registry/chains/proxydns/server.go b/pkg/registry/chains/proxydns/server.go index 1601ad788..9920a2da2 100644 --- a/pkg/registry/chains/proxydns/server.go +++ b/pkg/registry/chains/proxydns/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -23,7 +23,10 @@ import ( "google.golang.org/grpc" "github.com/networkservicemesh/sdk/pkg/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/dnsresolve" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" ) @@ -31,10 +34,29 @@ import ( // NewServer creates new stateless registry server that proxies queries to the second registries by DNS domains func NewServer(ctx context.Context, dnsResolver dnsresolve.Resolver, dialOptions ...grpc.DialOption) registry.Registry { nseChain := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), dnsresolve.NewNetworkServiceEndpointRegistryServer(dnsresolve.WithResolver(dnsResolver)), - connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(dialOptions...))) + connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(dialOptions...), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + )) nsChain := chain.NewNetworkServiceRegistryServer( + begin.NewNetworkServiceRegistryServer(), dnsresolve.NewNetworkServiceRegistryServer(dnsresolve.WithResolver(dnsResolver)), - connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(dialOptions...))) + connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient( + ctx, + dial.WithDialOptions(dialOptions...), + ), + connect.NewNetworkServiceRegistryClient(), + ), + )) return registry.NewServer(nsChain, nseChain) } diff --git a/pkg/registry/chains/proxydns/server_ns_test.go b/pkg/registry/chains/proxydns/server_ns_test.go index 824989730..c425cfa79 100644 --- a/pkg/registry/chains/proxydns/server_ns_test.go +++ b/pkg/registry/chains/proxydns/server_ns_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // diff --git a/pkg/registry/chains/proxydns/server_nse_test.go b/pkg/registry/chains/proxydns/server_nse_test.go index b86fd4050..ffcf79772 100644 --- a/pkg/registry/chains/proxydns/server_nse_test.go +++ b/pkg/registry/chains/proxydns/server_nse_test.go @@ -52,7 +52,7 @@ import ( func TestInterdomainNetworkServiceEndpointRegistry(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() dnsServer := new(sandbox.FakeDNSResolver) diff --git a/pkg/registry/common/begin/close_client_test.go b/pkg/registry/common/begin/close_client_test.go new file mode 100644 index 000000000..7d798bccb --- /dev/null +++ b/pkg/registry/common/begin/close_client_test.go @@ -0,0 +1,130 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "sync" + "testing" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +const ( + mark = "mark" +) + +func TestCloseClient(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + client := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + &markClient{t: t}, + ) + id := "1" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + resp, err := client.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NotNil(t, t, resp) + assert.NoError(t, err) + assert.Equal(t, mark, resp.GetNetworkServiceLabels()[mark].Labels[mark]) + resp = resp.Clone() + delete(resp.GetNetworkServiceLabels()[mark].Labels, mark) + assert.Empty(t, resp.GetNetworkServiceLabels()[mark].Labels) + _, err = client.Unregister(ctx, resp) + assert.NoError(t, err) +} + +type markClient struct { + t *testing.T +} + +func (m *markClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + if in.GetNetworkServiceLabels() == nil { + in.NetworkServiceLabels = make(map[string]*registry.NetworkServiceLabels) + } + + in.GetNetworkServiceLabels()[mark] = ®istry.NetworkServiceLabels{ + Labels: map[string]string{ + mark: mark, + }, + } + + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (m *markClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (m *markClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + assert.NotNil(m.t, in.GetNetworkServiceLabels()) + assert.Equal(m.t, mark, in.GetNetworkServiceLabels()[mark].Labels[mark]) + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +func TestDoubleCloseClient(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + client := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + &doubleCloseClient{t: t}, + ) + id := "1" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := client.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + conn = conn.Clone() + _, err = client.Unregister(ctx, conn) + assert.NoError(t, err) + _, err = client.Unregister(ctx, conn) + assert.NoError(t, err) +} + +type doubleCloseClient struct { + t *testing.T + sync.Once +} + +func (s *doubleCloseClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (s *doubleCloseClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (s *doubleCloseClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + count := 1 + s.Do(func() { + count++ + }) + assert.Equal(s.t, 2, count, "Close has been called more than once") + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} diff --git a/pkg/registry/common/begin/close_server_test.go b/pkg/registry/common/begin/close_server_test.go new file mode 100644 index 000000000..716fb11a5 --- /dev/null +++ b/pkg/registry/common/begin/close_server_test.go @@ -0,0 +1,93 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "sync" + "testing" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/null" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestCloseServer(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + adapters.NetworkServiceEndpointClientToServer(&markClient{t: t}), + ) + id := "1" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := server.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + assert.Equal(t, conn.GetNetworkServiceLabels()[mark].Labels[mark], mark) + conn = conn.Clone() + delete(conn.GetNetworkServiceLabels()[mark].Labels, mark) + assert.Zero(t, conn.GetNetworkServiceLabels()[mark].Labels[mark]) + _, err = server.Unregister(ctx, conn) + assert.NoError(t, err) +} + +func TestDoubleCloseServer(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + &doubleCloseServer{t: t, NetworkServiceEndpointRegistryServer: null.NewNetworkServiceEndpointRegistryServer()}, + ) + id := "1" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := server.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + conn = conn.Clone() + _, err = server.Unregister(ctx, conn) + assert.NoError(t, err) + _, err = server.Unregister(ctx, conn) + assert.NoError(t, err) +} + +type doubleCloseServer struct { + t *testing.T + sync.Once + registry.NetworkServiceEndpointRegistryServer +} + +func (s *doubleCloseServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*emptypb.Empty, error) { + count := 1 + s.Do(func() { + count++ + }) + assert.Equal(s.t, 2, count, "Close has been called more than once") + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, in) +} diff --git a/pkg/registry/common/begin/context.go b/pkg/registry/common/begin/context.go new file mode 100644 index 000000000..4eb3db3d2 --- /dev/null +++ b/pkg/registry/common/begin/context.go @@ -0,0 +1,61 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" +) + +// EventFactory - allows firing off a Request or Close event from midchain +type EventFactory interface { + Register(opts ...Option) <-chan error + Unregister(opts ...Option) <-chan error +} + +type connectionState int + +const ( + established = iota + 1 + closed +) + +type key struct{} + +func withEventFactory(parent context.Context, eventFactory EventFactory) context.Context { + if parent.Value(key{}) != nil { + return parent + } + ctx := context.WithValue(parent, key{}, eventFactory) + return ctx +} + +// FromContext - returns EventFactory from context +func FromContext(ctx context.Context) EventFactory { + value := fromContext(ctx) + if value == nil { + panic("EventFactory not found please add begin chain element to your chain") + } + return value +} + +func fromContext(ctx context.Context) EventFactory { + value, ok := ctx.Value(key{}).(EventFactory) + if ok { + return value + } + return nil +} diff --git a/pkg/registry/common/begin/doc.go b/pkg/registry/common/begin/doc.go new file mode 100644 index 000000000..9e4c78203 --- /dev/null +++ b/pkg/registry/common/begin/doc.go @@ -0,0 +1,101 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package begin provides a chain element that can be put at the beginning of the chain, after Connection.Id has been set +but before any chain elements that would mutate the Connection on the return path. +the begin.New{Client,Server}() guarantee: + +Scope + +All Request() or Close() events are scoped to a particular Connection, uniquely identified by its Connection.Id + +Exclusivity + +Only one event is processed for a Connection.Id at a time + +Order + +Events for a given Connection.Id are processed in the order in which they are received + +Close Correctness + +When a Close(Connection) event is received, begin will replace the Connection provided with the last Connection +successfully returned from the chain for Connection.Id + +Midchain Originated Events + +A midchain element may originate a Request() or Close() event to be processed +from the beginning of the chain (Timeout, Refresh,Heal): + + errCh := begin.FromContext(ctx).Request() + errCh := begin.FromContext(ctx).Close() + +errCh will receive any error from the firing of the event, and will be closed after the event has fully +processed. + +Note: if a chain is a server chain continued by a client chain, the beginning of the chain is at the beginning of +the server chain, even if there is a subsequent begin.NewClient() in the client chain. + +Optionally you may use the CancelContext(context.Context) option: + + begin.FromContext(ctx).Request(CancelContext(cancelContext)) + begin.FromContext(ctx).Close(CancelContext(cancelContext)) + +If cancelContext is canceled prior to the processing of the event, the event processing will be skipped, +and the errCh returned simply closed. + +Midchain Originated Request Event + +Example: + + begin.FromContext(ctx).Request() + +will use the networkservice.NetworkServiceRequest from the chain's last successfully completed Request() event +with networkservice.NetworkServiceRequest.Connection replaced with the Connection returned by the chain's last +successfully completed Request() event + +Chain Placement + +begin.New{Server/Client} should always proceed any chain element which: +- Maintains state +- Mutates the Connection object along the return path of processing a Request() event. + +Reasoning + +networkservice.NetworkService{Client,Server} processes two kinds of events: + - Request() + - Close() +Each Request() or Close() event is scoped to a networkservice.Connection, which can be uniquely identified by its Connection.Id + +For a given Connection.Id, at most one event can be processed at a time (exclusivity). +For a given Connection.Id, events must be processed in the order they were received (order). +For Close(), the Connection passed to it must be identical to the last one returned by the chain to insure all state +is correctly cleared (close correctness). + +Typically, a chain element receives a Request() or Close() event from the element before it in the chain +and sends a Request() or Close() and either terminates processing returning an error, or sends a Request() or Close() +event to the next element in the chain. + +There are some circumstances in which a Request() or Close() event needs to be originated by a chain element +in the middle of the chain, but processed from the beginning of the chain. Examples include (but are not limited to): + - A server timing out an expired Connection + - A client refreshing a Connection so that it does not expire + - A client healing from a lost Connection +In all of these cases, the Request() or Close() event should be processed starting at the beginning of the chain, to ensure +that all of the proper side effects occur within the chain. +*/ +package begin diff --git a/pkg/registry/common/connect/gen.go b/pkg/registry/common/begin/gen.go similarity index 51% rename from pkg/registry/common/connect/gen.go rename to pkg/registry/common/begin/gen.go index 752770a72..9e10b394d 100644 --- a/pkg/registry/common/connect/gen.go +++ b/pkg/registry/common/begin/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,18 +14,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package connect +package begin -import "sync" +import ( + "sync" +) -//go:generate go-syncmap -output nse_info_map.gen.go -type nseInfoMap -//go:generate go-syncmap -output nse_client_map.gen.go -type nseClientMap +//go:generate go-syncmap -output nse_client_map.gen.go -type nseClientMap +//go:generate go-syncmap -output nse_server_map.gen.go -type nseServerMap -type nseInfoMap sync.Map +//go:generate go-syncmap -output ns_client_map.gen.go -type nsClientMap +//go:generate go-syncmap -output ns_server_map.gen.go -type nsServerMap + +// nseClientMap - sync.Map with key == string and value == *eventNSEFactoryClient type nseClientMap sync.Map -//go:generate go-syncmap -output ns_info_map.gen.go -type nsInfoMap -//go:generate go-syncmap -output ns_client_map.gen.go -type nsClientMap +// nseServerMap - sync.Map with key == string and value == *eventNSEFactoryClient +type nseServerMap sync.Map -type nsInfoMap sync.Map +// nsClientMap - sync.Map with key == string and value == *eventNSFactoryClient type nsClientMap sync.Map + +// nsServerMap - sync.Map with key == string and value == *eventNSFactoryClient +type nsServerMap sync.Map diff --git a/pkg/registry/common/connect/context.go b/pkg/registry/common/begin/merge.go similarity index 52% rename from pkg/registry/common/connect/context.go rename to pkg/registry/common/begin/merge.go index d1ff1a464..cd2a0e752 100644 --- a/pkg/registry/common/connect/context.go +++ b/pkg/registry/common/begin/merge.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,30 +14,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -package connect +package begin -import ( - "context" +import "github.com/networkservicemesh/api/pkg/api/registry" - "google.golang.org/grpc" -) +func mergeNSE(left, right *registry.NetworkServiceEndpoint) *registry.NetworkServiceEndpoint { + if left == nil || right == nil { + return left + } -const ( - ccKey contextKeyType = "cc" -) + var result = right.Clone() -type contextKeyType string + result.Name = left.Name -func withCC(parent context.Context, cc grpc.ClientConnInterface) context.Context { - if parent == nil { - panic("cannot create context from nil parent") - } - return context.WithValue(parent, ccKey, cc) + result.ExpirationTime = nil + + return result } -func ccFromContext(ctx context.Context) grpc.ClientConnInterface { - if cc, ok := ctx.Value(ccKey).(grpc.ClientConnInterface); ok { - return cc +func mergeNS(left, right *registry.NetworkService) *registry.NetworkService { + if left == nil || right == nil { + return left } - return nil + + var result = right.Clone() + + result.Name = left.Name + + return result } diff --git a/pkg/registry/common/begin/ns_client.go b/pkg/registry/common/begin/ns_client.go new file mode 100644 index 000000000..2853bbed9 --- /dev/null +++ b/pkg/registry/common/begin/ns_client.go @@ -0,0 +1,120 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type beginNSClient struct { + nsClientMap +} + +func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + id := in.GetName() + if id == "" { + return nil, errors.New("registry.NetworkService.Name must not be zero valued") + } + // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + } + eventFactoryClient, _ := b.LoadOrStore(id, + newEventNSFactoryClient( + ctx, + func() { + b.Delete(id) + }, + opts..., + ), + ) + var resp *registry.NetworkService + var err error + <-eventFactoryClient.executor.AsyncExec(func() { + // If the eventFactory has changed, usually because the connection has been Closed and re-established + // go back to the beginning and try again. + currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + if currentEventFactoryClient != eventFactoryClient { + log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") + resp, err = b.Register(ctx, in, opts...) + return + } + + ctx = withEventFactory(ctx, eventFactoryClient) + resp, err = next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + if err != nil { + if eventFactoryClient.state != established { + eventFactoryClient.state = closed + b.Delete(id) + } + return + } + eventFactoryClient.opts = opts + eventFactoryClient.state = established + eventFactoryClient.registration = mergeNS(in, resp.Clone()) + eventFactoryClient.response = resp.Clone() + }) + return resp, err +} + +func (b *beginNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (b *beginNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + id := in.GetName() + if fromContext(ctx) != nil { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + } + eventFactoryClient, ok := b.Load(id) + if !ok { + return new(empty.Empty), nil + } + var emp *empty.Empty + var err error + <-eventFactoryClient.executor.AsyncExec(func() { + // If the connection is not established, don't do anything + if eventFactoryClient.state != established || eventFactoryClient.client == nil || eventFactoryClient.registration == nil { + return + } + + // If this isn't the connection we started with, do nothing + currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + if currentEventFactoryClient != eventFactoryClient { + return + } + // Always close with the last valid Connection we got + ctx = withEventFactory(ctx, eventFactoryClient) + emp, err = next.NetworkServiceRegistryClient(ctx).Unregister(ctx, eventFactoryClient.registration, opts...) + // afterCloseFunc() is used to cleanup things like the entry in the Map for EventFactories + eventFactoryClient.afterCloseFunc() + }) + return emp, err +} + +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient() registry.NetworkServiceRegistryClient { + return new(beginNSClient) +} diff --git a/pkg/registry/common/connect/ns_client_map.gen.go b/pkg/registry/common/begin/ns_client_map.gen.go similarity index 64% rename from pkg/registry/common/connect/ns_client_map.gen.go rename to pkg/registry/common/begin/ns_client_map.gen.go index 993c9dc1a..7ad961084 100644 --- a/pkg/registry/common/connect/ns_client_map.gen.go +++ b/pkg/registry/common/begin/ns_client_map.gen.go @@ -1,5 +1,5 @@ -// Code generated by "-output ns_client_map.gen.go -type nsClientMap -output ns_client_map.gen.go -type nsClientMap"; DO NOT EDIT. -package connect +// Code generated by "-output ns_client_map.gen.go -type nsClientMap -output ns_client_map.gen.go -type nsClientMap"; DO NOT EDIT. +package begin import ( "sync" // Used by sync.Map. @@ -12,43 +12,43 @@ func _() { _ = (sync.Map)(nsClientMap{}) } -var _nil_nsClientMap_nsClient_value = func() (val *nsClient) { return }() +var _nil_nsClientMap_eventNSFactoryClient_value = func() (val *eventNSFactoryClient) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *nsClientMap) Load(key string) (*nsClient, bool) { +func (m *nsClientMap) Load(key string) (*eventNSFactoryClient, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_nsClientMap_nsClient_value, ok + return _nil_nsClientMap_eventNSFactoryClient_value, ok } - return value.(*nsClient), ok + return value.(*eventNSFactoryClient), ok } // Store sets the value for a key. -func (m *nsClientMap) Store(key string, value *nsClient) { +func (m *nsClientMap) Store(key string, value *eventNSFactoryClient) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *nsClientMap) LoadOrStore(key string, value *nsClient) (*nsClient, bool) { +func (m *nsClientMap) LoadOrStore(key string, value *eventNSFactoryClient) (*eventNSFactoryClient, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_nsClientMap_nsClient_value, loaded + return _nil_nsClientMap_eventNSFactoryClient_value, loaded } - return actual.(*nsClient), loaded + return actual.(*eventNSFactoryClient), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *nsClientMap) LoadAndDelete(key string) (value *nsClient, loaded bool) { +func (m *nsClientMap) LoadAndDelete(key string) (value *eventNSFactoryClient, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_nsClientMap_nsClient_value, loaded + return _nil_nsClientMap_eventNSFactoryClient_value, loaded } - return actual.(*nsClient), loaded + return actual.(*eventNSFactoryClient), loaded } // Delete deletes the value for a key. @@ -66,8 +66,8 @@ func (m *nsClientMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *nsClientMap) Range(f func(key string, value *nsClient) bool) { +func (m *nsClientMap) Range(f func(key string, value *eventNSFactoryClient) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nsClient)) + return f(key.(string), value.(*eventNSFactoryClient)) }) } diff --git a/pkg/registry/common/begin/ns_event_factory.go b/pkg/registry/common/begin/ns_event_factory.go new file mode 100644 index 000000000..7af3f8c71 --- /dev/null +++ b/pkg/registry/common/begin/ns_event_factory.go @@ -0,0 +1,199 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/edwarnicke/serialize" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" +) + +type eventNSFactoryClient struct { + state connectionState + executor serialize.Executor + ctxFunc func() (context.Context, context.CancelFunc) + registration *registry.NetworkService + response *registry.NetworkService + opts []grpc.CallOption + client registry.NetworkServiceRegistryClient + afterCloseFunc func() +} + +func newEventNSFactoryClient(ctx context.Context, afterClose func(), opts ...grpc.CallOption) *eventNSFactoryClient { + f := &eventNSFactoryClient{ + client: next.NetworkServiceRegistryClient(ctx), + opts: opts, + } + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } + + f.afterCloseFunc = func() { + f.state = closed + if afterClose != nil { + afterClose() + } + } + return f +} + +func (f *eventNSFactoryClient) Register(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.state != established { + return + } + select { + case <-o.cancelCtx.Done(): + default: + registration := f.registration.Clone() + ctx, cancel := f.ctxFunc() + defer cancel() + resp, err := f.client.Register(ctx, registration, f.opts...) + if err == nil && f.registration != nil { + f.registration = mergeNS(f.registration, resp) + } + ch <- err + } + }) + return ch +} + +func (f *eventNSFactoryClient) Unregister(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.registration == nil { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + _, err := f.client.Unregister(ctx, f.response, f.opts...) + f.afterCloseFunc() + ch <- err + } + }) + return ch +} + +var _ EventFactory = &eventNSFactoryClient{} + +type eventNSFactoryServer struct { + state connectionState + executor serialize.Executor + ctxFunc func() (context.Context, context.CancelFunc) + registration *registry.NetworkService + response *registry.NetworkService + afterCloseFunc func() + server registry.NetworkServiceRegistryServer +} + +func newNSEventFactoryServer(ctx context.Context, afterClose func()) *eventNSFactoryServer { + f := &eventNSFactoryServer{ + server: next.NetworkServiceRegistryServer(ctx), + } + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } + + f.afterCloseFunc = func() { + f.state = closed + afterClose() + } + return f +} + +func (f *eventNSFactoryServer) Register(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.state != established { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + resp, err := f.server.Register(ctx, f.registration) + if err == nil && f.registration != nil { + f.registration = resp + } + ch <- err + } + }) + return ch +} + +func (f *eventNSFactoryServer) Unregister(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.registration == nil { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + _, err := f.server.Unregister(ctx, f.registration) + f.afterCloseFunc() + ch <- err + } + }) + return ch +} + +var _ EventFactory = &eventNSFactoryServer{} diff --git a/pkg/registry/common/begin/ns_server.go b/pkg/registry/common/begin/ns_server.go new file mode 100644 index 000000000..711c59696 --- /dev/null +++ b/pkg/registry/common/begin/ns_server.go @@ -0,0 +1,113 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type beginNSServer struct { + nsServerMap +} + +func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkService) (*registry.NetworkService, error) { + id := in.GetName() + if id == "" { + return nil, errors.New("NetworkService.Name can not be zero valued") + } + // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceRegistryServer(ctx).Register(ctx, in) + } + eventFactoryServer, _ := b.LoadOrStore(id, + newNSEventFactoryServer( + ctx, + func() { + b.Delete(id) + }, + ), + ) + + var resp *registry.NetworkService + var err error + + <-eventFactoryServer.executor.AsyncExec(func() { + currentEventFactoryServer, _ := b.LoadOrStore(id, eventFactoryServer) + if currentEventFactoryServer != eventFactoryServer { + log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") + resp, err = b.Register(ctx, in) + return + } + ctx = withEventFactory(ctx, eventFactoryServer) + resp, err = next.NetworkServiceRegistryServer(ctx).Register(ctx, in) + if err != nil { + if eventFactoryServer.state != established { + eventFactoryServer.state = closed + b.Delete(id) + } + return + } + eventFactoryServer.registration = mergeNS(in, resp) + eventFactoryServer.state = established + eventFactoryServer.response = resp + }) + return resp, err +} + +func (b *beginNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { + return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) +} + +func (b *beginNSServer) Unregister(ctx context.Context, in *registry.NetworkService) (*empty.Empty, error) { + id := in.GetName() + // // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, in) + } + eventFactoryServer, ok := b.Load(id) + if !ok { + // If we don't have a connection to Close, just let it be + return &emptypb.Empty{}, nil + } + var err error + <-eventFactoryServer.executor.AsyncExec(func() { + if eventFactoryServer.state != established || eventFactoryServer.registration == nil { + return + } + currentServerClient, _ := b.LoadOrStore(id, eventFactoryServer) + if currentServerClient != eventFactoryServer { + return + } + ctx = withEventFactory(ctx, eventFactoryServer) + _, err = next.NetworkServiceRegistryServer(ctx).Unregister(ctx, eventFactoryServer.registration) + eventFactoryServer.afterCloseFunc() + }) + return &emptypb.Empty{}, err +} + +// NewNetworkServiceRegistryServer - returns a new null server that does nothing but call next.NetworkServiceRegistryServer(ctx). +func NewNetworkServiceRegistryServer() registry.NetworkServiceRegistryServer { + return new(beginNSServer) +} diff --git a/pkg/registry/common/connect/ns_info_map.gen.go b/pkg/registry/common/begin/ns_server_map.gen.go similarity index 55% rename from pkg/registry/common/connect/ns_info_map.gen.go rename to pkg/registry/common/begin/ns_server_map.gen.go index 453c15f1f..601134046 100644 --- a/pkg/registry/common/connect/ns_info_map.gen.go +++ b/pkg/registry/common/begin/ns_server_map.gen.go @@ -1,5 +1,5 @@ -// Code generated by "-output ns_info_map.gen.go -type nsInfoMap -output ns_info_map.gen.go -type nsInfoMap"; DO NOT EDIT. -package connect +// Code generated by "-output ns_server_map.gen.go -type nsServerMap -output ns_server_map.gen.go -type nsServerMap"; DO NOT EDIT. +package begin import ( "sync" // Used by sync.Map. @@ -7,52 +7,52 @@ import ( // Generate code that will fail if the constants change value. func _() { - // An "cannot convert nsInfoMap literal (type nsInfoMap) to type sync.Map" compiler error signifies that the base type have changed. + // An "cannot convert nsServerMap literal (type nsServerMap) to type sync.Map" compiler error signifies that the base type have changed. // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(nsInfoMap{}) + _ = (sync.Map)(nsServerMap{}) } -var _nil_nsInfoMap_nsInfo_value = func() (val *nsInfo) { return }() +var _nil_nsServerMap_eventNSFactoryServer_value = func() (val *eventNSFactoryServer) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *nsInfoMap) Load(key string) (*nsInfo, bool) { +func (m *nsServerMap) Load(key string) (*eventNSFactoryServer, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_nsInfoMap_nsInfo_value, ok + return _nil_nsServerMap_eventNSFactoryServer_value, ok } - return value.(*nsInfo), ok + return value.(*eventNSFactoryServer), ok } // Store sets the value for a key. -func (m *nsInfoMap) Store(key string, value *nsInfo) { +func (m *nsServerMap) Store(key string, value *eventNSFactoryServer) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *nsInfoMap) LoadOrStore(key string, value *nsInfo) (*nsInfo, bool) { +func (m *nsServerMap) LoadOrStore(key string, value *eventNSFactoryServer) (*eventNSFactoryServer, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_nsInfoMap_nsInfo_value, loaded + return _nil_nsServerMap_eventNSFactoryServer_value, loaded } - return actual.(*nsInfo), loaded + return actual.(*eventNSFactoryServer), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *nsInfoMap) LoadAndDelete(key string) (value *nsInfo, loaded bool) { +func (m *nsServerMap) LoadAndDelete(key string) (value *eventNSFactoryServer, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_nsInfoMap_nsInfo_value, loaded + return _nil_nsServerMap_eventNSFactoryServer_value, loaded } - return actual.(*nsInfo), loaded + return actual.(*eventNSFactoryServer), loaded } // Delete deletes the value for a key. -func (m *nsInfoMap) Delete(key string) { +func (m *nsServerMap) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -66,8 +66,8 @@ func (m *nsInfoMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *nsInfoMap) Range(f func(key string, value *nsInfo) bool) { +func (m *nsServerMap) Range(f func(key string, value *eventNSFactoryServer) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nsInfo)) + return f(key.(string), value.(*eventNSFactoryServer)) }) } diff --git a/pkg/registry/common/begin/nse_client.go b/pkg/registry/common/begin/nse_client.go new file mode 100644 index 000000000..078887e68 --- /dev/null +++ b/pkg/registry/common/begin/nse_client.go @@ -0,0 +1,120 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type beginNSEClient struct { + nseClientMap +} + +func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + id := in.GetName() + if id == "" { + return nil, errors.New("registry.NetworkServiceEndpoint.Name must not be zero valued") + } + // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + } + eventFactoryClient, _ := b.LoadOrStore(id, + newEventNSEFactoryClient( + ctx, + func() { + b.Delete(id) + }, + opts..., + ), + ) + var resp *registry.NetworkServiceEndpoint + var err error + <-eventFactoryClient.executor.AsyncExec(func() { + // If the eventFactory has changed, usually because the connection has been Closed and re-established + // go back to the beginning and try again. + currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + if currentEventFactoryClient != eventFactoryClient { + log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") + resp, err = b.Register(ctx, in, opts...) + return + } + + ctx = withEventFactory(ctx, eventFactoryClient) + resp, err = next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + if err != nil { + if eventFactoryClient.state != established { + eventFactoryClient.state = closed + b.Delete(id) + } + return + } + eventFactoryClient.opts = opts + eventFactoryClient.state = established + eventFactoryClient.registration = mergeNSE(in, resp.Clone()) + eventFactoryClient.response = resp.Clone() + }) + return resp, err +} + +func (b *beginNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (b *beginNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + id := in.GetName() + if fromContext(ctx) != nil { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + } + eventFactoryClient, ok := b.Load(id) + if !ok { + return new(empty.Empty), nil + } + var emp *empty.Empty + var err error + <-eventFactoryClient.executor.AsyncExec(func() { + // If the connection is not established, don't do anything + if eventFactoryClient.state != established || eventFactoryClient.client == nil || eventFactoryClient.registration == nil { + return + } + + // If this isn't the connection we started with, do nothing + currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + if currentEventFactoryClient != eventFactoryClient { + return + } + // Always close with the last valid Connection we got + ctx = withEventFactory(ctx, eventFactoryClient) + emp, err = next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, eventFactoryClient.registration, opts...) + // afterCloseFunc() is used to cleanup things like the entry in the Map for EventFactories + eventFactoryClient.afterCloseFunc() + }) + return emp, err +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient() registry.NetworkServiceEndpointRegistryClient { + return new(beginNSEClient) +} diff --git a/pkg/registry/common/connect/nse_client_map.gen.go b/pkg/registry/common/begin/nse_client_map.gen.go similarity index 64% rename from pkg/registry/common/connect/nse_client_map.gen.go rename to pkg/registry/common/begin/nse_client_map.gen.go index 47004656c..62ef33c38 100644 --- a/pkg/registry/common/connect/nse_client_map.gen.go +++ b/pkg/registry/common/begin/nse_client_map.gen.go @@ -1,5 +1,5 @@ -// Code generated by "-output nse_client_map.gen.go -type nseClientMap -output nse_client_map.gen.go -type nseClientMap"; DO NOT EDIT. -package connect +// Code generated by "-output nse_client_map.gen.go -type nseClientMap -output nse_client_map.gen.go -type nseClientMap"; DO NOT EDIT. +package begin import ( "sync" // Used by sync.Map. @@ -12,43 +12,43 @@ func _() { _ = (sync.Map)(nseClientMap{}) } -var _nil_nseClientMap_nseClient_value = func() (val *nseClient) { return }() +var _nil_nseClientMap_eventNSEFactoryClient_value = func() (val *eventNSEFactoryClient) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *nseClientMap) Load(key string) (*nseClient, bool) { +func (m *nseClientMap) Load(key string) (*eventNSEFactoryClient, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_nseClientMap_nseClient_value, ok + return _nil_nseClientMap_eventNSEFactoryClient_value, ok } - return value.(*nseClient), ok + return value.(*eventNSEFactoryClient), ok } // Store sets the value for a key. -func (m *nseClientMap) Store(key string, value *nseClient) { +func (m *nseClientMap) Store(key string, value *eventNSEFactoryClient) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *nseClientMap) LoadOrStore(key string, value *nseClient) (*nseClient, bool) { +func (m *nseClientMap) LoadOrStore(key string, value *eventNSEFactoryClient) (*eventNSEFactoryClient, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_nseClientMap_nseClient_value, loaded + return _nil_nseClientMap_eventNSEFactoryClient_value, loaded } - return actual.(*nseClient), loaded + return actual.(*eventNSEFactoryClient), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *nseClientMap) LoadAndDelete(key string) (value *nseClient, loaded bool) { +func (m *nseClientMap) LoadAndDelete(key string) (value *eventNSEFactoryClient, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_nseClientMap_nseClient_value, loaded + return _nil_nseClientMap_eventNSEFactoryClient_value, loaded } - return actual.(*nseClient), loaded + return actual.(*eventNSEFactoryClient), loaded } // Delete deletes the value for a key. @@ -66,8 +66,8 @@ func (m *nseClientMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *nseClientMap) Range(f func(key string, value *nseClient) bool) { +func (m *nseClientMap) Range(f func(key string, value *eventNSEFactoryClient) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nseClient)) + return f(key.(string), value.(*eventNSEFactoryClient)) }) } diff --git a/pkg/registry/common/begin/nse_event_factory.go b/pkg/registry/common/begin/nse_event_factory.go new file mode 100644 index 000000000..c6dddca00 --- /dev/null +++ b/pkg/registry/common/begin/nse_event_factory.go @@ -0,0 +1,199 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/edwarnicke/serialize" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" +) + +type eventNSEFactoryClient struct { + state connectionState + executor serialize.Executor + ctxFunc func() (context.Context, context.CancelFunc) + registration *registry.NetworkServiceEndpoint + response *registry.NetworkServiceEndpoint + opts []grpc.CallOption + client registry.NetworkServiceEndpointRegistryClient + afterCloseFunc func() +} + +func newEventNSEFactoryClient(ctx context.Context, afterClose func(), opts ...grpc.CallOption) *eventNSEFactoryClient { + f := &eventNSEFactoryClient{ + client: next.NetworkServiceEndpointRegistryClient(ctx), + opts: opts, + } + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } + + f.afterCloseFunc = func() { + f.state = closed + if afterClose != nil { + afterClose() + } + } + return f +} + +func (f *eventNSEFactoryClient) Register(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.state != established { + return + } + select { + case <-o.cancelCtx.Done(): + default: + registration := f.registration.Clone() + ctx, cancel := f.ctxFunc() + defer cancel() + resp, err := f.client.Register(ctx, registration, f.opts...) + if err == nil && f.registration != nil { + f.registration = mergeNSE(f.registration, resp) + } + ch <- err + } + }) + return ch +} + +func (f *eventNSEFactoryClient) Unregister(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.registration == nil { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + _, err := f.client.Unregister(ctx, f.response, f.opts...) + f.afterCloseFunc() + ch <- err + } + }) + return ch +} + +var _ EventFactory = &eventNSEFactoryClient{} + +type eventNSEFactoryServer struct { + state connectionState + executor serialize.Executor + ctxFunc func() (context.Context, context.CancelFunc) + registration *registry.NetworkServiceEndpoint + response *registry.NetworkServiceEndpoint + afterCloseFunc func() + server registry.NetworkServiceEndpointRegistryServer +} + +func newNSEEventFactoryServer(ctx context.Context, afterClose func()) *eventNSEFactoryServer { + f := &eventNSEFactoryServer{ + server: next.NetworkServiceEndpointRegistryServer(ctx), + } + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } + + f.afterCloseFunc = func() { + f.state = closed + afterClose() + } + return f +} + +func (f *eventNSEFactoryServer) Register(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.state != established { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + resp, err := f.server.Register(ctx, f.registration) + if err == nil && f.registration != nil { + f.registration = resp + } + ch <- err + } + }) + return ch +} + +func (f *eventNSEFactoryServer) Unregister(opts ...Option) <-chan error { + o := &option{ + cancelCtx: context.Background(), + } + for _, opt := range opts { + opt(o) + } + ch := make(chan error, 1) + f.executor.AsyncExec(func() { + defer close(ch) + if f.registration == nil { + return + } + select { + case <-o.cancelCtx.Done(): + default: + ctx, cancel := f.ctxFunc() + defer cancel() + _, err := f.server.Unregister(ctx, f.registration) + f.afterCloseFunc() + ch <- err + } + }) + return ch +} + +var _ EventFactory = &eventNSEFactoryServer{} diff --git a/pkg/registry/common/begin/nse_server.go b/pkg/registry/common/begin/nse_server.go new file mode 100644 index 000000000..bba97fb54 --- /dev/null +++ b/pkg/registry/common/begin/nse_server.go @@ -0,0 +1,113 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type beginNSEServer struct { + nseServerMap +} + +func (b *beginNSEServer) Register(ctx context.Context, in *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + id := in.GetName() + if id == "" { + return nil, errors.New("NetworkServiceEndpoint.Name can not be zero valued") + } + // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) + } + eventFactoryServer, _ := b.LoadOrStore(id, + newNSEEventFactoryServer( + ctx, + func() { + b.Delete(id) + }, + ), + ) + + var resp *registry.NetworkServiceEndpoint + var err error + + <-eventFactoryServer.executor.AsyncExec(func() { + currentEventFactoryServer, _ := b.LoadOrStore(id, eventFactoryServer) + if currentEventFactoryServer != eventFactoryServer { + log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") + resp, err = b.Register(ctx, in) + return + } + ctx = withEventFactory(ctx, eventFactoryServer) + resp, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) + if err != nil { + if eventFactoryServer.state != established { + eventFactoryServer.state = closed + b.Delete(id) + } + return + } + eventFactoryServer.registration = mergeNSE(in, resp) + eventFactoryServer.state = established + eventFactoryServer.response = resp + }) + return resp, err +} + +func (b *beginNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) +} + +func (b *beginNSEServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + id := in.GetName() + // // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally + if fromContext(ctx) != nil { + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, in) + } + eventFactoryServer, ok := b.Load(id) + if !ok { + // If we don't have a connection to Close, just let it be + return &emptypb.Empty{}, nil + } + var err error + <-eventFactoryServer.executor.AsyncExec(func() { + if eventFactoryServer.state != established || eventFactoryServer.registration == nil { + return + } + currentServerClient, _ := b.LoadOrStore(id, eventFactoryServer) + if currentServerClient != eventFactoryServer { + return + } + ctx = withEventFactory(ctx, eventFactoryServer) + _, err = next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, eventFactoryServer.registration) + eventFactoryServer.afterCloseFunc() + }) + return &emptypb.Empty{}, err +} + +// NewNetworkServiceEndpointRegistryServer - returns a new null server that does nothing but call next.NetworkServiceEndpointRegistryServer(ctx). +func NewNetworkServiceEndpointRegistryServer() registry.NetworkServiceEndpointRegistryServer { + return new(beginNSEServer) +} diff --git a/pkg/registry/common/begin/nse_server_map.gen.go b/pkg/registry/common/begin/nse_server_map.gen.go new file mode 100644 index 000000000..d35e82dff --- /dev/null +++ b/pkg/registry/common/begin/nse_server_map.gen.go @@ -0,0 +1,73 @@ +// Code generated by "-output nse_server_map.gen.go -type nseServerMap -output nse_server_map.gen.go -type nseServerMap"; DO NOT EDIT. +package begin + +import ( + "sync" // Used by sync.Map. +) + +// Generate code that will fail if the constants change value. +func _() { + // An "cannot convert nseServerMap literal (type nseServerMap) to type sync.Map" compiler error signifies that the base type have changed. + // Re-run the go-syncmap command to generate them again. + _ = (sync.Map)(nseServerMap{}) +} + +var _nil_nseServerMap_eventNSEFactoryServer_value = func() (val *eventNSEFactoryServer) { return }() + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *nseServerMap) Load(key string) (*eventNSEFactoryServer, bool) { + value, ok := (*sync.Map)(m).Load(key) + if value == nil { + return _nil_nseServerMap_eventNSEFactoryServer_value, ok + } + return value.(*eventNSEFactoryServer), ok +} + +// Store sets the value for a key. +func (m *nseServerMap) Store(key string, value *eventNSEFactoryServer) { + (*sync.Map)(m).Store(key, value) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *nseServerMap) LoadOrStore(key string, value *eventNSEFactoryServer) (*eventNSEFactoryServer, bool) { + actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) + if actual == nil { + return _nil_nseServerMap_eventNSEFactoryServer_value, loaded + } + return actual.(*eventNSEFactoryServer), loaded +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +// The loaded result reports whether the key was present. +func (m *nseServerMap) LoadAndDelete(key string) (value *eventNSEFactoryServer, loaded bool) { + actual, loaded := (*sync.Map)(m).LoadAndDelete(key) + if actual == nil { + return _nil_nseServerMap_eventNSEFactoryServer_value, loaded + } + return actual.(*eventNSEFactoryServer), loaded +} + +// Delete deletes the value for a key. +func (m *nseServerMap) Delete(key string) { + (*sync.Map)(m).Delete(key) +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the Map's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently, Range may reflect any mapping for that key +// from any point during the Range call. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *nseServerMap) Range(f func(key string, value *eventNSEFactoryServer) bool) { + (*sync.Map)(m).Range(func(key, value interface{}) bool { + return f(key.(string), value.(*eventNSEFactoryServer)) + }) +} diff --git a/pkg/registry/common/proxy/doc.go b/pkg/registry/common/begin/options.go similarity index 59% rename from pkg/registry/common/proxy/doc.go rename to pkg/registry/common/begin/options.go index b08634898..cd3017481 100644 --- a/pkg/registry/common/proxy/doc.go +++ b/pkg/registry/common/begin/options.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,5 +14,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package proxy provides registry chain elements that can put URL to the proxy registry to the context in case of interdomain upstream. -package proxy +package begin + +import ( + "context" +) + +type option struct { + cancelCtx context.Context +} + +// Option - event option +type Option func(*option) + +// CancelContext - optionally provide a context that, when canceled will preclude the event from running +func CancelContext(cancelCtx context.Context) Option { + return func(o *option) { + o.cancelCtx = cancelCtx + } +} diff --git a/pkg/registry/common/begin/serialize_both_test.go b/pkg/registry/common/begin/serialize_both_test.go new file mode 100644 index 000000000..e08bec23e --- /dev/null +++ b/pkg/registry/common/begin/serialize_both_test.go @@ -0,0 +1,66 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" +) + +func TestSerializeBoth_StressTest(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + newParallelServer(t), + adapters.NetworkServiceEndpointClientToServer(chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + newParallelClient(t), + ), + ), + ) + + wg := new(sync.WaitGroup) + wg.Add(parallelCount) + for i := 0; i < parallelCount; i++ { + go func(id string) { + defer wg.Done() + + resp, err := server.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NoError(t, err) + + _, err = server.Unregister(ctx, resp) + assert.NoError(t, err) + }(fmt.Sprint(i % 20)) + } + wg.Wait() +} diff --git a/pkg/registry/common/begin/serialize_client_test.go b/pkg/registry/common/begin/serialize_client_test.go new file mode 100644 index 000000000..2429bb989 --- /dev/null +++ b/pkg/registry/common/begin/serialize_client_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +func TestSerializeClient_StressTest(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + newParallelClient(t), + ) + + wg := new(sync.WaitGroup) + wg.Add(parallelCount) + for i := 0; i < parallelCount; i++ { + go func(id string) { + defer wg.Done() + + resp, err := client.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NoError(t, err) + + _, err = client.Unregister(ctx, resp) + assert.NoError(t, err) + }(fmt.Sprint(i % 20)) + } + wg.Wait() +} + +type parallelClient struct { + t *testing.T + states sync.Map +} + +func newParallelClient(t *testing.T) *parallelClient { + return ¶llelClient{ + t: t, + } +} + +func (s *parallelClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) + statePtr := raw.(*int32) + + state := atomic.LoadInt32(statePtr) + if !atomic.CompareAndSwapInt32(statePtr, state, state+1) { + assert.Failf(s.t, "", "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) + } + + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (s *parallelClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (s *parallelClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) + statePtr := raw.(*int32) + + state := atomic.LoadInt32(statePtr) + if !atomic.CompareAndSwapInt32(statePtr, state, state+1) { + assert.Failf(s.t, "", "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) + } + + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient() registry.NetworkServiceEndpointRegistryClient { + return new(parallelClient) +} diff --git a/pkg/registry/common/begin/serialize_server_test.go b/pkg/registry/common/begin/serialize_server_test.go new file mode 100644 index 000000000..2d45f1a7f --- /dev/null +++ b/pkg/registry/common/begin/serialize_server_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +const ( + parallelCount = 1000 +) + +func TestSerializeServer_StressTest(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + newParallelServer(t), + ) + + wg := new(sync.WaitGroup) + wg.Add(parallelCount) + for i := 0; i < parallelCount; i++ { + go func(id string) { + defer wg.Done() + + resp, err := server.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: id, + }) + assert.NoError(t, err) + + _, err = server.Unregister(ctx, resp) + assert.NoError(t, err) + }(fmt.Sprint(i % 20)) + } + wg.Wait() +} + +func newParallelServer(t *testing.T) *parallelServer { + return ¶llelServer{ + t: t, + } +} + +type parallelServer struct { + t *testing.T + states sync.Map +} + +func (s *parallelServer) Register(ctx context.Context, in *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) + statePtr := raw.(*int32) + + state := atomic.LoadInt32(statePtr) + assert.True(s.t, atomic.CompareAndSwapInt32(statePtr, state, state+1), "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) + + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) +} + +func (s *parallelServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) +} + +func (s *parallelServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) + statePtr := raw.(*int32) + + state := atomic.LoadInt32(statePtr) + assert.True(s.t, atomic.CompareAndSwapInt32(statePtr, state, state+1), "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) + + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, in) +} diff --git a/pkg/registry/common/clientconn/context.go b/pkg/registry/common/clientconn/context.go new file mode 100644 index 000000000..d8ebe8ee0 --- /dev/null +++ b/pkg/registry/common/clientconn/context.go @@ -0,0 +1,100 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconn + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" +) + +type mapKey struct{} +type nameKey struct{} + +func withClientConnMetadata(ctx context.Context, m *stringCCMap, key string) context.Context { + ctx = context.WithValue(ctx, nameKey{}, key) + ctx = context.WithValue(ctx, mapKey{}, m) + return ctx +} + +func nameFromContext(ctx context.Context) string { + if v := ctx.Value(nameKey{}); v != nil { + return v.(string) + } + if u := clienturlctx.ClientURL(ctx); u != nil { + return u.String() + } + return "" +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +// The loaded result reports whether the key was present. +func LoadAndDelete(ctx context.Context) (grpc.ClientConnInterface, bool) { + k := nameFromContext(ctx) + + if v, ok := ctx.Value(mapKey{}).(*stringCCMap); ok && k != "" { + return v.LoadAndDelete(k) + } + + return nil, false +} + +// Store sets the value for a key. +func Store(ctx context.Context, cc grpc.ClientConnInterface) { + k := nameFromContext(ctx) + + if v, ok := ctx.Value(mapKey{}).(*stringCCMap); ok && k != "" { + v.Store(k, cc) + } +} + +// Delete deletes the value for a key. +func Delete(ctx context.Context) { + k := nameFromContext(ctx) + + if v, ok := ctx.Value(mapKey{}).(*stringCCMap); ok && k != "" { + v.Delete(k) + } +} + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func Load(ctx context.Context) (grpc.ClientConnInterface, bool) { + k := nameFromContext(ctx) + + if v, ok := ctx.Value(mapKey{}).(*stringCCMap); ok && k != "" { + return v.Load(k) + } + + return nil, false +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func LoadOrStore(ctx context.Context, cc grpc.ClientConnInterface) (grpc.ClientConnInterface, bool) { + k := nameFromContext(ctx) + + if v, ok := ctx.Value(mapKey{}).(*stringCCMap); ok && k != "" { + return v.LoadOrStore(k, cc) + } + + return cc, false +} diff --git a/pkg/tools/expire/gen.go b/pkg/registry/common/clientconn/gen.go similarity index 66% rename from pkg/tools/expire/gen.go rename to pkg/registry/common/clientconn/gen.go index 53189ff7a..792b22b80 100644 --- a/pkg/tools/expire/gen.go +++ b/pkg/registry/common/clientconn/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,10 +14,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package expire +package clientconn -import "sync" +import ( + "sync" +) -//go:generate go-syncmap -output timer_map.gen.go -type timerMap +//go:generate go-syncmap -output map.gen.go -type stringCCMap -type timerMap sync.Map +// clientMap - sync.Map with key == url.URL and value == grpc.ClientConnInterface +type stringCCMap sync.Map diff --git a/pkg/tools/expire/timer_map.gen.go b/pkg/registry/common/clientconn/map.gen.go similarity index 53% rename from pkg/tools/expire/timer_map.gen.go rename to pkg/registry/common/clientconn/map.gen.go index 21a1e58b6..7a0b20f3a 100644 --- a/pkg/tools/expire/timer_map.gen.go +++ b/pkg/registry/common/clientconn/map.gen.go @@ -1,58 +1,60 @@ -// Code generated by "-output timer_map.gen.go -type timerMap -output timer_map.gen.go -type timerMap"; DO NOT EDIT. -package expire +// Code generated by "-output map.gen.go -type stringCCMap -output map.gen.go -type stringCCMap"; DO NOT EDIT. +package clientconn import ( "sync" // Used by sync.Map. + + "google.golang.org/grpc" ) // Generate code that will fail if the constants change value. func _() { - // An "cannot convert timerMap literal (type timerMap) to type sync.Map" compiler error signifies that the base type have changed. + // An "cannot convert stringCCMap literal (type stringCCMap) to type sync.Map" compiler error signifies that the base type have changed. // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(timerMap{}) + _ = (sync.Map)(stringCCMap{}) } -var _nil_timerMap_timer_value = func() (val *timer) { return }() +var _nil_stringCCMap_grpc_ClientConnInterface_value = func() (val grpc.ClientConnInterface) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *timerMap) Load(key string) (*timer, bool) { +func (m *stringCCMap) Load(key string) (grpc.ClientConnInterface, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_timerMap_timer_value, ok + return _nil_stringCCMap_grpc_ClientConnInterface_value, ok } - return value.(*timer), ok + return value.(grpc.ClientConnInterface), ok } // Store sets the value for a key. -func (m *timerMap) Store(key string, value *timer) { +func (m *stringCCMap) Store(key string, value grpc.ClientConnInterface) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *timerMap) LoadOrStore(key string, value *timer) (*timer, bool) { +func (m *stringCCMap) LoadOrStore(key string, value grpc.ClientConnInterface) (grpc.ClientConnInterface, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_timerMap_timer_value, loaded + return _nil_stringCCMap_grpc_ClientConnInterface_value, loaded } - return actual.(*timer), loaded + return actual.(grpc.ClientConnInterface), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *timerMap) LoadAndDelete(key string) (value *timer, loaded bool) { +func (m *stringCCMap) LoadAndDelete(key string) (value grpc.ClientConnInterface, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_timerMap_timer_value, loaded + return _nil_stringCCMap_grpc_ClientConnInterface_value, loaded } - return actual.(*timer), loaded + return actual.(grpc.ClientConnInterface), loaded } // Delete deletes the value for a key. -func (m *timerMap) Delete(key string) { +func (m *stringCCMap) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -66,8 +68,8 @@ func (m *timerMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *timerMap) Range(f func(key string, value *timer) bool) { +func (m *stringCCMap) Range(f func(key string, value grpc.ClientConnInterface) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*timer)) + return f(key.(string), value.(grpc.ClientConnInterface)) }) } diff --git a/pkg/registry/common/clientconn/ns_client.go b/pkg/registry/common/clientconn/ns_client.go new file mode 100644 index 000000000..98564e2b5 --- /dev/null +++ b/pkg/registry/common/clientconn/ns_client.go @@ -0,0 +1,53 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package clientconn - chain element for injecting a grpc.ClientConnInterface into the client chain +package clientconn + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/google/uuid" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +type clientConnNSClient struct { + stringCCMap +} + +func (c *clientConnNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, in.GetName()) + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (c *clientConnNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, in.GetName()) + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in) +} + +func (c *clientConnNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, uuid.New().String()) + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) +} + +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient() registry.NetworkServiceRegistryClient { + return new(clientConnNSClient) +} diff --git a/pkg/registry/common/clientconn/nse_client.go b/pkg/registry/common/clientconn/nse_client.go new file mode 100644 index 000000000..947c57eb7 --- /dev/null +++ b/pkg/registry/common/clientconn/nse_client.go @@ -0,0 +1,53 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package clientconn - chain element for injecting a grpc.ClientConnInterface into the client chain +package clientconn + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/google/uuid" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +type clientConnNSEClient struct { + stringCCMap +} + +func (c *clientConnNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, in.GetName()) + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (c *clientConnNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, in.GetName()) + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in) +} + +func (c *clientConnNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + ctx = withClientConnMetadata(ctx, &c.stringCCMap, uuid.New().String()) + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient() registry.NetworkServiceEndpointRegistryClient { + return new(clientConnNSEClient) +} diff --git a/pkg/registry/common/clienturl/ns_client.go b/pkg/registry/common/clienturl/ns_client.go new file mode 100644 index 000000000..d72636170 --- /dev/null +++ b/pkg/registry/common/clienturl/ns_client.go @@ -0,0 +1,55 @@ +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clienturl + +import ( + "context" + "net/url" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" +) + +type clientURLNSClient struct { + u *url.URL +} + +func (c *clientURLNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (c *clientURLNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (c *clientURLNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient(u *url.URL) registry.NetworkServiceRegistryClient { + return &clientURLNSClient{ + u: u, + } +} diff --git a/pkg/registry/common/clienturl/nse_client.go b/pkg/registry/common/clienturl/nse_client.go new file mode 100644 index 000000000..673562a2d --- /dev/null +++ b/pkg/registry/common/clienturl/nse_client.go @@ -0,0 +1,58 @@ +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clienturl + +import ( + "context" + "net/url" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" +) + +type clientURLNSEClient struct { + u *url.URL +} + +func (c *clientURLNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (c *clientURLNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (c *clientURLNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + ctx = clienturlctx.WithClientURL(ctx, c.u) + + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient(u *url.URL) registry.NetworkServiceEndpointRegistryClient { + return &clientURLNSEClient{ + u: u, + } +} diff --git a/pkg/registry/common/connect/doc.go b/pkg/registry/common/connect/doc.go deleted file mode 100644 index a6966d416..000000000 --- a/pkg/registry/common/connect/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package connect provides NS, NSE registry server chain elements providing access to remote registry servers -package connect diff --git a/pkg/registry/common/connect/ns_client.go b/pkg/registry/common/connect/ns_client.go index f2d171ac9..07959e3b9 100644 --- a/pkg/registry/common/connect/ns_client.go +++ b/pkg/registry/common/connect/ns_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,125 +18,41 @@ package connect import ( "context" - "net/url" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/protobuf/types/known/emptypb" + "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" - "github.com/networkservicemesh/sdk/pkg/registry/core/chain" - "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" ) -type connectNSClient struct { - ctx context.Context - client registry.NetworkServiceRegistryClient - connectTo string - dialOptions []grpc.DialOption - - cc *grpc.ClientConn - lock sync.RWMutex -} - -// NewNetworkServiceRegistryClient returns a new NS registry client chain element connecting to the remote -// NS registry server -func NewNetworkServiceRegistryClient(ctx context.Context, connectTo *url.URL, opts ...Option) registry.NetworkServiceRegistryClient { - connectOpts := new(connectOptions) - for _, opt := range opts { - opt(connectOpts) - } - - c := &connectNSClient{ - ctx: ctx, - client: chain.NewNetworkServiceRegistryClient( - append( - connectOpts.nsAdditionalFunctionality, - new(grpcNSClient), - )..., - ), - connectTo: grpcutils.URLToTarget(connectTo), - dialOptions: append(append([]grpc.DialOption{}, connectOpts.dialOptions...), grpc.WithReturnConnectionError()), - } - - go func() { - <-ctx.Done() - - c.lock.Lock() - defer c.lock.Unlock() - - if c.cc != nil { - _ = c.cc.Close() - } - }() +type connectNSClient struct{} - return c -} - -func (c *connectNSClient) Register(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { - cc, err := c.getCC() - if err != nil { - return nil, err +func (n *connectNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - return c.client.Register(withCC(ctx, cc), ns, opts...) + return registry.NewNetworkServiceRegistryClient(cc).Register(ctx, in, opts...) } -func (c *connectNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { - cc, err := c.getCC() - if err != nil { - return nil, err +func (n *connectNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - return c.client.Find(withCC(ctx, cc), query, opts...) + return registry.NewNetworkServiceRegistryClient(cc).Find(ctx, in, opts...) } -func (c *connectNSClient) Unregister(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*emptypb.Empty, error) { - cc, err := c.getCC() - if err != nil { - return nil, err +func (n *connectNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - return c.client.Unregister(withCC(ctx, cc), ns, opts...) + return registry.NewNetworkServiceRegistryClient(cc).Unregister(ctx, in, opts...) } -func (c *connectNSClient) getCC() (*grpc.ClientConn, error) { - c.lock.RLock() - cc := c.cc - c.lock.RUnlock() - - if cc != nil { - return cc, nil - } - - c.lock.Lock() - defer c.lock.Unlock() - - if c.cc != nil { - return c.cc, nil - } - - var err error - if c.cc, err = grpc.DialContext(c.ctx, c.connectTo, c.dialOptions...); err != nil { - return nil, err - } - - go func() { - defer func() { - c.lock.Lock() - defer c.lock.Unlock() - - _ = c.cc.Close() - c.cc = nil - }() - for c.cc.WaitForStateChange(c.ctx, c.cc.GetState()) { - switch c.cc.GetState() { - case connectivity.Connecting, connectivity.Idle, connectivity.Ready: - continue - default: - return - } - } - }() - - return c.cc, nil +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient() registry.NetworkServiceRegistryClient { + return new(connectNSClient) } diff --git a/pkg/registry/common/connect/ns_client_test.go b/pkg/registry/common/connect/ns_client_test.go deleted file mode 100644 index 22c4203a5..000000000 --- a/pkg/registry/common/connect/ns_client_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connect_test - -import ( - "context" - "net/url" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "google.golang.org/grpc" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/common/connect" - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" - "github.com/networkservicemesh/sdk/pkg/tools/sandbox" -) - -func TestConnectNSClient(t *testing.T) { - t.Cleanup(func() { goleak.VerifyNone(t) }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mem := memory.NewNetworkServiceRegistryServer() - - u := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} - require.NoError(t, startNSServer(ctx, u, mem)) - require.NoError(t, waitNSServerStarted(u)) - - // 1. Register remote NS - _, err := mem.Register(ctx, ®istry.NetworkService{Name: "ns-remote"}) - require.NoError(t, err) - - c := connect.NewNetworkServiceRegistryClient(ctx, u, - connect.WithDialOptions(grpc.WithInsecure()), - ) - - // 2. Register local NS - _, err = c.Register(ctx, ®istry.NetworkService{Name: "ns-local"}) - require.NoError(t, err) - - // 3. Find both local, remote NSs from client - stream, err := c.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: new(registry.NetworkService), - }) - require.NoError(t, err) - - var nsNames []string - for _, ns := range registry.ReadNetworkServiceList(stream) { - nsNames = append(nsNames, ns.Name) - } - require.Len(t, nsNames, 2) - require.Subset(t, []string{"ns-remote", "ns-local"}, nsNames) - - // 4. Unregister remote NS from client - _, err = c.Unregister(ctx, ®istry.NetworkService{Name: "ns-remote"}) - require.NoError(t, err) - - // 5. Find only local NS in memory - ch := make(chan *registry.NetworkServiceResponse, 2) - err = mem.Find(®istry.NetworkServiceQuery{ - NetworkService: new(registry.NetworkService), - }, streamchannel.NewNetworkServiceFindServer(ctx, ch)) - require.NoError(t, err) - - require.Len(t, ch, 1) - require.Equal(t, "ns-local", (<-ch).NetworkService.Name) -} - -func TestConnectNSClient_Restart(t *testing.T) { - t.Cleanup(func() { goleak.VerifyNone(t) }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverCtx, serverCancel := context.WithCancel(ctx) - - mem := memory.NewNetworkServiceRegistryServer() - - u := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} - require.NoError(t, startNSServer(serverCtx, u, mem)) - require.NoError(t, waitNSServerStarted(u)) - - c := connect.NewNetworkServiceRegistryClient(ctx, u, - connect.WithDialOptions(grpc.WithInsecure()), - ) - - // 1. Register NS-1 with client - _, err := c.Register(ctx, ®istry.NetworkService{Name: "ns-1"}) - require.NoError(t, err) - - // 2. Restart remote - serverCancel() - require.Eventually(t, func() bool { - return sandbox.CheckURLFree(u) - }, time.Second, 10*time.Millisecond) - - require.NoError(t, startNSServer(ctx, u, mem)) - require.NoError(t, waitNSServerStarted(u)) - - // 3. Register NS-2 with client - require.Eventually(t, func() bool { - _, err = c.Register(ctx, ®istry.NetworkService{Name: "ns-2"}) - return err == nil - }, time.Second, 10*time.Millisecond) - - // 4. Find both NS-1, NS-2 in memory - ch := make(chan *registry.NetworkServiceResponse, 2) - err = mem.Find(®istry.NetworkServiceQuery{ - NetworkService: new(registry.NetworkService), - }, streamchannel.NewNetworkServiceFindServer(ctx, ch)) - require.NoError(t, err) - - var nsNames []string - for i := len(ch); i > 0; i-- { - nsNames = append(nsNames, (<-ch).NetworkService.Name) - } - require.Len(t, nsNames, 2) - require.Subset(t, []string{"ns-1", "ns-2"}, nsNames) -} diff --git a/pkg/registry/common/connect/ns_grpc_client.go b/pkg/registry/common/connect/ns_grpc_client.go deleted file mode 100644 index e60423b8f..000000000 --- a/pkg/registry/common/connect/ns_grpc_client.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connect - -import ( - "context" - - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/emptypb" - - "github.com/networkservicemesh/api/pkg/api/registry" -) - -type grpcNSClient struct{} - -func (c *grpcNSClient) Register(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { - return registry.NewNetworkServiceRegistryClient(ccFromContext(ctx)).Register(ctx, ns, opts...) -} - -func (c *grpcNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { - return registry.NewNetworkServiceRegistryClient(ccFromContext(ctx)).Find(ctx, query, opts...) -} - -func (c *grpcNSClient) Unregister(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*emptypb.Empty, error) { - return registry.NewNetworkServiceRegistryClient(ccFromContext(ctx)).Unregister(ctx, ns, opts...) -} diff --git a/pkg/registry/common/connect/ns_server.go b/pkg/registry/common/connect/ns_server.go index 07e585e5a..bbd9f9d19 100644 --- a/pkg/registry/common/connect/ns_server.go +++ b/pkg/registry/common/connect/ns_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,166 +14,76 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package connect provides chain elements to 'connect' clients package connect import ( "context" - "net/url" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" + "google.golang.org/grpc" - "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" - "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" + "github.com/networkservicemesh/api/pkg/api/registry" ) type connectNSServer struct { - ctx context.Context - clientOptions []Option - - nsInfos nsInfoMap - clients nsClientMap - executor multiexecutor.MultiExecutor -} - -type nsInfo struct { - clientURL *url.URL - client *nsClient -} - -type nsClient struct { - client registry.NetworkServiceRegistryClient - count int - onClose context.CancelFunc -} - -// NewNetworkServiceRegistryServer - server chain element that creates client subchains and requests them selecting by -// clienturlctx.ClientURL(ctx) -func NewNetworkServiceRegistryServer( - ctx context.Context, - clientOptions ...Option, -) registry.NetworkServiceRegistryServer { - return &connectNSServer{ - ctx: ctx, - clientOptions: clientOptions, - } + client registry.NetworkServiceRegistryClient + callOptions []grpc.CallOption } -func (s *connectNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) { - clientURL := clienturlctx.ClientURL(ctx) - if clientURL == nil { - return nil, errors.Errorf("clientURL not found for incoming service: %+v", ns) +func (c *connectNSServer) Register(ctx context.Context, in *registry.NetworkService) (*registry.NetworkService, error) { + closeCtxFunc := postpone.ContextWithValues(ctx) + clientResp, clientErr := c.client.Register(ctx, in, c.callOptions...) + if clientErr != nil { + return nil, clientErr } - _, loaded := s.nsInfos.Load(ns.Name) - - c := s.client(ctx, ns) - reg, err := c.client.Register(ctx, ns) - if err != nil { - if !loaded { - s.closeClient(c, clientURL.String()) - } - return nil, err + serverResp, serverErr := next.NetworkServiceRegistryServer(ctx).Register(ctx, clientResp) + if serverErr != nil { + closeCtx, closeCancel := closeCtxFunc() + defer closeCancel() + _, _ = c.client.Unregister(closeCtx, clientResp, c.callOptions...) } - - s.nsInfos.Store(ns.Name, &nsInfo{ - clientURL: clientURL, - client: c, - }) - - return reg, nil + return serverResp, serverErr } -func (s *connectNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { - clientURL := clienturlctx.ClientURL(server.Context()) - if clientURL == nil { - return errors.Errorf("clientURL not found for incoming query: %+v", query) - } - - c := s.client(server.Context(), nil) - - err := adapters.NetworkServiceClientToServer(c.client).Find(query, server) +func (c *connectNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { + ctx := server.Context() - s.closeClient(c, clientURL.String()) - - return err -} - -func (s *connectNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) { - clientURL := clienturlctx.ClientURL(ctx) - if clientURL == nil { - return nil, errors.Errorf("clientURL not found for incoming service: %+v", ns) + clientResp, clientErr := c.client.Find(ctx, query, c.callOptions...) + if clientErr != nil { + return clientErr } - c := s.client(ctx, ns) - - _, err := c.client.Unregister(ctx, ns) - - s.closeClient(c, clientURL.String()) - s.nsInfos.Delete(ns.Name) - - return new(empty.Empty), err -} - -func (s *connectNSServer) client(ctx context.Context, ns *registry.NetworkService) *nsClient { - clientURL := clienturlctx.ClientURL(ctx) - - if ns != nil { - // First check if we have already registered on some clientURL with this ns.Name. - if info, ok := s.nsInfos.Load(ns.Name); ok { - if *info.clientURL == *clientURL { - return info.client - } - - // For some reason we have changed the clientURL, so we need to close the existing client. - s.closeClient(info.client, info.clientURL.String()) + for resp := range registry.ReadNetworkServiceChannel(clientResp) { + if err := server.Send(resp); err != nil { + return err } } - var c *nsClient - <-s.executor.AsyncExec(clientURL.String(), func() { - // Fast path if we already have client for the clientURL and we should not reconnect, use it. - var loaded bool - c, loaded = s.clients.Load(clientURL.String()) - if !loaded { - // If not, create and LoadOrStore a new one. - c = s.newClient(clientURL) - s.clients.Store(clientURL.String(), c) - } - c.count++ - }) - return c + return next.NetworkServiceRegistryServer(ctx).Find(query, server) } -func (s *connectNSServer) newClient(clientURL *url.URL) *nsClient { - ctx, cancel := context.WithCancel(s.ctx) - return &nsClient{ - client: NewNetworkServiceRegistryClient(ctx, clientURL, s.clientOptions...), - count: 0, - onClose: cancel, +func (c *connectNSServer) Unregister(ctx context.Context, in *registry.NetworkService) (*empty.Empty, error) { + _, clientErr := c.client.Unregister(ctx, in, c.callOptions...) + _, serverErr := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in) + if clientErr != nil && serverErr != nil { + return nil, errors.Wrapf(serverErr, "errors during client close: %v", clientErr) } + if clientErr != nil { + return nil, errors.Wrap(clientErr, "errors during client close") + } + return &empty.Empty{}, serverErr } -func (s *connectNSServer) closeClient(c *nsClient, clientURL string) { - <-s.executor.AsyncExec(clientURL, func() { - c.count-- - if c.count == 0 { - if loadedClient, ok := s.clients.Load(clientURL); ok && c == loadedClient { - s.clients.Delete(clientURL) - } - c.onClose() - } - }) -} - -func (s *connectNSServer) deleteClient(c *nsClient, clientURL string) { - <-s.executor.AsyncExec(clientURL, func() { - if loadedClient, ok := s.clients.Load(clientURL); ok && c == loadedClient { - s.clients.Delete(clientURL) - } - c.onClose() - }) +// NewNetworkServiceRegistryServer - returns a connect chain element +func NewNetworkServiceRegistryServer(client registry.NetworkServiceRegistryClient, callOptions ...grpc.CallOption) registry.NetworkServiceRegistryServer { + return &connectNSServer{ + client: client, + callOptions: callOptions, + } } diff --git a/pkg/registry/common/connect/ns_server_test.go b/pkg/registry/common/connect/ns_server_test.go index 99cb99996..9e7c1df43 100644 --- a/pkg/registry/common/connect/ns_server_test.go +++ b/pkg/registry/common/connect/ns_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -29,9 +29,15 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/common/null" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" @@ -70,11 +76,11 @@ func waitNSServerStarted(target *url.URL) error { client := grpc_health_v1.NewHealthClient(cc) for ctx.Err() == nil { - respons, err := client.Check(ctx, healthCheckRequest) + response, err := client.Check(ctx, healthCheckRequest) if err != nil { return err } - if respons.Status == grpc_health_v1.HealthCheckResponse_SERVING { + if response.Status == grpc_health_v1.HealthCheckResponse_SERVING { return nil } } @@ -112,7 +118,17 @@ func TestConnectNSServer_AllUnregister(t *testing.T) { ignoreCurrent := goleak.IgnoreCurrent() - s := connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) + s := connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceRegistryClient(), + ), + ) _, err := s.Register(clienturlctx.WithClientURL(context.Background(), url1), ®istry.NetworkService{Name: "ns-1"}) require.NoError(t, err) @@ -150,8 +166,17 @@ func TestConnectNSServer_AllDead_Register(t *testing.T) { url1, url2, cancel1, cancel2 := startTestNSServers(ctx, t) - s := connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) - + s := connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceRegistryClient(), + ), + ) _, err := s.Register(clienturlctx.WithClientURL(ctx, url1), ®istry.NetworkService{Name: "ns-1"}) require.NoError(t, err) @@ -172,7 +197,19 @@ func TestConnectNSServer_AllDead_WatchingFind(t *testing.T) { url1, url2, cancel1, cancel2 := startTestNSServers(ctx, t) - s := connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) + s := connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceRegistryClient(), + ), + ) + + errCh := make(chan error, 2) go func() { ch := make(chan *registry.NetworkServiceResponse, 1) @@ -181,7 +218,7 @@ func TestConnectNSServer_AllDead_WatchingFind(t *testing.T) { NetworkService: new(registry.NetworkService), Watch: true, }, findSrv) - require.Error(t, err) + errCh <- err }() go func() { @@ -191,12 +228,95 @@ func TestConnectNSServer_AllDead_WatchingFind(t *testing.T) { NetworkService: new(registry.NetworkService), Watch: true, }, findSrv) - require.Error(t, err) + errCh <- err }() cancel1() cancel2() + <-errCh + <-errCh + for err, i := goleak.Find(), 0; err != nil && i < 3; err, i = goleak.Find(), i+1 { } } + +func Test_NSConenctChain_Find(t *testing.T) { + for depth := 2; depth < 11; depth++ { + for killIndex := 1; killIndex < depth; killIndex++ { + var ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var urls = make([]*url.URL, depth) + + var servers = make([]*struct { + registry.NetworkServiceRegistryServer + kill func() + }, depth) + + for i := 0; i < depth; i++ { + var serverCtx, serverCancel = context.WithCancel(ctx) + + servers[i] = &struct { + registry.NetworkServiceRegistryServer + kill func() + }{ + kill: serverCancel, + } + + urls[i] = new(url.URL) + + require.NoError(t, + startNSServer( + serverCtx, + urls[i], + servers[i], + ), + ) + } + + for i := 0; i < depth-1; i++ { + servers[i].NetworkServiceRegistryServer = chain.NewNetworkServiceRegistryServer( + clienturl.NewNetworkServiceRegistryServer(urls[i+1]), + connect.NewNetworkServiceRegistryServer( + chain.NewNetworkServiceRegistryClient( + begin.NewNetworkServiceRegistryClient(), + clientconn.NewNetworkServiceRegistryClient(), + dial.NewNetworkServiceRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceRegistryClient(), + ), + ), + ) + } + + servers[len(servers)-1].NetworkServiceRegistryServer = memory.NewNetworkServiceRegistryServer() + + c := adapters.NetworkServiceServerToClient(servers[0].NetworkServiceRegistryServer) + + _, err := c.Register(ctx, ®istry.NetworkService{ + Name: "testing", + }) + + require.NoError(t, err) + + stream, err := c.Find(ctx, ®istry.NetworkServiceQuery{ + Watch: true, + NetworkService: ®istry.NetworkService{ + Name: "testing", + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.NoError(t, err) + + servers[killIndex].kill() + + _, err = stream.Recv() + require.Error(t, err) + } + } +} diff --git a/pkg/registry/common/connect/nse_client.go b/pkg/registry/common/connect/nse_client.go index 052b90bdf..cde60b35b 100644 --- a/pkg/registry/common/connect/nse_client.go +++ b/pkg/registry/common/connect/nse_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,125 +18,44 @@ package connect import ( "context" - "net/url" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/protobuf/types/known/emptypb" + "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "google.golang.org/grpc" - "github.com/networkservicemesh/sdk/pkg/registry/core/chain" - "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" ) -type connectNSEClient struct { - ctx context.Context - client registry.NetworkServiceEndpointRegistryClient - connectTo string - dialOptions []grpc.DialOption +var errNoCCProvided = errors.New("no grpc.ClientConnInterface provided") - cc *grpc.ClientConn - lock sync.RWMutex -} - -// NewNetworkServiceEndpointRegistryClient returns a new NSE registry client chain element connecting to the remote -// NSE registry server -func NewNetworkServiceEndpointRegistryClient(ctx context.Context, connectTo *url.URL, opts ...Option) registry.NetworkServiceEndpointRegistryClient { - connectOpts := new(connectOptions) - for _, opt := range opts { - opt(connectOpts) - } +type connectNSEClient struct{} - c := &connectNSEClient{ - ctx: ctx, - client: chain.NewNetworkServiceEndpointRegistryClient( - append( - connectOpts.nseAdditionalFunctionality, - new(grpcNSEClient), - )..., - ), - connectTo: grpcutils.URLToTarget(connectTo), - dialOptions: append(append([]grpc.DialOption{}, connectOpts.dialOptions...), grpc.WithReturnConnectionError()), +func (n *connectNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - - go func() { - <-ctx.Done() - - c.lock.Lock() - defer c.lock.Unlock() - - if c.cc != nil { - _ = c.cc.Close() - } - }() - - return c + return registry.NewNetworkServiceEndpointRegistryClient(cc).Register(ctx, in, opts...) } -func (c *connectNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { - cc, err := c.getCC() - if err != nil { - return nil, err +func (n *connectNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - return c.client.Register(withCC(ctx, cc), nse, opts...) + return registry.NewNetworkServiceEndpointRegistryClient(cc).Find(ctx, in, opts...) } -func (c *connectNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { - cc, err := c.getCC() - if err != nil { - return nil, err +func (n *connectNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + cc, loaded := clientconn.Load(ctx) + if !loaded { + return nil, errNoCCProvided } - return c.client.Find(withCC(ctx, cc), query, opts...) + return registry.NewNetworkServiceEndpointRegistryClient(cc).Unregister(ctx, in, opts...) } -func (c *connectNSEClient) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*emptypb.Empty, error) { - cc, err := c.getCC() - if err != nil { - return nil, err - } - return c.client.Unregister(withCC(ctx, cc), nse, opts...) -} - -func (c *connectNSEClient) getCC() (*grpc.ClientConn, error) { - c.lock.RLock() - cc := c.cc - c.lock.RUnlock() - - if cc != nil { - return cc, nil - } - - c.lock.Lock() - defer c.lock.Unlock() - - if c.cc != nil { - return c.cc, nil - } - - var err error - if c.cc, err = grpc.DialContext(c.ctx, c.connectTo, c.dialOptions...); err != nil { - return nil, err - } - - go func() { - defer func() { - c.lock.Lock() - defer c.lock.Unlock() - - _ = c.cc.Close() - c.cc = nil - }() - for c.cc.WaitForStateChange(c.ctx, c.cc.GetState()) { - switch c.cc.GetState() { - case connectivity.Connecting, connectivity.Idle, connectivity.Ready: - continue - default: - return - } - } - }() - - return c.cc, nil +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient() registry.NetworkServiceEndpointRegistryClient { + return new(connectNSEClient) } diff --git a/pkg/registry/common/connect/nse_client_test.go b/pkg/registry/common/connect/nse_client_test.go deleted file mode 100644 index 9619e1a59..000000000 --- a/pkg/registry/common/connect/nse_client_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connect_test - -import ( - "context" - "net/url" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "google.golang.org/grpc" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/common/connect" - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" - "github.com/networkservicemesh/sdk/pkg/tools/sandbox" -) - -func TestConnectNSEClient(t *testing.T) { - t.Cleanup(func() { goleak.VerifyNone(t) }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mem := memory.NewNetworkServiceEndpointRegistryServer() - - u := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} - require.NoError(t, startNSEServer(ctx, u, mem)) - require.NoError(t, waitNSEServerStarted(u)) - - // 1. Register remote NSE - _, err := mem.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-remote"}) - require.NoError(t, err) - - c := connect.NewNetworkServiceEndpointRegistryClient(ctx, u, - connect.WithDialOptions(grpc.WithInsecure()), - ) - - // 2. Register local NSE - _, err = c.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-local"}) - require.NoError(t, err) - - // 3. Find both local, remote NSEs from client - stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), - }) - require.NoError(t, err) - - var nseNames []string - for _, nse := range registry.ReadNetworkServiceEndpointList(stream) { - nseNames = append(nseNames, nse.Name) - } - require.Len(t, nseNames, 2) - require.Subset(t, []string{"nse-remote", "nse-local"}, nseNames) - - // 4. Unregister remote NSE from client - _, err = c.Unregister(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-remote"}) - require.NoError(t, err) - - // 5. Find only local NSE in memory - ch := make(chan *registry.NetworkServiceEndpointResponse, 2) - err = mem.Find(®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), - }, streamchannel.NewNetworkServiceEndpointFindServer(ctx, ch)) - require.NoError(t, err) - - require.Len(t, ch, 1) - require.Equal(t, "nse-local", (<-ch).NetworkServiceEndpoint.Name) -} - -func TestConnectNSEClient_Restart(t *testing.T) { - t.Cleanup(func() { goleak.VerifyNone(t) }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverCtx, serverCancel := context.WithCancel(ctx) - - mem := memory.NewNetworkServiceEndpointRegistryServer() - - u := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} - require.NoError(t, startNSEServer(serverCtx, u, mem)) - require.NoError(t, waitNSEServerStarted(u)) - - c := connect.NewNetworkServiceEndpointRegistryClient(ctx, u, - connect.WithDialOptions(grpc.WithInsecure()), - ) - - // 1. Register NSE-1 with client - _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) - require.NoError(t, err) - - // 2. Restart remote - serverCancel() - require.Eventually(t, func() bool { - return sandbox.CheckURLFree(u) - }, time.Second, 10*time.Millisecond) - - require.NoError(t, startNSEServer(ctx, u, mem)) - require.NoError(t, waitNSEServerStarted(u)) - - // 3. Register NSE-2 with client - require.Eventually(t, func() bool { - _, err = c.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-2"}) - return err == nil - }, time.Second, 10*time.Millisecond) - - // 4. Find both NSE-1, NSE-2 in memory - ch := make(chan *registry.NetworkServiceEndpointResponse, 2) - err = mem.Find(®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), - }, streamchannel.NewNetworkServiceEndpointFindServer(ctx, ch)) - require.NoError(t, err) - - var nseNames []string - for i := len(ch); i > 0; i-- { - nseNames = append(nseNames, (<-ch).NetworkServiceEndpoint.Name) - } - require.Len(t, nseNames, 2) - require.Subset(t, []string{"nse-1", "nse-2"}, nseNames) -} diff --git a/pkg/registry/common/connect/nse_grpc_client.go b/pkg/registry/common/connect/nse_grpc_client.go deleted file mode 100644 index 0aff677fd..000000000 --- a/pkg/registry/common/connect/nse_grpc_client.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connect - -import ( - "context" - - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/emptypb" - - "github.com/networkservicemesh/api/pkg/api/registry" -) - -type grpcNSEClient struct{} - -func (c *grpcNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { - return registry.NewNetworkServiceEndpointRegistryClient(ccFromContext(ctx)).Register(ctx, nse, opts...) -} - -func (c *grpcNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { - return registry.NewNetworkServiceEndpointRegistryClient(ccFromContext(ctx)).Find(ctx, query, opts...) -} - -func (c *grpcNSEClient) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*emptypb.Empty, error) { - return registry.NewNetworkServiceEndpointRegistryClient(ccFromContext(ctx)).Unregister(ctx, nse, opts...) -} diff --git a/pkg/registry/common/connect/nse_server.go b/pkg/registry/common/connect/nse_server.go index 0aa5b199b..96dc1a8fa 100644 --- a/pkg/registry/common/connect/nse_server.go +++ b/pkg/registry/common/connect/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,157 +14,76 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package connect TODO package connect import ( "context" - "net/url" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" + "google.golang.org/grpc" - "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" - "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" + "github.com/networkservicemesh/api/pkg/api/registry" ) type connectNSEServer struct { - ctx context.Context - clientOptions []Option - - nseInfos nseInfoMap - clients nseClientMap - executor multiexecutor.MultiExecutor -} - -type nseInfo struct { - clientURL *url.URL - client *nseClient + client registry.NetworkServiceEndpointRegistryClient + callOptions []grpc.CallOption } -type nseClient struct { - client registry.NetworkServiceEndpointRegistryClient - count int - onClose context.CancelFunc -} - -// NewNetworkServiceEndpointRegistryServer - server chain element that creates client subchains and requests them selecting by -// clienturlctx.ClientURL(ctx) -func NewNetworkServiceEndpointRegistryServer( - ctx context.Context, - clientOptions ...Option, -) registry.NetworkServiceEndpointRegistryServer { - return &connectNSEServer{ - ctx: ctx, - clientOptions: clientOptions, +func (c *connectNSEServer) Register(ctx context.Context, in *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + closeCtxFunc := postpone.ContextWithValues(ctx) + clientResp, clientErr := c.client.Register(ctx, in, c.callOptions...) + if clientErr != nil { + return nil, clientErr } -} -func (s *connectNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - clientURL := clienturlctx.ClientURL(ctx) - if clientURL == nil { - return nil, errors.Errorf("clientURL not found for incoming endpoint: %+v", nse) + serverResp, serverErr := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, clientResp) + if serverErr != nil { + closeCtx, closeCancel := closeCtxFunc() + defer closeCancel() + _, _ = c.client.Unregister(closeCtx, clientResp, c.callOptions...) } - - _, loaded := s.nseInfos.Load(nse.Name) - - c := s.client(ctx, nse) - reg, err := c.client.Register(ctx, nse) - if err != nil { - if !loaded { - s.closeClient(c, clientURL.String()) - } - return nil, err - } - - s.nseInfos.Store(nse.Name, &nseInfo{ - clientURL: clientURL, - client: c, - }) - - return reg, nil + return serverResp, serverErr } -func (s *connectNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - clientURL := clienturlctx.ClientURL(server.Context()) - if clientURL == nil { - return errors.Errorf("clientURL not found for incoming query: %+v", query) - } - - c := s.client(server.Context(), nil) +func (c *connectNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + ctx := server.Context() - err := adapters.NetworkServiceEndpointClientToServer(c.client).Find(query, server) - - s.closeClient(c, clientURL.String()) - - return err -} - -func (s *connectNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - clientURL := clienturlctx.ClientURL(ctx) - if clientURL == nil { - return nil, errors.Errorf("clientURL not found for incoming endpoint: %+v", nse) + clientResp, clientErr := c.client.Find(ctx, query, c.callOptions...) + if clientErr != nil { + return clientErr } - c := s.client(ctx, nse) - - _, err := c.client.Unregister(ctx, nse) - - s.closeClient(c, clientURL.String()) - s.nseInfos.Delete(nse.Name) - - return new(empty.Empty), err -} - -func (s *connectNSEServer) client(ctx context.Context, nse *registry.NetworkServiceEndpoint) *nseClient { - clientURL := clienturlctx.ClientURL(ctx) - - if nse != nil { - // First check if we have already registered on some clientURL with this nse.Name. - if info, ok := s.nseInfos.Load(nse.Name); ok { - if *info.clientURL == *clientURL { - return info.client - } - - // For some reason we have changed the clientURL, so we need to close the existing client. - s.closeClient(info.client, info.clientURL.String()) + for resp := range registry.ReadNetworkServiceEndpointChannel(clientResp) { + if err := server.Send(resp); err != nil { + return err } } - var c *nseClient - <-s.executor.AsyncExec(clientURL.String(), func() { - // Fast path if we already have client for the clientURL and we should not reconnect, use it. - var loaded bool - c, loaded = s.clients.Load(clientURL.String()) - if !loaded { - // If not, create and LoadOrStore a new one. - c = s.newClient(clientURL) - s.clients.Store(clientURL.String(), c) - } - c.count++ - }) - return c + return next.NetworkServiceEndpointRegistryServer(ctx).Find(query, server) } -func (s *connectNSEServer) newClient(clientURL *url.URL) *nseClient { - ctx, cancel := context.WithCancel(s.ctx) - return &nseClient{ - client: NewNetworkServiceEndpointRegistryClient(ctx, clientURL, s.clientOptions...), - count: 0, - onClose: cancel, +func (c *connectNSEServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + _, clientErr := c.client.Unregister(ctx, in, c.callOptions...) + _, serverErr := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in) + if clientErr != nil && serverErr != nil { + return nil, errors.Wrapf(serverErr, "errors during client close: %v", clientErr) } + if clientErr != nil { + return nil, errors.Wrap(clientErr, "errors during client close") + } + return &empty.Empty{}, serverErr } -func (s *connectNSEServer) closeClient(c *nseClient, clientURL string) { - <-s.executor.AsyncExec(clientURL, func() { - c.count-- - if c.count == 0 { - if loadedClient, ok := s.clients.Load(clientURL); ok && c == loadedClient { - s.clients.Delete(clientURL) - } - c.onClose() - } - }) +// NewNetworkServiceEndpointRegistryServer - returns a connect chain element +func NewNetworkServiceEndpointRegistryServer(client registry.NetworkServiceEndpointRegistryClient, callOptions ...grpc.CallOption) registry.NetworkServiceEndpointRegistryServer { + return &connectNSEServer{ + client: client, + callOptions: callOptions, + } } diff --git a/pkg/registry/common/connect/nse_server_test.go b/pkg/registry/common/connect/nse_server_test.go index 13479e8d5..2e4989b0d 100644 --- a/pkg/registry/common/connect/nse_server_test.go +++ b/pkg/registry/common/connect/nse_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -29,9 +29,15 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/common/null" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" @@ -112,7 +118,17 @@ func TestConnectNSEServer_AllUnregister(t *testing.T) { ignoreCurrent := goleak.IgnoreCurrent() - s := connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) + s := connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ) _, err := s.Register(clienturlctx.WithClientURL(context.Background(), url1), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) require.NoError(t, err) @@ -150,8 +166,17 @@ func TestConnectNSEServer_AllDead_Register(t *testing.T) { url1, url2, cancel1, cancel2 := startTestNSEServers(ctx, t) - s := connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) - + s := connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ) _, err := s.Register(clienturlctx.WithClientURL(ctx, url1), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) require.NoError(t, err) @@ -172,7 +197,19 @@ func TestConnectNSEServer_AllDead_WatchingFind(t *testing.T) { url1, url2, cancel1, cancel2 := startTestNSEServers(ctx, t) - s := connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())) + s := connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ) + + errCh := make(chan error, 2) go func() { ch := make(chan *registry.NetworkServiceEndpointResponse, 1) @@ -181,7 +218,7 @@ func TestConnectNSEServer_AllDead_WatchingFind(t *testing.T) { NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), Watch: true, }, findSrv) - require.Error(t, err) + errCh <- err }() go func() { @@ -191,12 +228,95 @@ func TestConnectNSEServer_AllDead_WatchingFind(t *testing.T) { NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), Watch: true, }, findSrv) - require.Error(t, err) + errCh <- err }() cancel1() cancel2() + <-errCh + <-errCh + for err, i := goleak.Find(), 0; err != nil && i < 3; err, i = goleak.Find(), i+1 { } } + +func Test_ConenctNSEChain_Find(t *testing.T) { + for depth := 2; depth < 11; depth++ { + for killIndex := 1; killIndex < depth; killIndex++ { + var ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var urls = make([]*url.URL, depth) + + var servers = make([]*struct { + registry.NetworkServiceEndpointRegistryServer + kill func() + }, depth) + + for i := 0; i < depth; i++ { + var serverCtx, serverCancel = context.WithCancel(ctx) + + servers[i] = &struct { + registry.NetworkServiceEndpointRegistryServer + kill func() + }{ + kill: serverCancel, + } + + urls[i] = new(url.URL) + + require.NoError(t, + startNSEServer( + serverCtx, + urls[i], + servers[i], + ), + ) + } + + for i := 0; i < depth-1; i++ { + servers[i].NetworkServiceEndpointRegistryServer = chain.NewNetworkServiceEndpointRegistryServer( + clienturl.NewNetworkServiceEndpointRegistryServer(urls[i+1]), + connect.NewNetworkServiceEndpointRegistryServer( + chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(grpc.WithInsecure()), + dial.WithDialTimeout(time.Second), + ), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ), + ) + } + + servers[len(servers)-1].NetworkServiceEndpointRegistryServer = memory.NewNetworkServiceEndpointRegistryServer() + + c := adapters.NetworkServiceEndpointServerToClient(servers[0].NetworkServiceEndpointRegistryServer) + + _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: "testing", + }) + + require.NoError(t, err) + + stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ + Watch: true, + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: "testing", + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.NoError(t, err) + + servers[killIndex].kill() + + _, err = stream.Recv() + require.Error(t, err) + } + } +} diff --git a/pkg/registry/common/connect/option.go b/pkg/registry/common/connect/option.go deleted file mode 100644 index 290c2c47d..000000000 --- a/pkg/registry/common/connect/option.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connect - -import ( - "google.golang.org/grpc" - - "github.com/networkservicemesh/api/pkg/api/registry" -) - -// Option is an option pattern for NewNetworkServiceRegistryClient, NewNetworkServiceEndpointRegistryClient -type Option func(connectOpts *connectOptions) - -// WithNSAdditionalFunctionality sets additional functionality -func WithNSAdditionalFunctionality(additionalFunctionality ...registry.NetworkServiceRegistryClient) Option { - return func(connectOpts *connectOptions) { - connectOpts.nsAdditionalFunctionality = additionalFunctionality - } -} - -// WithNSEAdditionalFunctionality sets additional functionality -func WithNSEAdditionalFunctionality(additionalFunctionality ...registry.NetworkServiceEndpointRegistryClient) Option { - return func(connectOpts *connectOptions) { - connectOpts.nseAdditionalFunctionality = additionalFunctionality - } -} - -// WithDialOptions sets dial options -func WithDialOptions(dialOptions ...grpc.DialOption) Option { - return func(connectOpts *connectOptions) { - connectOpts.dialOptions = dialOptions - } -} - -type connectOptions struct { - nsAdditionalFunctionality []registry.NetworkServiceRegistryClient - nseAdditionalFunctionality []registry.NetworkServiceEndpointRegistryClient - dialOptions []grpc.DialOption -} diff --git a/pkg/registry/common/dial/dialer.go b/pkg/registry/common/dial/dialer.go new file mode 100644 index 000000000..351ed6f92 --- /dev/null +++ b/pkg/registry/common/dial/dialer.go @@ -0,0 +1,108 @@ +// Copyright (c) 2021-2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dial + +import ( + "context" + "net/url" + "runtime" + "time" + + "github.com/pkg/errors" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" +) + +type dialer struct { + ctx context.Context + cleanupContext context.Context + clientURL *url.URL + cleanupCancel context.CancelFunc + *grpc.ClientConn + dialOptions []grpc.DialOption + dialTimeout time.Duration +} + +func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer { + return &dialer{ + ctx: ctx, + dialOptions: dialOptions, + dialTimeout: dialTimeout, + } +} + +func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { + if di == nil { + return errors.New("cannot call dialer.Dial on nil dialer") + } + // Cleanup any previous grpc.ClientConn + if di.cleanupCancel != nil { + di.cleanupCancel() + } + + // Set the clientURL + di.clientURL = clientURL + + // Setup dialTimeout if needed + dialCtx := ctx + if di.dialTimeout != 0 { + dialCtx, _ = clock.FromContext(di.ctx).WithTimeout(dialCtx, di.dialTimeout) + } + + // Dial + target := grpcutils.URLToTarget(di.clientURL) + cc, err := grpc.DialContext(dialCtx, target, di.dialOptions...) + if err != nil { + if cc != nil { + _ = cc.Close() + } + return errors.Wrapf(err, "failed to dial %s", target) + } + di.ClientConn = cc + + di.cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) + + go func(cleanupContext context.Context, cc *grpc.ClientConn) { + <-cleanupContext.Done() + _ = cc.Close() + }(di.cleanupContext, cc) + return nil +} + +func (di *dialer) Close() error { + if di != nil && di.cleanupCancel != nil { + di.cleanupCancel() + runtime.Gosched() + } + return nil +} + +func (di *dialer) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + if di.ClientConn == nil { + return errors.New("no dialer.ClientConn found") + } + return di.ClientConn.Invoke(ctx, method, args, reply, opts...) +} + +func (di *dialer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if di.ClientConn == nil { + return nil, errors.New("no dialer.ClientConn found") + } + return di.ClientConn.NewStream(ctx, desc, method, opts...) +} diff --git a/pkg/registry/common/dial/ns_client.go b/pkg/registry/common/dial/ns_client.go new file mode 100644 index 000000000..e520d9f4d --- /dev/null +++ b/pkg/registry/common/dial/ns_client.go @@ -0,0 +1,173 @@ +// Copyright (c) 2021-2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package dial will dial up a grpc.ClientConnInterface if a client *url.URL is provided in the ctx, retrievable by +// clienturlctx.ClientURL(ctx) and put the resulting grpc.ClientConnInterface into the ctx using clientconn.Store(..) +// where it can be retrieved by other chain elements using clientconn.Load(...) +package dial + +import ( + "context" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" + + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type dialNSClient struct { + chainCtx context.Context + dialOptions []grpc.DialOption + dialTimeout time.Duration +} + +func (c *dialNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + closeContextFunc := postpone.ContextWithValues(ctx) + // If no clientURL, we have no work to do + // call the next in the chain + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + } + + cc, _ := clientconn.LoadOrStore(ctx, newDialer(c.chainCtx, c.dialTimeout, c.dialOptions...)) + + // If there's an existing grpc.ClientConnInterface and it's not ours, call the next in the chain + di, ok := cc.(*dialer) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + } + + // If our existing dialer has a different URL close down the chain + if di.clientURL != nil && di.clientURL.String() != clientURL.String() { + closeCtx, closeCancel := closeContextFunc() + defer closeCancel() + err := di.Dial(closeCtx, di.clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not redial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(di.clientURL), err) + clientconn.Delete(ctx) + return nil, err + } + _, _ = next.NetworkServiceRegistryClient(ctx).Unregister(clienturlctx.WithClientURL(closeCtx, di.clientURL), in, opts...) + } + + err := di.Dial(ctx, clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not dial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(clientURL), err) + clientconn.Delete(ctx) + return nil, err + } + + conn, err := next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + if err != nil { + _ = di.Close() + return nil, err + } + return conn, nil +} +func (c *dialNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + // If no clientURL, we have no work to do + // call the next in the chain + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + cc, _ := clientconn.Load(ctx) + + di, ok := cc.(*dialer) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + } + defer func() { + _ = di.Close() + clientconn.Delete(ctx) + }() + _ = di.Dial(ctx, clientURL) + + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +type dialNSFindClient struct { + registry.NetworkServiceRegistry_FindClient + closeFn func() +} + +func (c *dialNSFindClient) Recv() (*registry.NetworkServiceResponse, error) { + resp, err := c.NetworkServiceRegistry_FindClient.Recv() + if err != nil { + c.closeFn() + } + return resp, err +} + +func (c *dialNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) + } + + di := newDialer(c.chainCtx, c.dialTimeout, c.dialOptions...) + + findCtx, cancel := context.WithCancel(ctx) + + err := di.Dial(findCtx, clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not dial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(clientURL), err) + cancel() + return nil, err + } + + clientconn.Store(ctx, di) + + resp, err := next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) + if err != nil { + _ = di.Close() + cancel() + return nil, err + } + + return &dialNSFindClient{ + NetworkServiceRegistry_FindClient: resp, + closeFn: func() { + cancel() + _ = di.Close() + }, + }, nil +} + +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient(chainCtx context.Context, opts ...Option) registry.NetworkServiceRegistryClient { + o := &option{ + dialTimeout: time.Millisecond * 100, + } + for _, opt := range opts { + opt(o) + } + return &dialNSClient{ + chainCtx: chainCtx, + dialOptions: o.dialOptions, + dialTimeout: o.dialTimeout, + } +} diff --git a/pkg/registry/common/dial/nse_client.go b/pkg/registry/common/dial/nse_client.go new file mode 100644 index 000000000..acf9bca24 --- /dev/null +++ b/pkg/registry/common/dial/nse_client.go @@ -0,0 +1,171 @@ +// Copyright (c) 2021-2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package dial will dial up a grpc.ClientConnInterface if a client *url.URL is provided in the ctx, retrievable by +// clienturlctx.ClientURL(ctx) and put the resulting grpc.ClientConnInterface into the ctx using clientconn.Store(..) +// where it can be retrieved by other chain elements using clientconn.Load(...) +package dial + +import ( + "context" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/postpone" + + "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type dialNSEClient struct { + chainCtx context.Context + dialOptions []grpc.DialOption + dialTimeout time.Duration +} + +func (c *dialNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + closeContextFunc := postpone.ContextWithValues(ctx) + // If no clientURL, we have no work to do + // call the next in the chain + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + } + + cc, _ := clientconn.LoadOrStore(ctx, newDialer(c.chainCtx, c.dialTimeout, c.dialOptions...)) + + // If there's an existing grpc.ClientConnInterface and it's not ours, call the next in the chain + di, ok := cc.(*dialer) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + } + + // If our existing dialer has a different URL close down the chain + if di.clientURL != nil && di.clientURL.String() != clientURL.String() { + closeCtx, closeCancel := closeContextFunc() + defer closeCancel() + err := di.Dial(closeCtx, di.clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not redial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(di.clientURL), err) + clientconn.Delete(ctx) + return nil, err + } + _, _ = next.NetworkServiceEndpointRegistryClient(ctx).Unregister(clienturlctx.WithClientURL(closeCtx, di.clientURL), in, opts...) + } + + err := di.Dial(ctx, clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not dial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(clientURL), err) + clientconn.Delete(ctx) + return nil, err + } + + conn, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + if err != nil { + _ = di.Close() + return nil, err + } + return conn, nil +} +func (c *dialNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + // If no clientURL, we have no work to do + // call the next in the chain + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + cc, _ := clientconn.Load(ctx) + + di, ok := cc.(*dialer) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + } + defer func() { + _ = di.Close() + clientconn.Delete(ctx) + }() + _ = di.Dial(ctx, clientURL) + + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +type dialNSEFindClient struct { + registry.NetworkServiceEndpointRegistry_FindClient + closeFn func() +} + +func (c *dialNSEFindClient) Recv() (*registry.NetworkServiceEndpointResponse, error) { + resp, err := c.NetworkServiceEndpointRegistry_FindClient.Recv() + if err != nil { + c.closeFn() + } + return resp, err +} + +func (c *dialNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + clientURL := clienturlctx.ClientURL(ctx) + if clientURL == nil { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) + } + + di := newDialer(c.chainCtx, c.dialTimeout, c.dialOptions...) + + findCtx, cancel := context.WithCancel(ctx) + + err := di.Dial(findCtx, clientURL) + if err != nil { + log.FromContext(ctx).Errorf("can not dial to %v, err %v. Deleting clientconn...", grpcutils.URLToTarget(clientURL), err) + cancel() + return nil, err + } + + clientconn.Store(ctx, di) + + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) + if err != nil { + _ = di.Close() + cancel() + return nil, err + } + + return &dialNSEFindClient{ + NetworkServiceEndpointRegistry_FindClient: resp, + closeFn: func() { + cancel() + _ = di.Close() + }, + }, nil +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient(chainCtx context.Context, opts ...Option) registry.NetworkServiceEndpointRegistryClient { + o := &option{} + for _, opt := range opts { + opt(o) + } + return &dialNSEClient{ + chainCtx: chainCtx, + dialOptions: o.dialOptions, + dialTimeout: o.dialTimeout, + } +} diff --git a/pkg/registry/common/localbypass/find_server.go b/pkg/registry/common/dial/options.go similarity index 50% rename from pkg/registry/common/localbypass/find_server.go rename to pkg/registry/common/dial/options.go index 1599ce7cf..bcb44b964 100644 --- a/pkg/registry/common/localbypass/find_server.go +++ b/pkg/registry/common/dial/options.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,25 +14,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -package localbypass +package dial import ( - "github.com/networkservicemesh/api/pkg/api/registry" + "time" + + "google.golang.org/grpc" ) -type localBypassNSEFindServer struct { - *localBypassNSEServer - registry.NetworkServiceEndpointRegistry_FindServer +type option struct { + dialOptions []grpc.DialOption + dialTimeout time.Duration } -func (s *localBypassNSEFindServer) Send(nseResp *registry.NetworkServiceEndpointResponse) error { - if u, ok := s.nseURLs.Load(nseResp.NetworkServiceEndpoint.Name); ok { - nseResp.NetworkServiceEndpoint.Url = u.String() - } +// Option - options for the dial chain element +type Option func(*option) - if nseResp.GetNetworkServiceEndpoint().GetUrl() == s.nsmgrURL && !nseResp.Deleted { - return nil +// WithDialOptions - grpc.DialOptions for use by the dial chain element +func WithDialOptions(dialOptions ...grpc.DialOption) Option { + return func(o *option) { + o.dialOptions = dialOptions } +} - return s.NetworkServiceEndpointRegistry_FindServer.Send(nseResp) +// WithDialTimeout - dialTimeout for use by dial chain element. +func WithDialTimeout(dialTimeout time.Duration) Option { + return func(o *option) { + o.dialTimeout = dialTimeout + } } diff --git a/pkg/tools/serializectx/executor.go b/pkg/registry/common/expire/gen.go similarity index 62% rename from pkg/tools/serializectx/executor.go rename to pkg/registry/common/expire/gen.go index fd2d0e6aa..508af29c9 100644 --- a/pkg/tools/serializectx/executor.go +++ b/pkg/registry/common/expire/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -14,15 +14,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package serializectx +package expire -// Executor is a wrapper around `serialize.Executor.AsyncExec` + ID -type Executor struct { - id string - asyncExec func(f func()) <-chan struct{} -} +import "sync" -// AsyncExec is a `serialize.Executor.AsyncExec` -func (e *Executor) AsyncExec(f func()) <-chan struct{} { - return e.asyncExec(f) -} +//go:generate go-syncmap -output sync_map.gen.go -type cancelsMap + +// cancelsMap is like a Go map[string]context.CancelFunc but is safe for concurrent use +// by multiple goroutines without additional locking or coordination +type cancelsMap sync.Map diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index 9e8a5fe24..bca5375f4 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -25,64 +25,71 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" - "github.com/networkservicemesh/sdk/pkg/tools/expire" "github.com/networkservicemesh/sdk/pkg/tools/log" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" ) type expireNSEServer struct { - expireManager *expire.Manager nseExpiration time.Duration + ctx context.Context + cancelsMap } // NewNetworkServiceEndpointRegistryServer creates a new NetworkServiceServer chain element that implements unregister // of expired connections for the subsequent chain elements. func NewNetworkServiceEndpointRegistryServer(ctx context.Context, nseExpiration time.Duration) registry.NetworkServiceEndpointRegistryServer { return &expireNSEServer{ - expireManager: expire.NewManager(ctx), nseExpiration: nseExpiration, + ctx: ctx, } } func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - clockTime := clock.FromContext(ctx) - logger := log.FromContext(ctx).WithField("expireNSEServer", "Register") + factory := begin.FromContext(ctx) + timeClock := clock.FromContext(ctx) + expirationTime := timeClock.Now().Add(s.nseExpiration).Local() - s.expireManager.Stop(nse.Name) + logger := log.FromContext(ctx).WithField("expireNSEServer", "Register") - expirationTime := clockTime.Now().Add(s.nseExpiration) - if nse.ExpirationTime != nil { - if nseExpirationTime := nse.ExpirationTime.AsTime().Local(); nseExpirationTime.Before(expirationTime) { + if nse.GetExpirationTime() != nil { + if nseExpirationTime := nse.GetExpirationTime().AsTime().Local(); nseExpirationTime.Before(expirationTime) { expirationTime = nseExpirationTime + logger.Infof("selected expiration time %v for %v", expirationTime, nse.GetName()) } } + nse.ExpirationTime = timestamppb.New(expirationTime) - reg, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) + resp, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { - s.expireManager.Start(nse.Name) return nil, err } - unregisterNSE := reg.Clone() - if unregisterNSE.Name != nse.Name { - s.expireManager.Delete(nse.Name) + if nseExpirationTime := resp.GetExpirationTime().AsTime().Local(); nseExpirationTime.Before(expirationTime) { + expirationTime = nseExpirationTime + logger.Infof("selected expiration time %v for %v", expirationTime, resp.GetName()) } - s.expireManager.New( - serializectx.GetExecutor(ctx, unregisterNSE.Name), - unregisterNSE.Name, - unregisterNSE.ExpirationTime.AsTime().Local(), - func(unregisterCtx context.Context) { - if _, unregisterErr := next.NetworkServiceEndpointRegistryServer(ctx).Unregister(unregisterCtx, unregisterNSE); unregisterErr != nil { - logger.Errorf("failed to unregister expired endpoint: %s %s", unregisterNSE.Name, unregisterErr.Error()) - } - }, - ) - - return reg, nil + expireContext, cancel := context.WithCancel(s.ctx) + if v, ok := s.cancelsMap.LoadAndDelete(nse.GetName()); ok { + v() + } + s.cancelsMap.Store(nse.GetName(), cancel) + + expireCh := timeClock.After(timeClock.Until(expirationTime.Local())) + + go func() { + select { + case <-expireContext.Done(): + return + case <-expireCh: + factory.Unregister(begin.CancelContext(expireContext)) + } + }() + + return resp, nil } func (s *expireNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { @@ -90,17 +97,8 @@ func (s *expireNSEServer) Find(query *registry.NetworkServiceEndpointQuery, serv } func (s *expireNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - logger := log.FromContext(ctx).WithField("expireNSEServer", "Unregister") - - if s.expireManager.Stop(nse.Name) { - if _, err := next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse); err != nil { - s.expireManager.Start(nse.Name) - return nil, err - } - s.expireManager.Delete(nse.Name) - } else { - logger.Warnf("endpoint has been already unregistered: %s", nse.Name) + if oldCancel, loaded := s.LoadAndDelete(nse.Name); loaded { + oldCancel() } - - return new(empty.Empty), nil + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index 47dce1994..0b53f7646 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -31,21 +31,20 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/common/expire" "github.com/networkservicemesh/sdk/pkg/registry/common/localbypass" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" - "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checknse" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) const ( - expireTimeout = time.Minute + expireTimeout = time.Second nseName = "nse" testWait = 100 * time.Millisecond testTick = testWait / 100 @@ -81,7 +80,7 @@ func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testin ctx = clock.WithClock(ctx, clockMock) s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), ) @@ -91,7 +90,7 @@ func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testin }) require.NoError(t, err) - require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout) + require.Equal(t, expireTimeout, clockMock.Until(resp.ExpirationTime.AsTime().Local())) } func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing.T) { @@ -106,7 +105,7 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), mem, ) @@ -136,17 +135,17 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { ctx = clock.WithClock(ctx, clockMock) s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), // <-- GRPC invocation - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout/2), ) resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) require.NoError(t, err) - require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout/2) + require.Equal(t, expireTimeout/2, clockMock.Until(resp.ExpirationTime.AsTime())) } func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) { @@ -161,7 +160,7 @@ func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) { mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), // <-- GRPC invocation mem, @@ -195,7 +194,7 @@ func TestExpireNSEServer_DataRace(t *testing.T) { mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, 0), localbypass.NewNetworkServiceEndpointRegistryServer("tcp://0.0.0.0"), mem, @@ -225,11 +224,11 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) { ctx = clock.WithClock(ctx, clockMock) c := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( new(remoteNSEServer), // <-- GRPC invocation - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), injecterror.NewNetworkServiceEndpointRegistryServer( injecterror.WithRegisterErrorTimes(1, -1), @@ -262,13 +261,14 @@ func TestExpireNSEServer_UnregisterFailure(t *testing.T) { mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), + begin.NewNetworkServiceEndpointRegistryServer(), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), injecterror.NewNetworkServiceEndpointRegistryServer( injecterror.WithRegisterErrorTimes(), injecterror.WithFindErrorTimes(), injecterror.WithUnregisterErrorTimes(0), ), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), mem, ) @@ -306,18 +306,15 @@ func TestExpireNSEServer_RefreshKeepsNoUnregister(t *testing.T) { unregisterServer := new(unregisterNSEServer) c := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), - adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( - // NSMgr chain - new(remoteNSEServer), // <-- GRPC invocation - serialize.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), - checknse.NewServer(t, func(*testing.T, *registry.NetworkServiceEndpoint) { - clockMock.Add(expireTimeout / 2) - }), - unregisterServer, - )), + adapters.NetworkServiceEndpointServerToClient( + next.NewNetworkServiceEndpointRegistryServer( + // NSMgr chain + new(remoteNSEServer), // <-- GRPC invocation + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + unregisterServer, + )), ) _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{ @@ -326,7 +323,7 @@ func TestExpireNSEServer_RefreshKeepsNoUnregister(t *testing.T) { require.NoError(t, err) for i := 0; i < 3; i++ { - clockMock.Add(expireTimeout/2 - time.Millisecond) + clockMock.Add(expireTimeout*2/3 + time.Millisecond) require.Never(t, func() bool { return atomic.LoadInt32(&unregisterServer.unregisterCount) > 0 }, testWait, testTick) diff --git a/pkg/registry/common/connect/nse_info_map.gen.go b/pkg/registry/common/expire/sync_map.gen.go similarity index 59% rename from pkg/registry/common/connect/nse_info_map.gen.go rename to pkg/registry/common/expire/sync_map.gen.go index 2b8ea8016..d3a5dc632 100644 --- a/pkg/registry/common/connect/nse_info_map.gen.go +++ b/pkg/registry/common/expire/sync_map.gen.go @@ -1,58 +1,59 @@ -// Code generated by "-output nse_info_map.gen.go -type nseInfoMap -output nse_info_map.gen.go -type nseInfoMap"; DO NOT EDIT. -package connect +// Code generated by "-output sync_map.gen.go -type cancelsMap -output sync_map.gen.go -type cancelsMap"; DO NOT EDIT. +package expire import ( + "context" "sync" // Used by sync.Map. ) // Generate code that will fail if the constants change value. func _() { - // An "cannot convert nseInfoMap literal (type nseInfoMap) to type sync.Map" compiler error signifies that the base type have changed. + // An "cannot convert cancelsMap literal (type cancelsMap) to type sync.Map" compiler error signifies that the base type have changed. // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(nseInfoMap{}) + _ = (sync.Map)(cancelsMap{}) } -var _nil_nseInfoMap_nseInfo_value = func() (val *nseInfo) { return }() +var _nil_cancelsMap_context_CancelFunc_value = func() (val context.CancelFunc) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *nseInfoMap) Load(key string) (*nseInfo, bool) { +func (m *cancelsMap) Load(key string) (context.CancelFunc, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_nseInfoMap_nseInfo_value, ok + return _nil_cancelsMap_context_CancelFunc_value, ok } - return value.(*nseInfo), ok + return value.(context.CancelFunc), ok } // Store sets the value for a key. -func (m *nseInfoMap) Store(key string, value *nseInfo) { +func (m *cancelsMap) Store(key string, value context.CancelFunc) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *nseInfoMap) LoadOrStore(key string, value *nseInfo) (*nseInfo, bool) { +func (m *cancelsMap) LoadOrStore(key string, value context.CancelFunc) (context.CancelFunc, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_nseInfoMap_nseInfo_value, loaded + return _nil_cancelsMap_context_CancelFunc_value, loaded } - return actual.(*nseInfo), loaded + return actual.(context.CancelFunc), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *nseInfoMap) LoadAndDelete(key string) (value *nseInfo, loaded bool) { +func (m *cancelsMap) LoadAndDelete(key string) (value context.CancelFunc, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_nseInfoMap_nseInfo_value, loaded + return _nil_cancelsMap_context_CancelFunc_value, loaded } - return actual.(*nseInfo), loaded + return actual.(context.CancelFunc), loaded } // Delete deletes the value for a key. -func (m *nseInfoMap) Delete(key string) { +func (m *cancelsMap) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -66,8 +67,8 @@ func (m *nseInfoMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *nseInfoMap) Range(f func(key string, value *nseInfo) bool) { +func (m *cancelsMap) Range(f func(key string, value context.CancelFunc) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nseInfo)) + return f(key.(string), value.(context.CancelFunc)) }) } diff --git a/pkg/registry/common/heal/gen.go b/pkg/registry/common/heal/gen.go index 8ea93fe15..dc1e6aec9 100644 --- a/pkg/registry/common/heal/gen.go +++ b/pkg/registry/common/heal/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,8 +20,8 @@ import ( "sync" ) -//go:generate go-syncmap -output nse_info_map.gen.go -type nseInfoMap -//go:generate go-syncmap -output ns_info_map.gen.go -type nsInfoMap +//go:generate go-syncmap -output sync_map.gen.go -type cancelsMap -type nseInfoMap sync.Map -type nsInfoMap sync.Map +// cancelsMap is like a Go map[string]context.CancelFunc but is safe for concurrent use +// by multiple goroutines without additional locking or coordination +type cancelsMap sync.Map diff --git a/pkg/registry/common/heal/ns_client.go b/pkg/registry/common/heal/ns_client.go index 40add4ebe..64f1e62c2 100644 --- a/pkg/registry/common/heal/ns_client.go +++ b/pkg/registry/common/heal/ns_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,73 +18,62 @@ package heal import ( "context" - "sync" "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" "google.golang.org/grpc" "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/addressof" - "github.com/networkservicemesh/sdk/pkg/tools/extend" - "github.com/networkservicemesh/sdk/pkg/tools/log" ) type healNSClient struct { - ctx context.Context - onHeal *registry.NetworkServiceRegistryClient - nsInfos nsInfoMap - - stream registry.NetworkServiceRegistry_FindClient - healCancel context.CancelFunc - lock sync.RWMutex -} - -type nsInfo struct { - ns *registry.NetworkService - ctx context.Context - cancel context.CancelFunc + ctx context.Context + cancelsMap } // NewNetworkServiceRegistryClient returns a new NS registry client responsible for healing -func NewNetworkServiceRegistryClient(ctx context.Context, onHeal *registry.NetworkServiceRegistryClient) registry.NetworkServiceRegistryClient { - c := &healNSClient{ - ctx: ctx, - onHeal: onHeal, - healCancel: func() {}, - } - if c.onHeal == nil { - c.onHeal = addressof.NetworkServiceRegistryClient(c) +func NewNetworkServiceRegistryClient(ctx context.Context) registry.NetworkServiceRegistryClient { + return &healNSClient{ + ctx: ctx, } - return c } func (c *healNSClient) Register(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { - nsCtx, nsCancel := context.WithCancel(c.ctx) - _, loaded := c.nsInfos.LoadOrStore(ns.Name, &nsInfo{ - ns: ns.Clone(), - ctx: nsCtx, - cancel: nsCancel, - }) - if loaded { - nsCancel() - } + resp, err := next.NetworkServiceRegistryClient(ctx).Register(ctx, ns, opts...) - if err := c.startMonitor(ctx, opts); err != nil { + if err != nil { return nil, err } - reg, err := next.NetworkServiceRegistryClient(ctx).Register(ctx, ns, opts...) - if err != nil { - if !loaded { - nsCancel() - c.nsInfos.Delete(ns.Name) - } - return nil, err + factory := begin.FromContext(ctx) + + if v, ok := c.LoadAndDelete(ns.GetName()); ok { + v() } + healCtx, cancel := context.WithCancel(c.ctx) + + stream, streamErr := next.NetworkServiceRegistryClient(ctx).Find(healCtx, ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: ns.GetName()}, Watch: true}, opts...) + + if streamErr != nil { + cancel() + return nil, streamErr + } + + c.Store(ns.GetName(), cancel) - return reg, nil + go func() { + for { + _, recvErr := stream.Recv() + if recvErr != nil { + factory.Register(begin.CancelContext(healCtx)) + return + } + } + }() + + return resp, err } func (c *healNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { @@ -94,13 +83,15 @@ func (c *healNSClient) Find(ctx context.Context, query *registry.NetworkServiceQ query = proto.Clone(query).(*registry.NetworkServiceQuery) + nextClient := next.NetworkServiceRegistryClient(ctx) + createStream := func() (registry.NetworkServiceRegistry_FindClient, error) { queryClone := proto.Clone(query).(*registry.NetworkServiceQuery) - return (*c.onHeal).Find(withNSFindHealing(ctx), queryClone, opts...) + return nextClient.Find(withNSFindHealing(ctx), queryClone, opts...) } queryClone := proto.Clone(query).(*registry.NetworkServiceQuery) - stream, err := next.NetworkServiceRegistryClient(ctx).Find(ctx, queryClone, opts...) + stream, err := nextClient.Find(ctx, queryClone, opts...) if err != nil { return nil, err } @@ -122,101 +113,8 @@ func (c *healNSClient) Find(ctx context.Context, query *registry.NetworkServiceQ } func (c *healNSClient) Unregister(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { - info, loaded := c.nsInfos.LoadAndDelete(ns.Name) - if !loaded { - return new(empty.Empty), nil + if v, loaded := c.LoadAndDelete(ns.Name); loaded { + v() } - - info.cancel() - return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) } - -func (c *healNSClient) startMonitor(ctx context.Context, opts []grpc.CallOption) error { - logger := log.FromContext(c.ctx).WithField("healNSClient", "startMonitor") - - c.lock.RLock() - stream := c.stream - c.lock.RUnlock() - - if stream != nil { - return nil - } - - c.lock.Lock() - - if c.stream != nil { - c.lock.Unlock() - return nil - } - - findCtx, findCancel := context.WithCancel(c.ctx) - findCtx = extend.WithValuesFromContext(findCtx, ctx) - - query := ®istry.NetworkServiceQuery{ - NetworkService: new(registry.NetworkService), - Watch: true, - } - - var err error - c.stream, err = next.NetworkServiceRegistryClient(ctx).Find(findCtx, query, opts...) - - c.lock.Unlock() - - if err != nil { - logger.Warn("NS client failed") - findCancel() - return err - } - - logger.Info("NS client ready") - - go func() { - defer findCancel() - c.monitor(opts) - }() - - return nil -} - -func (c *healNSClient) monitor(opts []grpc.CallOption) { - for _, err := c.stream.Recv(); err == nil; _, err = c.stream.Recv() { - } - c.healCancel() - - c.lock.Lock() - defer c.lock.Unlock() - - c.restore(opts) -} - -func (c *healNSClient) restore(opts []grpc.CallOption) { - log.FromContext(c.ctx).WithField("healNSClient", "restore").Warn("NS client restoring") - - c.stream = nil - - var healCtx context.Context - healCtx, c.healCancel = context.WithCancel(c.ctx) - - c.nsInfos.Range(func(name string, info *nsInfo) bool { - go func() { - nsCtx, nsCancel := context.WithCancel(extend.WithValuesFromContext(healCtx, context.Background())) - defer nsCancel() - - go func() { - select { - case <-nsCtx.Done(): - case <-info.ctx.Done(): - } - nsCancel() - }() - - for nsCtx.Err() == nil { - if _, err := (*c.onHeal).Register(nsCtx, info.ns.Clone(), opts...); err == nil { - return - } - } - }() - return true - }) -} diff --git a/pkg/registry/common/heal/ns_info_map.gen.go b/pkg/registry/common/heal/ns_info_map.gen.go deleted file mode 100644 index 6ceb986f2..000000000 --- a/pkg/registry/common/heal/ns_info_map.gen.go +++ /dev/null @@ -1,73 +0,0 @@ -// Code generated by "-output ns_info_map.gen.go -type nsInfoMap -output ns_info_map.gen.go -type nsInfoMap"; DO NOT EDIT. -package heal - -import ( - "sync" // Used by sync.Map. -) - -// Generate code that will fail if the constants change value. -func _() { - // An "cannot convert nsInfoMap literal (type nsInfoMap) to type sync.Map" compiler error signifies that the base type have changed. - // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(nsInfoMap{}) -} - -var _nil_nsInfoMap_nsInfo_value = func() (val *nsInfo) { return }() - -// Load returns the value stored in the map for a key, or nil if no -// value is present. -// The ok result indicates whether value was found in the map. -func (m *nsInfoMap) Load(key string) (*nsInfo, bool) { - value, ok := (*sync.Map)(m).Load(key) - if value == nil { - return _nil_nsInfoMap_nsInfo_value, ok - } - return value.(*nsInfo), ok -} - -// Store sets the value for a key. -func (m *nsInfoMap) Store(key string, value *nsInfo) { - (*sync.Map)(m).Store(key, value) -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -func (m *nsInfoMap) LoadOrStore(key string, value *nsInfo) (*nsInfo, bool) { - actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) - if actual == nil { - return _nil_nsInfoMap_nsInfo_value, loaded - } - return actual.(*nsInfo), loaded -} - -// LoadAndDelete deletes the value for a key, returning the previous value if any. -// The loaded result reports whether the key was present. -func (m *nsInfoMap) LoadAndDelete(key string) (value *nsInfo, loaded bool) { - actual, loaded := (*sync.Map)(m).LoadAndDelete(key) - if actual == nil { - return _nil_nsInfoMap_nsInfo_value, loaded - } - return actual.(*nsInfo), loaded -} - -// Delete deletes the value for a key. -func (m *nsInfoMap) Delete(key string) { - (*sync.Map)(m).Delete(key) -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// -// Range does not necessarily correspond to any consistent snapshot of the Map's -// contents: no key will be visited more than once, but if the value for any key -// is stored or deleted concurrently, Range may reflect any mapping for that key -// from any point during the Range call. -// -// Range may be O(N) with the number of elements in the map even if f returns -// false after a constant number of calls. -func (m *nsInfoMap) Range(f func(key string, value *nsInfo) bool) { - (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nsInfo)) - }) -} diff --git a/pkg/registry/common/heal/nse_client.go b/pkg/registry/common/heal/nse_client.go index e7136e48c..b60fa77a0 100644 --- a/pkg/registry/common/heal/nse_client.go +++ b/pkg/registry/common/heal/nse_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,73 +18,62 @@ package heal import ( "context" - "sync" "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" "google.golang.org/grpc" "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/addressof" - "github.com/networkservicemesh/sdk/pkg/tools/extend" - "github.com/networkservicemesh/sdk/pkg/tools/log" ) type healNSEClient struct { - ctx context.Context - onHeal *registry.NetworkServiceEndpointRegistryClient - nseInfos nseInfoMap - - stream registry.NetworkServiceEndpointRegistry_FindClient - healCancel context.CancelFunc - lock sync.RWMutex -} - -type nseInfo struct { - nse *registry.NetworkServiceEndpoint - ctx context.Context - cancel context.CancelFunc + ctx context.Context + cancelsMap } // NewNetworkServiceEndpointRegistryClient returns a new NSE registry client responsible for healing -func NewNetworkServiceEndpointRegistryClient(ctx context.Context, onHeal *registry.NetworkServiceEndpointRegistryClient) registry.NetworkServiceEndpointRegistryClient { - c := &healNSEClient{ - ctx: ctx, - onHeal: onHeal, - healCancel: func() {}, - } - if c.onHeal == nil { - c.onHeal = addressof.NetworkServiceEndpointRegistryClient(c) +func NewNetworkServiceEndpointRegistryClient(ctx context.Context) registry.NetworkServiceEndpointRegistryClient { + return &healNSEClient{ + ctx: ctx, } - return c } func (c *healNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { - nseCtx, nseCancel := context.WithCancel(c.ctx) - _, loaded := c.nseInfos.LoadOrStore(nse.Name, &nseInfo{ - nse: nse.Clone(), - ctx: nseCtx, - cancel: nseCancel, - }) - if loaded { - nseCancel() - } + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) - if err := c.startMonitor(ctx, opts); err != nil { + if err != nil { return nil, err } - reg, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) - if err != nil { - if !loaded { - nseCancel() - c.nseInfos.Delete(nse.Name) - } - return nil, err + factory := begin.FromContext(ctx) + + if v, ok := c.LoadAndDelete(nse.GetName()); ok { + v() } + healCtx, cancel := context.WithCancel(c.ctx) + + stream, streamErr := next.NetworkServiceEndpointRegistryClient(ctx).Find(healCtx, ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: nse.GetName()}, Watch: true}, opts...) + + if streamErr != nil { + cancel() + return nil, streamErr + } + + c.Store(nse.GetName(), cancel) - return reg, nil + go func() { + for { + _, recvErr := stream.Recv() + if recvErr != nil { + factory.Register(begin.CancelContext(healCtx)) + return + } + } + }() + + return resp, err } func (c *healNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { @@ -94,13 +83,15 @@ func (c *healNSEClient) Find(ctx context.Context, query *registry.NetworkService query = proto.Clone(query).(*registry.NetworkServiceEndpointQuery) + nextClient := next.NetworkServiceEndpointRegistryClient(ctx) + createStream := func() (registry.NetworkServiceEndpointRegistry_FindClient, error) { queryClone := proto.Clone(query).(*registry.NetworkServiceEndpointQuery) - return (*c.onHeal).Find(withNSEFindHealing(ctx), queryClone, opts...) + return nextClient.Find(withNSEFindHealing(ctx), queryClone, opts...) } queryClone := proto.Clone(query).(*registry.NetworkServiceEndpointQuery) - stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, queryClone, opts...) + stream, err := nextClient.Find(ctx, queryClone, opts...) if err != nil { return nil, err } @@ -122,101 +113,8 @@ func (c *healNSEClient) Find(ctx context.Context, query *registry.NetworkService } func (c *healNSEClient) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { - info, loaded := c.nseInfos.LoadAndDelete(nse.Name) - if !loaded { - return new(empty.Empty), nil + if v, loaded := c.LoadAndDelete(nse.Name); loaded { + v() } - - info.cancel() - return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) } - -func (c *healNSEClient) startMonitor(ctx context.Context, opts []grpc.CallOption) error { - logger := log.FromContext(c.ctx).WithField("healNSEClient", "startMonitor") - - c.lock.RLock() - stream := c.stream - c.lock.RUnlock() - - if stream != nil { - return nil - } - - c.lock.Lock() - - if c.stream != nil { - c.lock.Unlock() - return nil - } - - findCtx, findCancel := context.WithCancel(c.ctx) - findCtx = extend.WithValuesFromContext(findCtx, ctx) - - query := ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), - Watch: true, - } - - var err error - c.stream, err = next.NetworkServiceEndpointRegistryClient(ctx).Find(findCtx, query, opts...) - - c.lock.Unlock() - - if err != nil { - logger.Warn("NSE client failed") - findCancel() - return err - } - - logger.Info("NSE client ready") - - go func() { - defer findCancel() - c.monitor(opts) - }() - - return nil -} - -func (c *healNSEClient) monitor(opts []grpc.CallOption) { - for _, err := c.stream.Recv(); err == nil; _, err = c.stream.Recv() { - } - c.healCancel() - - c.lock.Lock() - defer c.lock.Unlock() - - c.restore(opts) -} - -func (c *healNSEClient) restore(opts []grpc.CallOption) { - log.FromContext(c.ctx).WithField("healNSEClient", "restore").Warn("NSE client restoring") - - c.stream = nil - - var healCtx context.Context - healCtx, c.healCancel = context.WithCancel(c.ctx) - - c.nseInfos.Range(func(name string, info *nseInfo) bool { - go func() { - nseCtx, nseCancel := context.WithCancel(extend.WithValuesFromContext(healCtx, context.Background())) - defer nseCancel() - - go func() { - select { - case <-nseCtx.Done(): - case <-info.ctx.Done(): - } - nseCancel() - }() - - for nseCtx.Err() == nil { - if _, err := (*c.onHeal).Register(nseCtx, info.nse.Clone(), opts...); err == nil { - return - } - } - }() - return true - }) -} diff --git a/pkg/registry/common/heal/nse_find_client.go b/pkg/registry/common/heal/nse_find_client.go index 375006bf6..6420f7279 100644 --- a/pkg/registry/common/heal/nse_find_client.go +++ b/pkg/registry/common/heal/nse_find_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // diff --git a/pkg/registry/common/heal/nse_info_map.gen.go b/pkg/registry/common/heal/sync_map.gen.go similarity index 59% rename from pkg/registry/common/heal/nse_info_map.gen.go rename to pkg/registry/common/heal/sync_map.gen.go index 52aa28954..e37bb8fc8 100644 --- a/pkg/registry/common/heal/nse_info_map.gen.go +++ b/pkg/registry/common/heal/sync_map.gen.go @@ -1,58 +1,59 @@ -// Code generated by "-output nse_info_map.gen.go -type nseInfoMap -output nse_info_map.gen.go -type nseInfoMap"; DO NOT EDIT. +// Code generated by "-output sync_map.gen.go -type cancelsMap -output sync_map.gen.go -type cancelsMap"; DO NOT EDIT. package heal import ( + "context" "sync" // Used by sync.Map. ) // Generate code that will fail if the constants change value. func _() { - // An "cannot convert nseInfoMap literal (type nseInfoMap) to type sync.Map" compiler error signifies that the base type have changed. + // An "cannot convert cancelsMap literal (type cancelsMap) to type sync.Map" compiler error signifies that the base type have changed. // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(nseInfoMap{}) + _ = (sync.Map)(cancelsMap{}) } -var _nil_nseInfoMap_nseInfo_value = func() (val *nseInfo) { return }() +var _nil_cancelsMap_context_CancelFunc_value = func() (val context.CancelFunc) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *nseInfoMap) Load(key string) (*nseInfo, bool) { +func (m *cancelsMap) Load(key string) (context.CancelFunc, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_nseInfoMap_nseInfo_value, ok + return _nil_cancelsMap_context_CancelFunc_value, ok } - return value.(*nseInfo), ok + return value.(context.CancelFunc), ok } // Store sets the value for a key. -func (m *nseInfoMap) Store(key string, value *nseInfo) { +func (m *cancelsMap) Store(key string, value context.CancelFunc) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *nseInfoMap) LoadOrStore(key string, value *nseInfo) (*nseInfo, bool) { +func (m *cancelsMap) LoadOrStore(key string, value context.CancelFunc) (context.CancelFunc, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_nseInfoMap_nseInfo_value, loaded + return _nil_cancelsMap_context_CancelFunc_value, loaded } - return actual.(*nseInfo), loaded + return actual.(context.CancelFunc), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *nseInfoMap) LoadAndDelete(key string) (value *nseInfo, loaded bool) { +func (m *cancelsMap) LoadAndDelete(key string) (value context.CancelFunc, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_nseInfoMap_nseInfo_value, loaded + return _nil_cancelsMap_context_CancelFunc_value, loaded } - return actual.(*nseInfo), loaded + return actual.(context.CancelFunc), loaded } // Delete deletes the value for a key. -func (m *nseInfoMap) Delete(key string) { +func (m *cancelsMap) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -66,8 +67,8 @@ func (m *nseInfoMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *nseInfoMap) Range(f func(key string, value *nseInfo) bool) { +func (m *cancelsMap) Range(f func(key string, value context.CancelFunc) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*nseInfo)) + return f(key.(string), value.(context.CancelFunc)) }) } diff --git a/pkg/registry/common/interdomainbypass/server.go b/pkg/registry/common/interdomainbypass/server.go new file mode 100644 index 000000000..dfa61c1f3 --- /dev/null +++ b/pkg/registry/common/interdomainbypass/server.go @@ -0,0 +1,96 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package interdomainbypass provides registry chain element that sets to outgoing NSE the public nsmgr-proxy and stores into the shared map the public nsmgr URL from the incoming endpoint. +package interdomainbypass + +import ( + "context" + "net/url" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/stringurl" +) + +type interdomainBypassNSEServer struct { + m *stringurl.Map + u *url.URL +} + +type interdomainBypassNSEFindServer struct { + m *stringurl.Map + u *url.URL + registry.NetworkServiceEndpointRegistry_FindServer +} + +func (n *interdomainBypassNSEServer) Register(ctx context.Context, service *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + var originalURL = service.Url + service.Url = n.u.String() + + resp, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, service) + + if err != nil { + return nil, err + } + + u, _ := url.Parse(originalURL) + + n.m.Store(service.Name, u) + + resp.Url = originalURL + + return resp, err +} + +func (n *interdomainBypassNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, &interdomainBypassNSEFindServer{NetworkServiceEndpointRegistry_FindServer: server, m: n.m, u: n.u}) +} + +func (n *interdomainBypassNSEServer) Unregister(ctx context.Context, service *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + n.m.Delete(service.Name) + var originalURL = service.Url + service.Url = n.u.String() + defer func() { + service.Url = originalURL + }() + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, service) +} + +// NewNetworkServiceEndpointRegistryServer creates new instance of interdomainbypass NSE server. +// It simply stores into passed stringurl.Map all incoming nse.Name:nse.URL entries. +// And sets passed URL for outgoing NSEs. +func NewNetworkServiceEndpointRegistryServer(m *stringurl.Map, u *url.URL) registry.NetworkServiceEndpointRegistryServer { + if m == nil { + panic("m can not be nil") + } + if u == nil { + panic("u can not be nil") + } + return &interdomainBypassNSEServer{m: m, u: u} +} + +func (s *interdomainBypassNSEFindServer) Send(nseResp *registry.NetworkServiceEndpointResponse) error { + u, err := url.Parse(nseResp.GetNetworkServiceEndpoint().GetUrl()) + if err != nil { + return err + } + s.m.LoadOrStore(nseResp.NetworkServiceEndpoint.GetName(), u) + nseResp.GetNetworkServiceEndpoint().Url = s.u.String() + return s.NetworkServiceEndpointRegistry_FindServer.Send(nseResp) +} diff --git a/pkg/registry/common/localbypass/server.go b/pkg/registry/common/localbypass/server.go index 9fa1b0dfd..2cf99264d 100644 --- a/pkg/registry/common/localbypass/server.go +++ b/pkg/registry/common/localbypass/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -31,6 +31,23 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/stringurl" ) +type localBypassNSEFindServer struct { + *localBypassNSEServer + registry.NetworkServiceEndpointRegistry_FindServer +} + +func (s *localBypassNSEFindServer) Send(nseResp *registry.NetworkServiceEndpointResponse) error { + if u, ok := s.nseURLs.Load(nseResp.NetworkServiceEndpoint.Name); ok { + nseResp.NetworkServiceEndpoint.Url = u.String() + } + + if nseResp.GetNetworkServiceEndpoint().GetUrl() == s.nsmgrURL && !nseResp.Deleted { + return nil + } + + return s.NetworkServiceEndpointRegistry_FindServer.Send(nseResp) +} + type localBypassNSEServer struct { nsmgrURL string nseURLs stringurl.Map diff --git a/pkg/registry/common/memory/nse_server.go b/pkg/registry/common/memory/nse_server.go index 871e9c2e1..f74fcc909 100644 --- a/pkg/registry/common/memory/nse_server.go +++ b/pkg/registry/common/memory/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -161,7 +161,7 @@ func (s *memoryNSEServer) receiveEvent( } func (s *memoryNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - if unregisterNSE, ok := s.networkServiceEndpoints.LoadAndDelete(nse.Name); ok { + if unregisterNSE, ok := s.networkServiceEndpoints.LoadAndDelete(nse.GetName()); ok { unregisterNSE = unregisterNSE.Clone() s.sendEvent(®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: unregisterNSE, Deleted: true}) } diff --git a/pkg/registry/common/proxy/common_test.go b/pkg/registry/common/proxy/common_test.go deleted file mode 100644 index bca0a8a93..000000000 --- a/pkg/registry/common/proxy/common_test.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy_test - -import ( - "context" - "net" - "net/url" - "testing" - - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/stretchr/testify/require" - "google.golang.org/grpc" - - "github.com/networkservicemesh/sdk/pkg/registry/common/connect" - "github.com/networkservicemesh/sdk/pkg/registry/common/proxy" - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" -) - -func startNSEServer(t *testing.T, chain registry.NetworkServiceEndpointRegistryServer) (u *url.URL, closeFunc func()) { - s := grpc.NewServer() - registry.RegisterNetworkServiceEndpointRegistryServer(s, chain) - grpcutils.RegisterHealthServices(s, chain) - l, err := net.Listen("tcp", "127.0.0.1:0") - require.Nil(t, err) - closeFunc = func() { - _ = l.Close() - } - go func() { - _ = s.Serve(l) - }() - u, err = url.Parse("tcp://" + l.Addr().String()) - if err != nil { - closeFunc() - } - require.Nil(t, err) - return u, closeFunc -} - -func startNSServer(t *testing.T, chain registry.NetworkServiceRegistryServer) (u *url.URL, closeFunc func()) { - s := grpc.NewServer() - registry.RegisterNetworkServiceRegistryServer(s, chain) - grpcutils.RegisterHealthServices(s, chain) - l, err := net.Listen("tcp", "127.0.0.1:0") - require.Nil(t, err) - closeFunc = func() { - _ = l.Close() - } - go func() { - _ = s.Serve(l) - }() - u, err = url.Parse("tcp://" + l.Addr().String()) - if err != nil { - closeFunc() - } - require.Nil(t, err) - return u, closeFunc -} - -func testingNSEServerChain(ctx context.Context, u *url.URL) registry.NetworkServiceEndpointRegistryServer { - return next.NewNetworkServiceEndpointRegistryServer( - proxy.NewNetworkServiceEndpointRegistryServer(u), - connect.NewNetworkServiceEndpointRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())), - ) -} - -func testingNSServerChain(ctx context.Context, u *url.URL) registry.NetworkServiceRegistryServer { - return next.NewNetworkServiceRegistryServer( - proxy.NewNetworkServiceRegistryServer(u), - connect.NewNetworkServiceRegistryServer(ctx, connect.WithDialOptions(grpc.WithInsecure())), - ) -} diff --git a/pkg/registry/common/proxy/ns_server.go b/pkg/registry/common/proxy/ns_server.go deleted file mode 100644 index 53f69eb62..000000000 --- a/pkg/registry/common/proxy/ns_server.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "context" - "net/url" - - "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/registry/core/streamcontext" - "github.com/networkservicemesh/sdk/pkg/tools/interdomain" -) - -type nsServer struct { - proxyRegistryURL *url.URL - matchFunc func(name string) bool -} - -func (n *nsServer) Register(ctx context.Context, nse *registry.NetworkService) (*registry.NetworkService, error) { - if !interdomain.Is(nse.Name) { - return nse, nil - } - ctx = clienturlctx.WithClientURL(ctx, n.proxyRegistryURL) - return next.NetworkServiceRegistryServer(ctx).Register(ctx, nse) -} - -func (n *nsServer) Find(q *registry.NetworkServiceQuery, s registry.NetworkServiceRegistry_FindServer) error { - if !interdomain.Is(q.NetworkService.Name) { - return nil - } - ctx := clienturlctx.WithClientURL(s.Context(), n.proxyRegistryURL) - return next.NetworkServiceRegistryServer(ctx).Find(q, streamcontext.NetworkServiceRegistryFindServer(ctx, s)) -} - -func (n *nsServer) Unregister(ctx context.Context, nse *registry.NetworkService) (*empty.Empty, error) { - if !interdomain.Is(nse.Name) { - return new(empty.Empty), nil - } - ctx = clienturlctx.WithClientURL(ctx, n.proxyRegistryURL) - return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, nse) -} - -// NewNetworkServiceRegistryServer creates new NetworkServiceRegistryServer that can proxying interdomain upstream to the remote registry by URL -func NewNetworkServiceRegistryServer(proxyRegistryURL *url.URL) registry.NetworkServiceRegistryServer { - return &nsServer{ - proxyRegistryURL: proxyRegistryURL, - matchFunc: interdomain.Is, - } -} diff --git a/pkg/registry/common/proxy/ns_server_test.go b/pkg/registry/common/proxy/ns_server_test.go deleted file mode 100644 index d0d21a7f1..000000000 --- a/pkg/registry/common/proxy/ns_server_test.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy_test - -import ( - "context" - "runtime" - "testing" - "time" - - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" -) - -func TestNewProxyNetworkServiceRegistryServer_Register(t *testing.T) { - m := memory.NewNetworkServiceRegistryServer() - u, closeServer := startNSServer(t, m) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - chain := testingNSServerChain(ctx, u) - - _, err := chain.Register(context.Background(), ®istry.NetworkService{Name: "nse-1"}) - require.NoError(t, err) - _, err = chain.Register(context.Background(), ®istry.NetworkService{Name: "nse-2@domain"}) - require.NoError(t, err) - _, err = chain.Register(context.Background(), ®istry.NetworkService{Name: "nse-3"}) - require.NoError(t, err) - - client := adapters.NetworkServiceServerToClient(m) - - stream, err := client.Find(context.Background(), ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: "nse"}}) - require.NoError(t, err) - list := registry.ReadNetworkServiceList(stream) - require.Len(t, list, 1) - require.Equal(t, "nse-2@domain", list[0].Name) - - closeServer() - - require.Eventually(t, func() bool { - runtime.GC() - return goleak.Find() != nil - }, time.Second, time.Microsecond*100) -} - -func TestNewProxyNetworkServiceRegistryServer_Unregister(t *testing.T) { - m := memory.NewNetworkServiceRegistryServer() - _, err := m.Register(context.Background(), ®istry.NetworkService{Name: "nse-1@domain1"}) - require.Nil(t, err) - u, closeServer := startNSServer(t, m) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - chain := testingNSServerChain(ctx, u) - - checkLen := func(expected int) { - client := adapters.NetworkServiceServerToClient(m) - stream, findErr := client.Find(context.Background(), ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: "nse"}}) - require.NoError(t, findErr) - list := registry.ReadNetworkServiceList(stream) - require.Len(t, list, expected) - } - - _, err = chain.Unregister(context.Background(), ®istry.NetworkService{Name: "nse-1"}) - require.Nil(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkService{Name: "nse"}) - require.Nil(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkService{Name: "nse-1@domain2"}) - require.Nil(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkService{Name: "nse-1@domain1"}) - require.Nil(t, err) - checkLen(0) - - closeServer() - - require.Eventually(t, func() bool { - runtime.GC() - return goleak.Find() != nil - }, time.Second, time.Microsecond*100) -} - -func TestNewProxyNetworkServiceRegistryServer_Find(t *testing.T) { - m := memory.NewNetworkServiceRegistryServer() - _, err := m.Register(context.Background(), ®istry.NetworkService{Name: "nse-1@domain1"}) - require.Nil(t, err) - u, closeServer := startNSServer(t, m) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - chain := testingNSServerChain(ctx, u) - - checkLen := func(nseName string, expected int) { - client := adapters.NetworkServiceServerToClient(chain) - stream, err := client.Find(context.Background(), ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: nseName}}) - require.NoError(t, err) - list := registry.ReadNetworkServiceList(stream) - require.Len(t, list, expected) - } - - checkLen("nse", 0) - checkLen("nse-1@domain1", 1) - - closeServer() - - require.Eventually(t, func() bool { - runtime.GC() - return goleak.Find() != nil - }, time.Second, time.Microsecond*100) -} diff --git a/pkg/registry/common/proxy/nse_server.go b/pkg/registry/common/proxy/nse_server.go deleted file mode 100644 index 29c573bf6..000000000 --- a/pkg/registry/common/proxy/nse_server.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "context" - "net/url" - - "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/registry/core/streamcontext" - "github.com/networkservicemesh/sdk/pkg/tools/interdomain" -) - -type nseServer struct { - proxyRegistryURL *url.URL -} - -func (n *nseServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - if !interdomain.Is(nse.Name) { - return nse, nil - } - ctx = clienturlctx.WithClientURL(ctx, n.proxyRegistryURL) - return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) -} - -func (n nseServer) Find(q *registry.NetworkServiceEndpointQuery, s registry.NetworkServiceEndpointRegistry_FindServer) error { - if !isInterdomain(q.NetworkServiceEndpoint) { - return nil - } - ctx := clienturlctx.WithClientURL(s.Context(), n.proxyRegistryURL) - return next.NetworkServiceEndpointRegistryServer(ctx).Find(q, streamcontext.NetworkServiceEndpointRegistryFindServer(ctx, s)) -} - -func (n *nseServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - if !interdomain.Is(nse.Name) { - return new(empty.Empty), nil - } - ctx = clienturlctx.WithClientURL(ctx, n.proxyRegistryURL) - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) -} - -// NewNetworkServiceEndpointRegistryServer creates new NetworkServiceEndpointRegistryServer that can proxying interdomain upstream to the remote registry by URL -func NewNetworkServiceEndpointRegistryServer(proxyRegistryURL *url.URL) registry.NetworkServiceEndpointRegistryServer { - return &nseServer{ - proxyRegistryURL: proxyRegistryURL, - } -} - -func isInterdomain(nse *registry.NetworkServiceEndpoint) bool { - if interdomain.Is(nse.Name) { - return true - } - for _, ns := range nse.NetworkServiceNames { - if interdomain.Is(ns) { - return true - } - } - return false -} diff --git a/pkg/registry/common/proxy/nse_server_test.go b/pkg/registry/common/proxy/nse_server_test.go deleted file mode 100644 index f0fd06e9e..000000000 --- a/pkg/registry/common/proxy/nse_server_test.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy_test - -import ( - "context" - "runtime" - "testing" - "time" - - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" -) - -func TestNewProxyNetworkServiceEndpointRegistryServer_Register(t *testing.T) { - m := memory.NewNetworkServiceEndpointRegistryServer() - u, closeServer := startNSEServer(t, m) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - chain := testingNSEServerChain(ctx, u) - - _, err := chain.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) - require.NoError(t, err) - _, err = chain.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-2@domain"}) - require.NoError(t, err) - _, err = chain.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-3"}) - require.NoError(t, err) - - client := adapters.NetworkServiceEndpointServerToClient(m) - - stream, err := client.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: "nse"}}) - require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, 1) - require.Equal(t, "nse-2@domain", list[0].Name) - - closeServer() - - require.Eventually(t, func() bool { - runtime.GC() - return goleak.Find() != nil - }, time.Second, time.Microsecond*100) -} - -func TestNewProxyNetworkServiceEndpointRegistryServer_Unregister(t *testing.T) { - m := memory.NewNetworkServiceEndpointRegistryServer() - _, err := m.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1@domain1"}) - require.Nil(t, err) - u, closeServer := startNSEServer(t, m) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - chain := testingNSEServerChain(ctx, u) - - checkLen := func(expected int) { - client := adapters.NetworkServiceEndpointServerToClient(m) - stream, findErr := client.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: "nse"}}) - require.NoError(t, findErr) - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, expected) - } - - _, err = chain.Unregister(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) - require.NoError(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse"}) - require.NoError(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1@domain2"}) - require.NoError(t, err) - checkLen(1) - _, err = chain.Unregister(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1@domain1"}) - require.NoError(t, err) - checkLen(0) - - closeServer() - - require.Eventually(t, func() bool { - runtime.GC() - return goleak.Find() != nil - }, time.Second, time.Microsecond*100) -} - -func TestNewProxyNetworkServiceEndpointRegistryServer_Find(t *testing.T) { - m := memory.NewNetworkServiceEndpointRegistryServer() - _, err := m.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1@domain1"}) - require.Nil(t, err) - u, closeServer := startNSEServer(t, m) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - chain := testingNSEServerChain(ctx, u) - - checkLen := func(nseName string, expected int) { - client := adapters.NetworkServiceEndpointServerToClient(chain) - stream, err := client.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: nseName}}) - require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, expected) - } - - checkLen("nse", 0) - checkLen("nse-1@domain1", 1) - - closeServer() - require.Eventually(t, func() bool { - runtime.GC() - return true - }, time.Second, time.Microsecond*100) -} diff --git a/pkg/registry/common/recvfd/server_test.go b/pkg/registry/common/recvfd/server_test.go index d70007991..8e316cc46 100644 --- a/pkg/registry/common/recvfd/server_test.go +++ b/pkg/registry/common/recvfd/server_test.go @@ -35,13 +35,16 @@ import ( "google.golang.org/grpc/credentials/insecure" registryserver "github.com/networkservicemesh/sdk/pkg/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/registry/common/connect" + "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" registryrecvfd "github.com/networkservicemesh/sdk/pkg/registry/common/recvfd" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" - registryserialize "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" - registrychain "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/tools/grpcfdutils" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" @@ -71,15 +74,15 @@ func TestNseRecvfdServerClosesFile(t *testing.T) { var ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - var nsRegistry = registrychain.NewNetworkServiceRegistryServer( - registryserialize.NewNetworkServiceRegistryServer(), + var nsRegistry = chain.NewNetworkServiceRegistryServer( + begin.NewNetworkServiceRegistryServer(), memory.NewNetworkServiceRegistryServer(), ) var onFileClosedCallbacks = make(map[string]func()) - var nseRegistry = registrychain.NewNetworkServiceEndpointRegistryServer( - registryserialize.NewNetworkServiceEndpointRegistryServer(), + var nseRegistry = chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), checkcontext.NewNSEServer(t, func(t *testing.T, c context.Context) { err := grpcfdutils.InjectOnFileReceivedCallback(c, func(inodeURLStr string, file *os.File) { runtime.SetFinalizer(file, func(file *os.File) { @@ -113,14 +116,21 @@ func TestNseRecvfdServerClosesFile(t *testing.T) { sandbox.WithInsecureStreamRPCCredentials(), } - var nseClient = registrychain.NewNetworkServiceEndpointRegistryClient( - registryserialize.NewNetworkServiceEndpointRegistryClient(), + var nseClient = chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), - connect.NewNetworkServiceEndpointRegistryClient(ctx, regURL, - connect.WithNSEAdditionalFunctionality( - sendfd.NewNetworkServiceEndpointRegistryClient()), - connect.WithDialOptions(dialOptions...), - )) + + chain.NewNetworkServiceEndpointRegistryClient( + clienturl.NewNetworkServiceEndpointRegistryClient(regURL), + clientconn.NewNetworkServiceEndpointRegistryClient(), + dial.NewNetworkServiceEndpointRegistryClient(ctx, + dial.WithDialOptions(dialOptions...), + dial.WithDialTimeout(time.Second), + ), + sendfd.NewNetworkServiceEndpointRegistryClient(), + connect.NewNetworkServiceEndpointRegistryClient(), + ), + ) startServer(ctx, t, registryserver.NewServer(nsRegistry, nseRegistry), regURL) diff --git a/pkg/registry/common/refresh/doc.go b/pkg/registry/common/refresh/doc.go index 85e0b1d57..90b420f7a 100644 --- a/pkg/registry/common/refresh/doc.go +++ b/pkg/registry/common/refresh/doc.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // diff --git a/pkg/registry/common/refresh/gen.go b/pkg/registry/common/refresh/gen.go index 9b3d88d56..1ed4bfa1a 100644 --- a/pkg/registry/common/refresh/gen.go +++ b/pkg/registry/common/refresh/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // diff --git a/pkg/registry/common/refresh/nse_registry_client.go b/pkg/registry/common/refresh/nse_registry_client.go index 34a797cec..ce154c377 100644 --- a/pkg/registry/common/refresh/nse_registry_client.go +++ b/pkg/registry/common/refresh/nse_registry_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,25 +18,20 @@ package refresh import ( "context" - "time" "github.com/golang/protobuf/ptypes/empty" - "github.com/pkg/errors" "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" - "github.com/networkservicemesh/sdk/pkg/tools/log" - "github.com/networkservicemesh/sdk/pkg/tools/postpone" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" ) type refreshNSEClient struct { - ctx context.Context - nseCancels cancelsMap + ctx context.Context + cancelsMap } // NewNetworkServiceEndpointRegistryClient creates new NetworkServiceEndpointRegistryClient that will refresh expiration @@ -48,103 +43,38 @@ func NewNetworkServiceEndpointRegistryClient(ctx context.Context) registry.Netwo } func (c *refreshNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { - clockTime := clock.FromContext(ctx) - logger := log.FromContext(ctx).WithField("refreshNSEClient", "Register") + var factory = begin.FromContext(ctx) - var expirationDuration time.Duration - if nse.ExpirationTime != nil { - expirationDuration = clockTime.Until(nse.ExpirationTime.AsTime().Local()) - } - - cancel, ok := c.nseCancels.LoadAndDelete(nse.Name) - if ok { - cancel() - } + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) - postponeCtxFunc := postpone.ContextWithValues(ctx) - - reg, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) if err != nil { return nil, err } - if reg.ExpirationTime != nil { - refreshNSE := nse.Clone() - refreshNSE.ExpirationTime = reg.ExpirationTime - refreshNSE.InitialRegistrationTime = reg.InitialRegistrationTime - - cancel, err = c.startRefresh(ctx, refreshNSE, expirationDuration) - if err != nil { - unregisterCtx, cancelUnregister := postponeCtxFunc() - defer cancelUnregister() + refreshCtx, cancel := context.WithCancel(c.ctx) - if _, unregisterErr := next.NetworkServiceEndpointRegistryServer(ctx).Unregister(unregisterCtx, reg); unregisterErr != nil { - logger.Errorf("failed to unregister endpoint on error: %s %s", reg.Name, unregisterErr.Error()) - } - return nil, err - } - - c.nseCancels.Store(refreshNSE.Name, cancel) + if cancelPrevious, ok := c.LoadAndDelete(nse.Name); ok { + cancelPrevious() } - return reg, err -} + c.Store(nse.Name, cancel) -func (c *refreshNSEClient) startRefresh(ctx context.Context, nse *registry.NetworkServiceEndpoint, expirationDuration time.Duration) (context.CancelFunc, error) { - clockTime := clock.FromContext(ctx) - logger := log.FromContext(ctx).WithField("refreshNSEClient", "startRefresh") + var clockTime = clock.FromContext(ctx) - executor := serializectx.GetExecutor(ctx, nse.Name) - if executor == nil { - return nil, errors.Errorf("failed to get executor from context") - } + if resp.GetExpirationTime() != nil { + var refreshCh = clockTime.After(2 * clockTime.Until(resp.GetExpirationTime().AsTime().Local()) / 3) - expirationTime := nse.ExpirationTime.AsTime().Local() - refreshCh := clockTime.After(2 * clockTime.Until(expirationTime) / 3) - - refreshCtx, refreshCancel := context.WithCancel(c.ctx) - go func() { - defer refreshCancel() - for { + go func() { select { case <-refreshCtx.Done(): return case <-refreshCh: - <-executor.AsyncExec(func() { - if refreshCtx.Err() != nil { - return - } - - var registerCtx context.Context - var cancel context.CancelFunc - if expirationDuration != 0 { - nse.ExpirationTime = timestamppb.New(clockTime.Now().Add(expirationDuration)) - registerCtx, cancel = clockTime.WithTimeout(refreshCtx, expirationDuration) - } else { - nse.ExpirationTime = nil - registerCtx, cancel = context.WithCancel(refreshCtx) - } - defer cancel() - - reg, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(registerCtx, nse.Clone()) - if err != nil { - logger.Errorf("failed to refresh endpoint registration: %s %s", nse.Name, err.Error()) - return - } - - if reg.ExpirationTime == nil { - logger.Warnf("received nil expiration time: %s", nse.Name) - return - } - - expirationTime = reg.ExpirationTime.AsTime().Local() - refreshCh = clockTime.After(2 * clockTime.Until(expirationTime) / 3) - }) + <-factory.Register(begin.CancelContext(refreshCtx)) } - } - }() + }() + } - return refreshCancel, nil + return resp, err } func (c *refreshNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { @@ -152,8 +82,8 @@ func (c *refreshNSEClient) Find(ctx context.Context, query *registry.NetworkServ } func (c *refreshNSEClient) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { - if cancel, ok := c.nseCancels.LoadAndDelete(nse.Name); ok { - cancel() + if v, ok := c.LoadAndDelete(nse.GetName()); ok { + v() } return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) } diff --git a/pkg/registry/common/refresh/nse_registry_client_test.go b/pkg/registry/common/refresh/nse_registry_client_test.go index 43236bce0..d81f2cae7 100644 --- a/pkg/registry/common/refresh/nse_registry_client_test.go +++ b/pkg/registry/common/refresh/nse_registry_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,6 +18,7 @@ package refresh_test import ( "context" + "fmt" "sync/atomic" "testing" "time" @@ -30,9 +31,9 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" "github.com/networkservicemesh/sdk/pkg/registry/common/null" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" - "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checknse" "github.com/networkservicemesh/sdk/pkg/tools/clock" @@ -74,7 +75,7 @@ func Test_NetworkServiceEndpointRefreshClient_ShouldWorkCorrectlyWithFloatingSce var registerCount int32 client := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), &injectNSERegisterClient{ NetworkServiceEndpointRegistryClient: null.NewNetworkServiceEndpointRegistryClient(), @@ -113,7 +114,7 @@ func TestNewNetworkServiceEndpointRegistryClient(t *testing.T) { countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), countClient, ) @@ -141,7 +142,7 @@ func Test_RefreshNSEClient_CalledRegisterTwice(t *testing.T) { countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), countClient, ) @@ -161,6 +162,64 @@ func Test_RefreshNSEClient_CalledRegisterTwice(t *testing.T) { require.NoError(t, err) } +func Test_RefreshNSEClient_StopsRefreshOnUnregister(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + goleak.VerifyNone(t) + + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) + + ignoreClockMockGoroutine := goleak.IgnoreCurrent() + + countClient := new(requestCountClient) + client := next.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + refresh.NewNetworkServiceEndpointRegistryClient(ctx), + countClient, + ) + + const registerCount = 100 + + var regs []*registry.NetworkServiceEndpoint + + for i := 0; i < registerCount; i++ { + regs = append(regs, testNSE(clockMock)) + regs[i].Name = fmt.Sprint(i) + resp, err := client.Register(ctx, regs[i]) + require.NoError(t, err) + regs[i] = resp + regs[i].ExpirationTime = timestamppb.New(clockMock.Now().Add(expireTimeout)) + } + + clockMock.Add(expireTimeout / 3 * 2) + + require.Eventually(t, func() bool { + return atomic.LoadInt32(&countClient.requestCount) >= 2*int32(len(regs)) + }, testWait, testTick) + + for i := 0; i < registerCount; i++ { + _, err := client.Unregister(ctx, regs[i]) + require.NoError(t, err) + } + + goleak.VerifyNone(t, ignoreClockMockGoroutine) + + for i := 0; i < 5; i++ { + clockMock.Add(expireTimeout / 3 * 2) + + require.Never(t, func() bool { + return atomic.LoadInt32(&countClient.requestCount) > registerCount*3 + }, testWait, testTick) + + // Wait for the Refresh to fully happen + time.Sleep(testWait) + } +} + func Test_RefreshNSEClient_SetsCorrectExpireTime(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -172,12 +231,12 @@ func Test_RefreshNSEClient_SetsCorrectExpireTime(t *testing.T) { countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), + countClient, checknse.NewClient(t, func(t *testing.T, nse *registry.NetworkServiceEndpoint) { - require.Equal(t, expireTimeout, clockMock.Until(nse.ExpirationTime.AsTime().Local())) + nse.ExpirationTime = testNSE(clockMock).ExpirationTime }), - countClient, ) reg, err := client.Register(ctx, testNSE(clockMock)) @@ -214,7 +273,7 @@ func Test_RefreshNSEClient_CorrectInitialRegTime(t *testing.T) { var registerCount int32 client := next.NewNetworkServiceEndpointRegistryClient( - serialize.NewNetworkServiceEndpointRegistryClient(), + begin.NewNetworkServiceEndpointRegistryClient(), refresh.NewNetworkServiceEndpointRegistryClient(ctx), &injectNSERegisterClient{ NetworkServiceEndpointRegistryClient: null.NewNetworkServiceEndpointRegistryClient(), diff --git a/pkg/registry/common/retry/ns_client.go b/pkg/registry/common/retry/ns_client.go index fa9d46b3c..14454d016 100644 --- a/pkg/registry/common/retry/ns_client.go +++ b/pkg/registry/common/retry/ns_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -32,10 +32,11 @@ import ( type retryNSClient struct { interval time.Duration tryTimeout time.Duration + chainCtx context.Context } // NewNetworkServiceRegistryClient - returns a retry chain element -func NewNetworkServiceRegistryClient(opts ...Option) registry.NetworkServiceRegistryClient { +func NewNetworkServiceRegistryClient(ctx context.Context, opts ...Option) registry.NetworkServiceRegistryClient { clientOpts := &options{ interval: time.Millisecond * 200, tryTimeout: time.Second * 15, @@ -46,6 +47,7 @@ func NewNetworkServiceRegistryClient(opts ...Option) registry.NetworkServiceRegi } return &retryNSClient{ + chainCtx: ctx, interval: clientOpts.interval, tryTimeout: clientOpts.tryTimeout, } @@ -55,7 +57,7 @@ func (r *retryNSClient) Register(ctx context.Context, in *registry.NetworkServic logger := log.FromContext(ctx).WithField("retryNSClient", "Register") c := clock.FromContext(ctx) - for ctx.Err() == nil { + for ctx.Err() == nil && r.chainCtx.Err() == nil { registerCtx, cancel := c.WithTimeout(ctx, r.tryTimeout) resp, err := next.NetworkServiceRegistryClient(registerCtx).Register(registerCtx, in, opts...) cancel() @@ -64,6 +66,8 @@ func (r *retryNSClient) Register(ctx context.Context, in *registry.NetworkServic logger.Errorf("try attempt has failed: %v", err.Error()) select { + case <-r.chainCtx.Done(): + return nil, r.chainCtx.Err() case <-ctx.Done(): return nil, ctx.Err() case <-c.After(r.interval): @@ -74,6 +78,10 @@ func (r *retryNSClient) Register(ctx context.Context, in *registry.NetworkServic return resp, err } + if r.chainCtx.Err() != nil { + return nil, r.chainCtx.Err() + } + return nil, ctx.Err() } @@ -81,7 +89,7 @@ func (r *retryNSClient) Find(ctx context.Context, query *registry.NetworkService logger := log.FromContext(ctx).WithField("retryNSClient", "Find") c := clock.FromContext(ctx) - for ctx.Err() == nil { + for ctx.Err() == nil && r.chainCtx.Err() == nil { stream, err := next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) if err != nil { @@ -93,6 +101,10 @@ func (r *retryNSClient) Find(ctx context.Context, query *registry.NetworkService return stream, err } + if r.chainCtx.Err() != nil { + return nil, r.chainCtx.Err() + } + return nil, ctx.Err() } @@ -100,7 +112,7 @@ func (r *retryNSClient) Unregister(ctx context.Context, in *registry.NetworkServ logger := log.FromContext(ctx).WithField("retryNSClient", "Unregister") c := clock.FromContext(ctx) - for ctx.Err() == nil { + for ctx.Err() == nil && r.chainCtx.Err() == nil { closeCtx, cancel := c.WithTimeout(ctx, r.tryTimeout) resp, err := next.NetworkServiceRegistryClient(closeCtx).Unregister(closeCtx, in, opts...) cancel() @@ -109,6 +121,8 @@ func (r *retryNSClient) Unregister(ctx context.Context, in *registry.NetworkServ logger.Errorf("try attempt has failed: %v", err.Error()) select { + case <-r.chainCtx.Done(): + return nil, r.chainCtx.Err() case <-ctx.Done(): return nil, ctx.Err() case <-c.After(r.interval): @@ -118,6 +132,9 @@ func (r *retryNSClient) Unregister(ctx context.Context, in *registry.NetworkServ return resp, err } + if r.chainCtx.Err() != nil { + return nil, r.chainCtx.Err() + } return nil, ctx.Err() } diff --git a/pkg/registry/common/retry/ns_client_test.go b/pkg/registry/common/retry/ns_client_test.go index de116bed5..991a38e95 100644 --- a/pkg/registry/common/retry/ns_client_test.go +++ b/pkg/registry/common/retry/ns_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -41,6 +41,7 @@ func TestNSRetryClient_Register(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -66,7 +67,7 @@ func TestNSRetryClient_Register_ContextHasCorrectDeadline(t *testing.T) { expectedDeadline := clockMock.Now().Add(time.Hour) var client = chain.NewNetworkServiceRegistryClient( - retry.NewNetworkServiceRegistryClient(retry.WithTryTimeout(time.Hour)), + retry.NewNetworkServiceRegistryClient(context.Background(), retry.WithTryTimeout(time.Hour)), checkcontext.NewNSClient(t, func(t *testing.T, c context.Context) { v, ok := c.Deadline() require.True(t, ok) @@ -91,7 +92,7 @@ func TestNSRetryClient_Unregister_ContextHasCorrectDeadline(t *testing.T) { expectedDeadline := clockMock.Now().Add(time.Hour) var client = chain.NewNetworkServiceRegistryClient( - retry.NewNetworkServiceRegistryClient(retry.WithTryTimeout(time.Hour)), + retry.NewNetworkServiceRegistryClient(context.Background(), retry.WithTryTimeout(time.Hour)), checkcontext.NewNSClient(t, func(t *testing.T, c context.Context) { v, ok := c.Deadline() require.True(t, ok) @@ -110,6 +111,7 @@ func TestNSRetryClient_Unregister(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -129,6 +131,7 @@ func TestNSRetryClient_Find(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -148,6 +151,7 @@ func TestNSRetryClient_RegisterCompletesOnParentContextTimeout(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -170,6 +174,7 @@ func TestNSRetryClient_UnregisterCompletesOnParentContextTimeout(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -192,6 +197,7 @@ func TestNSRetryClient_FindCompletesOnParentContextTimeout(t *testing.T) { var client = chain.NewNetworkServiceRegistryClient( retry.NewNetworkServiceRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, diff --git a/pkg/registry/common/retry/nse_client.go b/pkg/registry/common/retry/nse_client.go index 22cf138a0..20e331e8c 100644 --- a/pkg/registry/common/retry/nse_client.go +++ b/pkg/registry/common/retry/nse_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -32,10 +32,11 @@ import ( type retryNSEClient struct { interval time.Duration tryTimeout time.Duration + chainCtx context.Context } // NewNetworkServiceEndpointRegistryClient - returns a retry chain element -func NewNetworkServiceEndpointRegistryClient(opts ...Option) registry.NetworkServiceEndpointRegistryClient { +func NewNetworkServiceEndpointRegistryClient(ctx context.Context, opts ...Option) registry.NetworkServiceEndpointRegistryClient { clientOpts := &options{ interval: time.Millisecond * 200, tryTimeout: time.Second * 15, @@ -48,6 +49,7 @@ func NewNetworkServiceEndpointRegistryClient(opts ...Option) registry.NetworkSer return &retryNSEClient{ interval: clientOpts.interval, tryTimeout: clientOpts.tryTimeout, + chainCtx: ctx, } } @@ -64,6 +66,8 @@ func (r *retryNSEClient) Register(ctx context.Context, nse *registry.NetworkServ logger.Errorf("try attempt has failed: %v", err.Error()) select { + case <-r.chainCtx.Done(): + return nil, err case <-ctx.Done(): return nil, ctx.Err() case <-c.After(r.interval): @@ -85,7 +89,7 @@ func (r *retryNSEClient) Find(ctx context.Context, query *registry.NetworkServic if query != nil { cloneQuery.NetworkServiceEndpoint = query.NetworkServiceEndpoint.Clone() } - for ctx.Err() == nil { + for ctx.Err() == nil && r.chainCtx.Err() == nil { stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, cloneQuery, opts...) if err != nil { @@ -97,6 +101,10 @@ func (r *retryNSEClient) Find(ctx context.Context, query *registry.NetworkServic return stream, err } + if r.chainCtx.Err() != nil { + return nil, ctx.Err() + } + return nil, ctx.Err() } @@ -113,6 +121,8 @@ func (r *retryNSEClient) Unregister(ctx context.Context, in *registry.NetworkSer logger.Errorf("try attempt has failed: %v", err.Error()) select { + case <-r.chainCtx.Done(): + return nil, err case <-ctx.Done(): return nil, ctx.Err() case <-c.After(r.interval): diff --git a/pkg/registry/common/retry/nse_client_test.go b/pkg/registry/common/retry/nse_client_test.go index 7a5c8be57..79e46a098 100644 --- a/pkg/registry/common/retry/nse_client_test.go +++ b/pkg/registry/common/retry/nse_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -41,6 +41,7 @@ func TestNSERetryClient_Register(t *testing.T) { var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -66,7 +67,7 @@ func TestNSERetryClient_Register_ContextHasCorrectDeadline(t *testing.T) { expectedDeadline := clockMock.Now().Add(time.Hour) var client = chain.NewNetworkServiceEndpointRegistryClient( - retry.NewNetworkServiceEndpointRegistryClient(retry.WithTryTimeout(time.Hour)), + retry.NewNetworkServiceEndpointRegistryClient(context.Background(), retry.WithTryTimeout(time.Hour)), checkcontext.NewNSEClient(t, func(t *testing.T, c context.Context) { v, ok := c.Deadline() require.True(t, ok) @@ -91,7 +92,7 @@ func TestNSERetryClient_Unregister_ContextHasCorrectDeadline(t *testing.T) { expectedDeadline := clockMock.Now().Add(time.Hour) var client = chain.NewNetworkServiceEndpointRegistryClient( - retry.NewNetworkServiceEndpointRegistryClient(retry.WithTryTimeout(time.Hour)), + retry.NewNetworkServiceEndpointRegistryClient(context.Background(), retry.WithTryTimeout(time.Hour)), checkcontext.NewNSEClient(t, func(t *testing.T, c context.Context) { v, ok := c.Deadline() require.True(t, ok) @@ -110,6 +111,7 @@ func TestNSERetryClient_Unregister(t *testing.T) { var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -129,6 +131,7 @@ func TestNSERetryClient_Find(t *testing.T) { var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -148,6 +151,7 @@ func TestNSERetryClient_RegisterCompletesOnParentContextTimeout(t *testing.T) { var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -170,6 +174,7 @@ func TestNSERetryClient_UnregisterCompletesOnParentContextTimeout(t *testing.T) var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, @@ -192,6 +197,7 @@ func TestNSERetryClient_FindCompletesOnParentContextTimeout(t *testing.T) { var client = chain.NewNetworkServiceEndpointRegistryClient( retry.NewNetworkServiceEndpointRegistryClient( + context.Background(), retry.WithInterval(time.Millisecond*10), retry.WithTryTimeout(time.Second/30)), counter, diff --git a/pkg/registry/common/serialize/README.md b/pkg/registry/common/serialize/README.md deleted file mode 100644 index 778802428..000000000 --- a/pkg/registry/common/serialize/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Functional requirements - -`Register`, `Unregister` events for the same `NetworkService.Name`, `NetworkServiceEndpoint.Name` should be executed in -registry chain serially. - -# Implementation - -## serializeNSServer, serializeNSClient, serializeNSEServer, serializeNSEclient - -It is just the same as serialize chain elements for the network service chain. Please see [serialize](https://github.com/networkservicemesh/sdk/blob/master/pkg/networkservice/common/serialize) -for more details. diff --git a/pkg/registry/common/serialize/doc.go b/pkg/registry/common/serialize/doc.go deleted file mode 100644 index 46a14ffe4..000000000 --- a/pkg/registry/common/serialize/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package serialize provides NSE, NS registry chain elements for serial Register, Unregister event processing -package serialize diff --git a/pkg/registry/common/serialize/ns_client.go b/pkg/registry/common/serialize/ns_client.go deleted file mode 100644 index 36c3ca838..000000000 --- a/pkg/registry/common/serialize/ns_client.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package serialize - -import ( - "context" - - "github.com/golang/protobuf/ptypes/empty" - "google.golang.org/grpc" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" -) - -type serializeNSClient struct { - executor multiexecutor.MultiExecutor -} - -// NewNetworkServiceRegistryClient returns a new serialize NS registry client chain element -func NewNetworkServiceRegistryClient() registry.NetworkServiceRegistryClient { - return new(serializeNSClient) -} - -func (c *serializeNSClient) Register(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (reg *registry.NetworkService, err error) { - <-c.executor.AsyncExec(ns.Name, func() { - registerCtx := serializectx.WithMultiExecutor(ctx, &c.executor) - reg, err = next.NetworkServiceRegistryClient(ctx).Register(registerCtx, ns, opts...) - }) - return reg, err -} - -func (c *serializeNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { - return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) -} - -func (c *serializeNSClient) Unregister(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (_ *empty.Empty, err error) { - <-c.executor.AsyncExec(ns.Name, func() { - _, err = next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) - }) - return new(empty.Empty), err -} diff --git a/pkg/registry/common/serialize/ns_server.go b/pkg/registry/common/serialize/ns_server.go deleted file mode 100644 index aab9773ee..000000000 --- a/pkg/registry/common/serialize/ns_server.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package serialize - -import ( - "context" - - "github.com/golang/protobuf/ptypes/empty" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" -) - -type serializeNSServer struct { - executor multiexecutor.MultiExecutor -} - -// NewNetworkServiceRegistryServer returns a new serialize NS registry server chain element -func NewNetworkServiceRegistryServer() registry.NetworkServiceRegistryServer { - return new(serializeNSServer) -} - -func (s *serializeNSServer) Register(ctx context.Context, ns *registry.NetworkService) (reg *registry.NetworkService, err error) { - <-s.executor.AsyncExec(ns.Name, func() { - registerCtx := serializectx.WithMultiExecutor(ctx, &s.executor) - reg, err = next.NetworkServiceRegistryServer(ctx).Register(registerCtx, ns) - }) - return reg, err -} - -func (s *serializeNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { - return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) -} - -func (s *serializeNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (_ *empty.Empty, err error) { - <-s.executor.AsyncExec(ns.Name, func() { - _, err = next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) - }) - return new(empty.Empty), err -} diff --git a/pkg/registry/common/serialize/nse_client.go b/pkg/registry/common/serialize/nse_client.go deleted file mode 100644 index 4e8d225fd..000000000 --- a/pkg/registry/common/serialize/nse_client.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package serialize - -import ( - "context" - - "github.com/golang/protobuf/ptypes/empty" - "google.golang.org/grpc" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" -) - -type serializeNSEClient struct { - executor multiexecutor.MultiExecutor -} - -// NewNetworkServiceEndpointRegistryClient returns a new serialize NSE registry client chain element -func NewNetworkServiceEndpointRegistryClient() registry.NetworkServiceEndpointRegistryClient { - return new(serializeNSEClient) -} - -func (c *serializeNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (reg *registry.NetworkServiceEndpoint, err error) { - <-c.executor.AsyncExec(nse.Name, func() { - registerCtx := serializectx.WithMultiExecutor(ctx, &c.executor) - reg, err = next.NetworkServiceEndpointRegistryClient(ctx).Register(registerCtx, nse, opts...) - }) - return reg, err -} - -func (c *serializeNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { - return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, query, opts...) -} - -func (c *serializeNSEClient) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (_ *empty.Empty, err error) { - <-c.executor.AsyncExec(nse.Name, func() { - _, err = next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) - }) - return new(empty.Empty), err -} diff --git a/pkg/registry/common/serialize/nse_server.go b/pkg/registry/common/serialize/nse_server.go deleted file mode 100644 index a69236a6c..000000000 --- a/pkg/registry/common/serialize/nse_server.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package serialize - -import ( - "context" - - "github.com/golang/protobuf/ptypes/empty" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" -) - -type serializeNSEServer struct { - executor multiexecutor.MultiExecutor -} - -// NewNetworkServiceEndpointRegistryServer returns a new serialize NSE registry server chain element -func NewNetworkServiceEndpointRegistryServer() registry.NetworkServiceEndpointRegistryServer { - return new(serializeNSEServer) -} - -func (s *serializeNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (reg *registry.NetworkServiceEndpoint, err error) { - <-s.executor.AsyncExec(nse.Name, func() { - registerCtx := serializectx.WithMultiExecutor(ctx, &s.executor) - reg, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(registerCtx, nse) - }) - return reg, err -} - -func (s *serializeNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) -} - -func (s *serializeNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (_ *empty.Empty, err error) { - <-s.executor.AsyncExec(nse.Name, func() { - _, err = next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) - }) - return new(empty.Empty), err -} diff --git a/pkg/registry/common/serialize/nse_server_test.go b/pkg/registry/common/serialize/nse_server_test.go deleted file mode 100644 index 20f6f4f22..000000000 --- a/pkg/registry/common/serialize/nse_server_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package serialize_test - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "testing" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/stretchr/testify/assert" - "go.uber.org/goleak" - - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/common/serialize" - "github.com/networkservicemesh/sdk/pkg/registry/core/chain" - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/serializectx" -) - -const ( - parallelCount = 1000 -) - -func TestSerializeNSEServer_StressTest(t *testing.T) { - t.Cleanup(func() { goleak.VerifyNone(t) }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - server := chain.NewNetworkServiceEndpointRegistryServer( - serialize.NewNetworkServiceEndpointRegistryServer(), - new(eventNSEServer), - newParallelServer(t), - ) - - wg := new(sync.WaitGroup) - wg.Add(parallelCount) - for i := 0; i < parallelCount; i++ { - go func(name string) { - defer wg.Done() - - reg, err := server.Register(ctx, ®istry.NetworkServiceEndpoint{Name: name}) - assert.NoError(t, err) - - _, err = server.Unregister(ctx, reg.Clone()) - assert.NoError(t, err) - }(fmt.Sprint(i % 20)) - } - wg.Wait() -} - -type eventNSEServer struct{} - -func (s *eventNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - reg, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) - if err != nil { - return nil, err - } - - executor := serializectx.GetExecutor(ctx, reg.Name) - - go func() { - executor.AsyncExec(func() { - registerCtx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - registerCtx = serializectx.WithExecutor(registerCtx, executor) - - _, _ = next.NetworkServiceEndpointRegistryServer(ctx).Register(registerCtx, reg) - }) - }() - - go func() { - executor.AsyncExec(func() { - unregisterCtx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - _, _ = next.NetworkServiceEndpointRegistryServer(ctx).Unregister(unregisterCtx, reg) - }) - }() - - return reg, nil -} - -func (s *eventNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) -} - -func (s *eventNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) -} - -type parallelNSEServer struct { - t *testing.T - states sync.Map -} - -func newParallelServer(t *testing.T) *parallelNSEServer { - return ¶llelNSEServer{ - t: t, - } -} - -func (s *parallelNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - reg, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) - if err != nil { - return nil, err - } - - raw, _ := s.states.LoadOrStore(reg.Name, new(int32)) - statePtr := raw.(*int32) - - state := atomic.LoadInt32(statePtr) - assert.True(s.t, atomic.CompareAndSwapInt32(statePtr, state, state+1), "state has been changed") - - return reg, nil -} - -func (s *parallelNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) -} - -func (s *parallelNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - raw, _ := s.states.LoadOrStore(nse.Name, new(int32)) - statePtr := raw.(*int32) - - state := atomic.LoadInt32(statePtr) - assert.True(s.t, atomic.CompareAndSwapInt32(statePtr, state, state+1), "state has been changed") - - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) -} diff --git a/pkg/registry/common/seturl/nse_server.go b/pkg/registry/common/seturl/nse_server.go deleted file mode 100644 index 09708662d..000000000 --- a/pkg/registry/common/seturl/nse_server.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package seturl provides registry.NetworkServiceEndpointRegistryServer that sets passed url for each found nse -package seturl - -import ( - "context" - "net/url" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" -) - -type seturlNSEServer struct { - u *url.URL -} - -func (s *setURLNSEServer) Send(nse *registry.NetworkServiceEndpointResponse) error { - nse.NetworkServiceEndpoint.Url = s.u.String() - return s.NetworkServiceEndpointRegistry_FindServer.Send(nse) -} - -type setURLNSEServer struct { - u *url.URL - registry.NetworkServiceEndpointRegistry_FindServer -} - -func (n *seturlNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - u := nse.Url - nse.Url = n.u.String() - resp, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) - - if resp != nil { - resp.Url = u - } - - return resp, err -} - -func (n *seturlNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, &setURLNSEServer{NetworkServiceEndpointRegistry_FindServer: server, u: n.u}) -} - -func (n *seturlNSEServer) Unregister(ctx context.Context, service *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - service.Url = n.u.String() - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, service) -} - -// NewNetworkServiceEndpointRegistryServer creates a new seturl registry.NetworkServiceEndpointRegistryServer -func NewNetworkServiceEndpointRegistryServer(u *url.URL) registry.NetworkServiceEndpointRegistryServer { - if u == nil { - panic("u can not be nil") - } - return &seturlNSEServer{u: u} -} diff --git a/pkg/registry/common/seturl/nse_server_test.go b/pkg/registry/common/seturl/nse_server_test.go deleted file mode 100644 index 7a0915fe3..000000000 --- a/pkg/registry/common/seturl/nse_server_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seturl_test - -import ( - "context" - "net/url" - "testing" - "time" - - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/common/seturl" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" - "github.com/networkservicemesh/sdk/pkg/registry/core/next" -) - -func Test_StoreUrlNSEServer(t *testing.T) { - defer goleak.VerifyNone(t) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - - defer cancel() - - s := next.NewNetworkServiceEndpointRegistryServer( - seturl.NewNetworkServiceEndpointRegistryServer(&url.URL{Scheme: "tcp", Host: "127.0.0.1"}), - memory.NewNetworkServiceEndpointRegistryServer(), - ) - - _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - Url: "unix://file.sock", - }) - require.NoError(t, err) - - stream, err := adapters.NetworkServiceEndpointServerToClient(s).Find(ctx, ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - }, - }) - require.NoError(t, err) - - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, 1) - - require.Equal(t, "tcp://127.0.0.1", list[0].Url) -} diff --git a/pkg/registry/common/storeurl/nse_server.go b/pkg/registry/common/storeurl/nse_server.go deleted file mode 100644 index c7aa0e2f5..000000000 --- a/pkg/registry/common/storeurl/nse_server.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package storeurl provides chain element that stores incoming NSE URLs into map -package storeurl - -import ( - "context" - "net/url" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/stringurl" -) - -type storeurl struct { - m *stringurl.Map -} - -type urlstockFindServer struct { - m *stringurl.Map - registry.NetworkServiceEndpointRegistry_FindServer -} - -func (n *storeurl) Register(ctx context.Context, service *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, service) -} - -func (n *storeurl) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { - return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, &urlstockFindServer{NetworkServiceEndpointRegistry_FindServer: server, m: n.m}) -} - -func (n *storeurl) Unregister(ctx context.Context, service *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, service) -} - -// NewNetworkServiceEndpointRegistryServer creates new instance of storeurl NSE server -func NewNetworkServiceEndpointRegistryServer(m *stringurl.Map) registry.NetworkServiceEndpointRegistryServer { - if m == nil { - panic("m can not be nil") - } - return &storeurl{m: m} -} - -func (s *urlstockFindServer) Send(nseResp *registry.NetworkServiceEndpointResponse) error { - u, err := url.Parse(nseResp.NetworkServiceEndpoint.Url) - if err != nil { - return err - } - s.m.Store(nseResp.NetworkServiceEndpoint.Name, u) - return s.NetworkServiceEndpointRegistry_FindServer.Send(nseResp) -} diff --git a/pkg/registry/common/storeurl/nse_server_test.go b/pkg/registry/common/storeurl/nse_server_test.go deleted file mode 100644 index cce0e9177..000000000 --- a/pkg/registry/common/storeurl/nse_server_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storeurl_test - -import ( - "context" - "net/url" - "testing" - "time" - - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "github.com/networkservicemesh/sdk/pkg/registry/common/memory" - "github.com/networkservicemesh/sdk/pkg/registry/common/storeurl" - "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" - "github.com/networkservicemesh/sdk/pkg/registry/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/stringurl" -) - -func Test_StoreUrlNSEServer(t *testing.T) { - defer goleak.VerifyNone(t) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - - defer cancel() - - var m stringurl.Map - - s := next.NewNetworkServiceEndpointRegistryServer( - storeurl.NewNetworkServiceEndpointRegistryServer(&m), - memory.NewNetworkServiceEndpointRegistryServer(), - ) - - _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - Url: "unix://file.sock", - }) - require.NoError(t, err) - - stream, err := adapters.NetworkServiceEndpointServerToClient(s).Find(ctx, ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - }, - }) - require.NoError(t, err) - - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, 1) - - v, ok := m.Load("nse-1") - require.True(t, ok) - - require.Equal(t, url.URL{Scheme: "unix", Host: "file.sock"}, *v) -} diff --git a/pkg/registry/switchcase/ns_client.go b/pkg/registry/switchcase/ns_client.go new file mode 100644 index 000000000..125f26040 --- /dev/null +++ b/pkg/registry/switchcase/ns_client.go @@ -0,0 +1,81 @@ +// Copyright (c) 2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package switchcase + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +// NSClientCase repsenets NetworkService case for clients. +type NSClientCase struct { + Condition func(context.Context, *registry.NetworkService) bool + Action registry.NetworkServiceRegistryClient +} + +type switchCaseNSClient struct { + cases []NSClientCase +} + +func (n *switchCaseNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + for _, c := range n.cases { + if c.Condition(ctx, in) { + return c.Action.Register(ctx, in) + } + } + return next.NetworkServiceRegistryServer(ctx).Register(ctx, in) +} + +func (n *switchCaseNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + for _, c := range n.cases { + if c.Condition(ctx, in.NetworkService) { + return c.Action.Find(ctx, in) + } + } + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (n *switchCaseNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + for _, c := range n.cases { + if c.Condition(ctx, in) { + return c.Action.Unregister(ctx, in) + } + } + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceRegistryClient - returns a new switchcase client. +func NewNetworkServiceRegistryClient(cases ...NSClientCase) registry.NetworkServiceRegistryClient { + for index, c := range cases { + if c.Action == nil { + panic(fmt.Sprintf("index: %v, %v.Action is nil", index, c)) + } + if c.Condition == nil { + panic(fmt.Sprintf("index: %v, %v.Condition is nil", index, c)) + } + } + + return &switchCaseNSClient{ + cases: cases, + } +} diff --git a/pkg/registry/switchcase/ns_server.go b/pkg/registry/switchcase/ns_server.go new file mode 100644 index 000000000..2a9344d98 --- /dev/null +++ b/pkg/registry/switchcase/ns_server.go @@ -0,0 +1,82 @@ +// Copyright (c) 2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package switchcase provides chain elements acting like a switch-case statement, selecting a chain element with first +// succeed condition +package switchcase + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +// NSServerCase repsenets NetworkService case for servers. +type NSServerCase struct { + Condition func(context.Context, *registry.NetworkService) bool + Action registry.NetworkServiceRegistryServer +} + +type switchCaseNSServer struct { + cases []NSServerCase +} + +func (n *switchCaseNSServer) Register(ctx context.Context, service *registry.NetworkService) (*registry.NetworkService, error) { + for _, c := range n.cases { + if c.Condition(ctx, service) { + return c.Action.Register(ctx, service) + } + } + return next.NetworkServiceRegistryServer(ctx).Register(ctx, service) +} + +func (n *switchCaseNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { + for _, c := range n.cases { + if c.Condition(server.Context(), query.NetworkService) { + return c.Action.Find(query, server) + } + } + return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) +} + +func (n *switchCaseNSServer) Unregister(ctx context.Context, service *registry.NetworkService) (*empty.Empty, error) { + for _, c := range n.cases { + if c.Condition(ctx, service) { + return c.Action.Unregister(ctx, service) + } + } + return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, service) +} + +// NewNetworkServiceRegistryServer - returns a new switchcase server. +func NewNetworkServiceRegistryServer(cases ...NSServerCase) registry.NetworkServiceRegistryServer { + for index, c := range cases { + if c.Action == nil { + panic(fmt.Sprintf("index: %v, %v.Action is nil", index, c)) + } + if c.Condition == nil { + panic(fmt.Sprintf("index: %v, %v.Condition is nil", index, c)) + } + } + + return &switchCaseNSServer{ + cases: cases, + } +} diff --git a/pkg/registry/switchcase/nse_client.go b/pkg/registry/switchcase/nse_client.go new file mode 100644 index 000000000..8d46ebd45 --- /dev/null +++ b/pkg/registry/switchcase/nse_client.go @@ -0,0 +1,81 @@ +// Copyright (c) 2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package switchcase + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +// NSEClientCase repsenets NetworkServiceEndpoint case for clients. +type NSEClientCase struct { + Condition func(context.Context, *registry.NetworkServiceEndpoint) bool + Action registry.NetworkServiceEndpointRegistryClient +} + +type switchCaseNSEClient struct { + cases []NSEClientCase +} + +func (n *switchCaseNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + for _, c := range n.cases { + if c.Condition(ctx, in) { + return c.Action.Register(ctx, in) + } + } + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) +} + +func (n *switchCaseNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + for _, c := range n.cases { + if c.Condition(ctx, in.NetworkServiceEndpoint) { + return c.Action.Find(ctx, in) + } + } + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (n *switchCaseNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + for _, c := range n.cases { + if c.Condition(ctx, in) { + return c.Action.Unregister(ctx, in) + } + } + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceEndpointRegistryClient - returns a new switchcase client. +func NewNetworkServiceEndpointRegistryClient(cases ...NSEClientCase) registry.NetworkServiceEndpointRegistryClient { + for index, c := range cases { + if c.Action == nil { + panic(fmt.Sprintf("index: %v, %v.Action is nil", index, c)) + } + if c.Condition == nil { + panic(fmt.Sprintf("index: %v, %v.Condition is nil", index, c)) + } + } + + return &switchCaseNSEClient{ + cases: cases, + } +} diff --git a/pkg/registry/switchcase/nse_server.go b/pkg/registry/switchcase/nse_server.go new file mode 100644 index 000000000..a748c310e --- /dev/null +++ b/pkg/registry/switchcase/nse_server.go @@ -0,0 +1,80 @@ +// Copyright (c) 2022 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package switchcase + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" +) + +// NSEServerCase repsenets NetworkServiceEndpoint case for servers. +type NSEServerCase struct { + Condition func(context.Context, *registry.NetworkServiceEndpoint) bool + Action registry.NetworkServiceEndpointRegistryServer +} + +type switchCaseNSEServer struct { + cases []NSEServerCase +} + +func (n *switchCaseNSEServer) Register(ctx context.Context, service *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + for _, c := range n.cases { + if c.Condition(ctx, service) { + return c.Action.Register(ctx, service) + } + } + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, service) +} + +func (n *switchCaseNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + for _, c := range n.cases { + if c.Condition(server.Context(), query.NetworkServiceEndpoint) { + return c.Action.Find(query, server) + } + } + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) +} + +func (n *switchCaseNSEServer) Unregister(ctx context.Context, service *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + for _, c := range n.cases { + if c.Condition(ctx, service) { + return c.Action.Unregister(ctx, service) + } + } + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, service) +} + +// NewNetworkServiceEndpointRegistryServer - returns a new switchcase server. +func NewNetworkServiceEndpointRegistryServer(cases ...NSEServerCase) registry.NetworkServiceEndpointRegistryServer { + for index, c := range cases { + if c.Action == nil { + panic(fmt.Sprintf("index: %v, %v.Action is nil", index, c)) + } + if c.Condition == nil { + panic(fmt.Sprintf("index: %v, %v.Condition is nil", index, c)) + } + } + + return &switchCaseNSEServer{ + cases: cases, + } +} diff --git a/pkg/tools/expire/README.md b/pkg/tools/expire/README.md deleted file mode 100644 index 604ced2e0..000000000 --- a/pkg/tools/expire/README.md +++ /dev/null @@ -1,72 +0,0 @@ -# Functional requirements - -For some entities we need "expire on timeout if not being updated" logic. There are few requirements for the algorithm -we need to match: -1. Update can take some time and can fail: - 1. Expiration should be paused for the Update processing. - 2. If Update fails, expiration should be resumed for the same expiration time. -2. If some error occurs after the Update has been successfully finished, entity should be gracefully closed. -3. Close event should be performed on context with event scope lifetime, to prevent leaks. -4. If entity has been already closed or expired it should never be closed or expired until it will have been updated. - -# Implementation - -## Manager - -`Manager` can be used for managing expiration for some set of entities. Here is an example for its usage with some -abstract `expireServer`: - -```go -type expireServer struct { - expireManager Manager -} - -func (s *expireServer) Open(ctx context.Context, req *request) (*response, error) { - logger := log.FromContext(ctx).WithField("expireServer", "Open") - - // 1. Stop expiration. - s.expireManager.Stop(req.Id) - - // 2. Send Open event. - resp, err := nextServer(ctx).Open(ctx, req) - if err != nil { - // 2.1. Reset expiration if Open event has failed. - s.expireManager.Reset(req.Id) - return nil, err - } - - // 3. Delete the old expiration if we need to create a new one for the new ID. - closeResp := resp.Clone() - if closeResp.Id != req.Id { - s.expireManager.Delete(req.Id) - } - - // 4. Create a new expiration. - s.expireManager.New( - serializectx.GetExecutor(ctx, closeResp.Id), - closeResp.Id, - s.computeExpirationTime(closeResp), - func (closeCtx context.Context) { - if err := nextServer(ctx).Close(closeCtx, closeResp); err != nil { - logger.Errorf("failed to close expired response: %s %s", closeResp.Id, err.Error()) - } - }, - ) - - return resp, nil -} - -func (s *expireServer) Close(ctx context.Context, resp *response) error { - logger := log.FromContext(ctx).WithField("expireServer", "Close") - - // 1. Check if we have an expiration. - if !s.expireManager.DeleteExpiration(conn.GetId()) { - // 1.1. If there is no expiration, there is nothing to do. - logger.Warnf("response has been already closed: %s", resp.Id) - return nil - } - - // 2. Send Close event. - return nextServer(ctx).Close(ctx, resp) -} -``` \ No newline at end of file diff --git a/pkg/tools/expire/manager.go b/pkg/tools/expire/manager.go deleted file mode 100644 index 7ea349252..000000000 --- a/pkg/tools/expire/manager.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package expire provide expiration manager -package expire - -import ( - "context" - "time" - - "github.com/networkservicemesh/sdk/pkg/tools/clock" -) - -// Executor is a serialize.Executor interface -type Executor interface { - AsyncExec(f func()) <-chan struct{} -} - -// Manager manages expiration for some entities -type Manager struct { - ctx context.Context - clockTime clock.Clock - timers timerMap -} - -type timer struct { - expirationTime time.Time - stopped bool - - clock.Timer -} - -// NewManager creates a new Manager -func NewManager(ctx context.Context) *Manager { - return &Manager{ - ctx: ctx, - clockTime: clock.FromContext(ctx), - } -} - -// New creates a new expiration for the `id`, on expiration it would call `closeFunc` -func (m *Manager) New(executor Executor, id string, expirationTime time.Time, closeFunc func(context.Context)) { - if executor == nil { - panic("cannot create a new expiration with nil Executor") - } - - var t *timer - t = &timer{ - expirationTime: expirationTime, - Timer: m.clockTime.AfterFunc(m.clockTime.Until(expirationTime), func() { - executor.AsyncExec(func() { - if tt, ok := m.timers.Load(id); !ok || tt != t { - return - } - m.timers.Delete(id) - - closeCtx, cancel := context.WithCancel(m.ctx) - defer cancel() - - closeFunc(closeCtx) - }) - }), - } - - m.timers.Store(id, t) -} - -// Stop stops expiration for the `id` -func (m *Manager) Stop(id string) bool { - t, loaded := m.timers.Load(id) - if loaded { - t.stopped = t.Stop() - } - return loaded -} - -// Start starts stopped expiration for the `id` with the same expiration time -func (m *Manager) Start(id string) { - if t, ok := m.timers.Load(id); ok && t.stopped { - t.stopped = false - t.Reset(m.clockTime.Until(t.expirationTime)) - } -} - -// Expire force expires stopped expiration for the `id` -func (m *Manager) Expire(id string) { - if t, ok := m.timers.Load(id); ok && t.stopped { - t.stopped = false - t.Reset(0) - } -} - -// Delete deletes expiration for the `id` -func (m *Manager) Delete(id string) bool { - t, ok := m.timers.LoadAndDelete(id) - if !ok { - return false - } - - t.Stop() - - return true -} diff --git a/pkg/tools/multiexecutor/multi_executor.go b/pkg/tools/multiexecutor/multi_executor.go deleted file mode 100644 index ff8c17bd6..000000000 --- a/pkg/tools/multiexecutor/multi_executor.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package multiexecutor provides a structure MultiExecutor that can be used to guarantee exclusive by ID, in order execution of functions. -package multiexecutor - -import ( - "sync" - - "github.com/edwarnicke/serialize" -) - -// MultiExecutor - a struct that can be used to guarantee exclusive by ID, in order execution of functions. -type MultiExecutor struct { - executors map[string]*refCountExecutor - executor serialize.Executor - once sync.Once -} - -type refCountExecutor struct { - count int - executor serialize.Executor -} - -// AsyncExec - guarantees f() will be executed Exclusively for specified ID and in the Order submitted. -// It immediately returns a channel that will be closed when f() has completed execution. -func (e *MultiExecutor) AsyncExec(id string, f func()) (ch <-chan struct{}) { - e.once.Do(func() { - e.executors = make(map[string]*refCountExecutor) - }) - - <-e.executor.AsyncExec(func() { - exec, ok := e.executors[id] - if !ok { - exec = new(refCountExecutor) - e.executors[id] = exec - } - exec.count++ - - ch = exec.executor.AsyncExec(func() { - f() - e.executor.AsyncExec(func() { - exec.count-- - if exec.count == 0 { - delete(e.executors, id) - } - }) - }) - }) - return ch -} diff --git a/pkg/tools/multiexecutor/multi_executor_test.go b/pkg/tools/multiexecutor/multi_executor_test.go deleted file mode 100644 index fcfcfd742..000000000 --- a/pkg/tools/multiexecutor/multi_executor_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package multiexecutor_test - -import ( - "strconv" - "sync" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" -) - -const ( - totalCount = 1000 - parallelCount = 10 -) - -func TestMultiExecutor_AsyncExec(t *testing.T) { - e := new(multiexecutor.MultiExecutor) - - var data sync.Map - for i := 0; i < totalCount; i++ { - id := strconv.Itoa(i % parallelCount) - - k := i - e.AsyncExec(id, func() { - val, ok := data.Load(id) - if ok == false { - val = 0 - } - require.Equal(t, k/parallelCount, val.(int)) - data.Store(id, val.(int)+1) - }) - } - - for i := 0; i < parallelCount; i++ { - id := strconv.Itoa(i) - <-e.AsyncExec(id, func() { - val, ok := data.Load(id) - require.True(t, ok) - require.Equal(t, totalCount/parallelCount, val.(int)) - }) - } -} diff --git a/pkg/tools/sandbox/builder.go b/pkg/tools/sandbox/builder.go index 21295c75e..8f9517eab 100644 --- a/pkg/tools/sandbox/builder.go +++ b/pkg/tools/sandbox/builder.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -33,7 +33,6 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/chains/nsmgrproxy" "github.com/networkservicemesh/sdk/pkg/registry/chains/memory" "github.com/networkservicemesh/sdk/pkg/registry/chains/proxydns" - registryconnect "github.com/networkservicemesh/sdk/pkg/registry/common/connect" "github.com/networkservicemesh/sdk/pkg/registry/common/dnsresolve" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/log" @@ -294,9 +293,6 @@ func (b *Builder) newNSMgrProxy() *NSMgrEntry { nsmgrproxy.WithListenOn(entry.URL), nsmgrproxy.WithName(entry.Name), nsmgrproxy.WithDialOptions(dialOptions...), - nsmgrproxy.WithRegistryConnectOptions( - registryconnect.WithDialOptions(dialOptions...), - ), ) serve(ctx, b.t, entry.URL, entry.Register) diff --git a/pkg/tools/sandbox/node.go b/pkg/tools/sandbox/node.go index 3ee3ea552..e5e508143 100644 --- a/pkg/tools/sandbox/node.go +++ b/pkg/tools/sandbox/node.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -159,7 +159,6 @@ func (n *Node) NewForwarder( ctx, CloneURL(n.NSMgr.URL), registryclient.WithDialOptions(dialOptions...), - registryclient.WithNSEAdditionalFunctionality(), ) n.registerEndpoint(ctx, nse, nseClone, entry.NetworkServiceEndpointRegistryClient) diff --git a/pkg/tools/sandbox/utils.go b/pkg/tools/sandbox/utils.go index e6e561969..044ce3924 100644 --- a/pkg/tools/sandbox/utils.go +++ b/pkg/tools/sandbox/utils.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -31,8 +31,6 @@ import ( ) const ( - // RegistryExpiryDuration is a duration that should be used for expire tests - RegistryExpiryDuration = time.Second // DialTimeout is a default dial timeout for the sandbox tests DialTimeout = 2 * time.Second ) diff --git a/pkg/tools/serializectx/context.go b/pkg/tools/serializectx/context.go deleted file mode 100644 index a75b6a305..000000000 --- a/pkg/tools/serializectx/context.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package serializectx allows to set executor in the context -package serializectx - -import ( - "context" - - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" -) - -const ( - multiExecutorKey contextKeyType = "multiExecutor" - executorKey contextKeyType = "executor" -) - -type contextKeyType string - -// WithMultiExecutor wraps `parent` in a new context with multiexecutor.MultiExecutor -func WithMultiExecutor(parent context.Context, multiExecutor *multiexecutor.MultiExecutor) context.Context { - if parent == nil { - panic("cannot create context from nil parent") - } - return context.WithValue(parent, multiExecutorKey, multiExecutor) -} - -// WithExecutor wraps `parent` in a new context with Executor -func WithExecutor(parent context.Context, executor *Executor) context.Context { - if parent == nil { - panic("cannot create context from nil parent") - } - return context.WithValue(parent, executorKey, executor) -} - -// GetExecutor returns Executor -func GetExecutor(ctx context.Context, id string) *Executor { - if executor, ok := ctx.Value(executorKey).(*Executor); ok && executor.id == id { - return executor - } - if multiExecutor, ok := ctx.Value(multiExecutorKey).(*multiexecutor.MultiExecutor); ok { - return &Executor{ - id: id, - asyncExec: func(f func()) <-chan struct{} { - return multiExecutor.AsyncExec(id, f) - }, - } - } - return nil -}