From 9946957987caa5659214d15006211c4a587282b2 Mon Sep 17 00:00:00 2001 From: Sergey Semenov Date: Tue, 14 Jul 2020 13:13:47 +0700 Subject: [PATCH] Add ns close --- pkg/kernel/networkservice/inject/server.go | 46 ++++++++++++---------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/pkg/kernel/networkservice/inject/server.go b/pkg/kernel/networkservice/inject/server.go index cb4987a4..da06fd07 100644 --- a/pkg/kernel/networkservice/inject/server.go +++ b/pkg/kernel/networkservice/inject/server.go @@ -18,6 +18,7 @@ package inject import ( "context" + "runtime" "github.com/networkservicemesh/sdk-kernel/pkg/kernel/utils" @@ -42,21 +43,31 @@ func NewServer() networkservice.NetworkServiceServer { } func (a *injectServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - if request.GetConnection().GetMechanism().GetType() != kernel.MECHANISM { - return next.Server(ctx).Request(ctx, request) + conn, err := next.Server(ctx).Request(ctx, request) + if err != nil { + return nil, err } + mech := kernel.ToMechanism(conn.GetMechanism()) + + /* Lock the OS thread so we don't accidentally switch namespaces */ + runtime.LockOSThread() + defer runtime.UnlockOSThread() forwarderNetNSHandle, err := netns.Get() if err != nil { return nil, errors.Wrapf(err, "Unable to obtain Forwarder's network namespace handle") } - clientNetNSHandle, err := a.getClientNetNSHandle(request.GetConnection()) + defer func() { _ = forwarderNetNSHandle.Close() }() + + clientNetNSHandle, err := utils.GetNSHandleFromInode(mech.GetNetNSInode()) if err != nil { return nil, errors.Wrapf(err, "Unable to obtain Client's network namespace handle") } - ifaceName := request.GetConnection().GetMechanism().GetParameters()[kernel.InterfaceNameKey] + defer func() { _ = clientNetNSHandle.Close() }() + + ifaceName := mech.GetParameters()[kernel.InterfaceNameKey] if ifaceName == "" { - return nil, errors.New("Virtual function's interface name is not found") + return nil, errors.New("Interface name is not found") } err = a.moveInterfaceToAnotherNamespace(ifaceName, forwarderNetNSHandle, clientNetNSHandle) @@ -69,23 +80,27 @@ func (a *injectServer) Request(ctx context.Context, request *networkservice.Netw } func (a *injectServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - if conn.GetMechanism().GetType() != kernel.MECHANISM { - return next.Server(ctx).Close(ctx, conn) - } + mech := kernel.ToMechanism(conn.GetMechanism()) + + /* Lock the OS thread so we don't accidentally switch namespaces */ + runtime.LockOSThread() + defer runtime.UnlockOSThread() forwarderNetNSHandle, err := netns.Get() if err != nil { return nil, errors.Wrapf(err, "Unable to obtain Forwarder's network namespace handle") } + defer func() { _ = forwarderNetNSHandle.Close() }() - clientNetNSHandle, err := a.getClientNetNSHandle(conn) + clientNetNSHandle, err := utils.GetNSHandleFromInode(mech.GetNetNSInode()) if err != nil { return nil, errors.Wrapf(err, "Unable to obtain Client's network namespace handle") } + defer func() { _ = clientNetNSHandle.Close() }() ifaceName := conn.GetMechanism().GetParameters()[kernel.InterfaceNameKey] if ifaceName == "" { - return nil, errors.New("Virtual function's interface name is not found") + return nil, errors.New("Interface name is not found") } err = a.moveInterfaceToAnotherNamespace(ifaceName, clientNetNSHandle, forwarderNetNSHandle) @@ -94,7 +109,7 @@ func (a *injectServer) Close(ctx context.Context, conn *networkservice.Connectio } log.Entry(ctx).Infof("Moved network interface %s into the Forwarder's namespace for connection %s", ifaceName, conn.GetId()) - return next.Client(ctx).Close(ctx, conn) + return next.Server(ctx).Close(ctx, conn) } func (a *injectServer) moveInterfaceToAnotherNamespace(ifaceName string, fromNetNS, toNetNS netns.NsHandle) error { @@ -110,12 +125,3 @@ func (a *injectServer) moveInterfaceToAnotherNamespace(ifaceName string, fromNet return nil } - -func (a *injectServer) getClientNetNSHandle(conn *networkservice.Connection) (netns.NsHandle, error) { - clientNetNSInode := conn.GetMechanism().GetParameters()[kernel.NetNSInodeKey] - if clientNetNSInode == "" { - return 0, errors.New("Client's pod net ns inode is not found") - } - - return utils.GetNSHandleFromInode(clientNetNSInode) -}