Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support RemoveSmbGlobalMapping during unmount on Windows node #505

Merged
merged 1 commit into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions pkg/mounter/refcounter_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//go:build windows
// +build windows

/*
Copyright 2020 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mounter

import (
"crypto/md5"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)

var basePath = "c:\\csi\\smbmounts"
var mutexes sync.Map

func lock(key string) func() {
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()

return func() { mtx.Unlock() }
}

// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example:
//
// \\hostname\share\subpath => \\hostname\share, error is nil
// \\hostname\share => \\hostname\share, error is nil
// \\hostname => '', error is 'remote path (\\hostname) is invalid'
func getRootMappingPath(path string) (string, error) {
items := strings.Split(path, "\\")
parts := []string{}
for _, s := range items {
if len(s) > 0 {
parts = append(parts, s)
if len(parts) == 2 {
break
}
}
}
if len(parts) != 2 {
return "", fmt.Errorf("remote path (%s) is invalid", path)
}
// parts[0] is a smb host name
// parts[1] is a smb share name
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
}

// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
// How it works:
// 1. MappingPath contains two components: hostname, sharename
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
// Example: c:\\csi\\smbmounts\\hostname\\sharename
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
file, err := os.Create(filePath)
if err != nil {
return err
}
defer func() {
file.Close()
}()

_, err = file.WriteString(remotePath)
return err
}

// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
return os.Remove(filePath)
vitaliy-leschenko marked this conversation as resolved.
Show resolved Hide resolved
}

// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func getRemotePathReferencesCount(mappingPath string) int {
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if os.MkdirAll(path, os.ModeDir) != nil {
return -1
}
files, err := os.ReadDir(path)
if err != nil {
return -1
}
return len(files)
}

func getMd5(path string) string {
data := []byte(strings.ToLower(path))
return fmt.Sprintf("%x", md5.Sum(data))
}
227 changes: 227 additions & 0 deletions pkg/mounter/refcounter_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
Copyright 2020 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mounter

import (
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestLockUnlock(t *testing.T) {
key := "resource name"

unlock := lock(key)
defer unlock()

_, loaded := mutexes.Load(key)
assert.True(t, loaded)
}

func TestLockLockedResource(t *testing.T) {
locked := true
unlock := lock("a")
go func() {
time.Sleep(500 * time.Microsecond)
locked = false
unlock()
}()

// try to lock already locked resource
unlock2 := lock("a")
defer unlock2()
if locked {
assert.Fail(t, "access to locked resource")
}
}

func TestLockDifferentKeys(t *testing.T) {
unlocka := lock("a")
unlockb := lock("b")
unlocka()
unlockb()
}

func TestGetRootMappingPath(t *testing.T) {
testCases := []struct {
remote string
expectResult string
expectError bool
}{
{
remote: "",
expectResult: "",
expectError: true,
},
{
remote: "hostname",
expectResult: "",
expectError: true,
},
{
remote: "\\\\hostname\\path",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\subpath",
expectResult: "\\\\hostname\\path",
expectError: false,
},
}
for _, tc := range testCases {
result, err := getRootMappingPath(tc.remote)
if tc.expectError && err == nil {
t.Errorf("Expected error but getRootMappingPath returned a nil error")
}
if !tc.expectError {
if err != nil {
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
}
if tc.expectResult != result {
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result)
}
}
}
}

func TestRemotePathReferencesCounter(t *testing.T) {
remotePath1 := "\\\\servername\\share\\subpath\\1"
remotePath2 := "\\\\servername\\share\\subpath\\2"
mappingPath, err := getRootMappingPath(remotePath1)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

// by default we have no any files in `mappingPath`. So, `count` should be zero
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath2`. So, `count` should be equal `2`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath2`. So, `count` should be equal `0`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
}

func TestIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
t.Error("reference file does not exist")
}
}

func TestDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
if _, err := os.Stat(reference); os.IsExist(err) {
t.Error("reference file exists")
}
}

func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
// next calls of `incementMappingPathCount` with the same arguments should be ignored
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
}

func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
}
Loading