Skip to content

Commit

Permalink
feat: add removeSMBMappingDuringUnmount in chart config
Browse files Browse the repository at this point in the history
fix

fix ut
  • Loading branch information
andyzhangx committed Aug 14, 2022
1 parent eb9ddbc commit e767c59
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 61 deletions.
1 change: 1 addition & 0 deletions charts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ The following table lists the configurable parameters of the latest SMB CSI Driv
| `linux.resources.smb.requests.memory` | smb-csi-driver memory requests limits | `20Mi`
| `windows.enabled` | whether enable windows feature | `false` |
| `windows.dsName` | name of driver daemonset on windows | `csi-smb-node-win` |
| `windows.removeSMBMappingDuringUnmount` | remove SMBMapping during unmount on Windows node windows | `true` |
| `windows.resources.livenessProbe.limits.memory` | liveness-probe memory limits | `200Mi` |
| `windows.resources.livenessProbe.requests.cpu` | liveness-probe cpu requests limits | `10m` |
| `windows.resources.livenessProbe.requests.memory` | liveness-probe memory requests limits | `20Mi` |
Expand Down
Binary file modified charts/latest/csi-driver-smb-v0.0.0.tgz
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ spec:
- --nodeid=$(KUBE_NODE_NAME)
- "--metrics-address=0.0.0.0:{{ .Values.node.metricsPort }}"
- "--enable-get-volume-stats={{ .Values.feature.enableGetVolumeStats }}"
- "--remove-smb-mapping-during-unmount={{ .Values.windows.removeSMBMappingDuringUnmount }}"
ports:
- containerPort: {{ .Values.node.livenessProbe.healthPort }}
name: healthz
Expand Down
1 change: 1 addition & 0 deletions charts/latest/csi-driver-smb/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ windows:
enabled: false
dsName: csi-smb-node-win # daemonset name
kubelet: 'C:\var\lib\kubelet'
removeSMBMappingDuringUnmount: true
tolerations:
- key: "node.kubernetes.io/os"
operator: "Exists"
Expand Down
26 changes: 14 additions & 12 deletions cmd/smbplugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ func init() {
}

var (
endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI endpoint")
nodeID = flag.String("nodeid", "", "node id")
driverName = flag.String("drivername", smb.DefaultDriverName, "name of the driver")
ver = flag.Bool("ver", false, "Print the version and exit.")
metricsAddress = flag.String("metrics-address", "0.0.0.0:29644", "export the metrics")
kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Required only when running out of cluster.")
enableGetVolumeStats = flag.Bool("enable-get-volume-stats", true, "allow GET_VOLUME_STATS on agent node")
workingMountDir = flag.String("working-mount-dir", "/tmp", "working directory for provisioner to mount smb shares temporarily")
endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI endpoint")
nodeID = flag.String("nodeid", "", "node id")
driverName = flag.String("drivername", smb.DefaultDriverName, "name of the driver")
ver = flag.Bool("ver", false, "Print the version and exit.")
metricsAddress = flag.String("metrics-address", "0.0.0.0:29644", "export the metrics")
kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Required only when running out of cluster.")
enableGetVolumeStats = flag.Bool("enable-get-volume-stats", true, "allow GET_VOLUME_STATS on agent node")
removeSMBMappingDuringUnmount = flag.Bool("remove-smb-mapping-during-unmount", true, "remove SMBMapping during unmount on Windows node")
workingMountDir = flag.String("working-mount-dir", "/tmp", "working directory for provisioner to mount smb shares temporarily")
)

func main() {
Expand All @@ -67,10 +68,11 @@ func main() {

func handle() {
driverOptions := smb.DriverOptions{
NodeID: *nodeID,
DriverName: *driverName,
EnableGetVolumeStats: *enableGetVolumeStats,
WorkingMountDir: *workingMountDir,
NodeID: *nodeID,
DriverName: *driverName,
EnableGetVolumeStats: *enableGetVolumeStats,
RemoveSMBMappingDuringUnmount: *removeSMBMappingDuringUnmount,
WorkingMountDir: *workingMountDir,
}
driver := smb.NewDriver(&driverOptions)
driver.Run(*endpoint, *kubeconfig, false)
Expand Down
3 changes: 2 additions & 1 deletion pkg/csi-common/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ func TestNewNonBlockingGRPCServer(t *testing.T) {
func TestStart(t *testing.T) {
s := NewNonBlockingGRPCServer()
// sleep a while to avoid race condition in unit test
time.Sleep(time.Millisecond * 2000)
time.Sleep(time.Millisecond * 500)
s.Start("tcp://127.0.0.1:0", nil, nil, nil, true)
time.Sleep(time.Millisecond * 500)
}

func TestServe(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/mounter/safe_mounter_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
utilexec "k8s.io/utils/exec"
)

func NewSafeMounter() (*mount.SafeFormatAndMount, error) {
func NewSafeMounter(removeSMBMappingDuringUnmount bool) (*mount.SafeFormatAndMount, error) {
return &mount.SafeFormatAndMount{
Interface: mount.New(""),
Exec: utilexec.New(),
Expand Down
2 changes: 1 addition & 1 deletion pkg/mounter/safe_mounter_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func TestNewSafeMounter(t *testing.T) {
resp, err := NewSafeMounter()
resp, err := NewSafeMounter(true)
assert.NotNil(t, resp)
assert.Nil(t, err)
}
76 changes: 36 additions & 40 deletions pkg/mounter/safe_mounter_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ type CSIProxyMounter interface {
var _ CSIProxyMounter = &csiProxyMounter{}

type csiProxyMounter struct {
FsClient *fsclient.Client
SMBClient *smbclient.Client
FsClient *fsclient.Client
SMBClient *smbclient.Client
removeSMBMappingDuringUnmount bool
}

func normalizeWindowsPath(path string) string {
Expand Down Expand Up @@ -101,12 +102,9 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
}

source = strings.Replace(source, "/", "\\", -1)
if strings.HasSuffix(source, "\\") {
source = strings.TrimSuffix(source, "\\")
}

source = strings.TrimSuffix(source, "\\")
mappingPath, err := getRootMappingPath(source)
if err != nil {
if mounter.removeSMBMappingDuringUnmount && err != nil {
return fmt.Errorf("getRootMappingPath(%s) failed with error: %v", source, err)
}
unlock := lock(mappingPath)
Expand All @@ -119,16 +117,17 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
Username: mountOptions[0],
Password: sensitiveMountOptions[0],
}
klog.V(2).Infof("begin to mount %s on %s", source, normalizedTarget)
klog.V(2).Infof("begin to NewSmbGlobalMapping %s on %s", source, normalizedTarget)
if _, err := mounter.SMBClient.NewSmbGlobalMapping(context.Background(), smbMountRequest); err != nil {
return fmt.Errorf("smb mapping failed with error: %v", err)
return fmt.Errorf("NewSmbGlobalMapping(%s, %s) failed with error: %v", source, normalizedTarget, err)
}
klog.V(2).Infof("mount %s on %s successfully", source, normalizedTarget)
klog.V(2).Infof("NewSmbGlobalMapping %s on %s successfully", source, normalizedTarget)

if err = incementRemotePathReferencesCount(mappingPath, source); err != nil {
klog.Warningf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err)
if mounter.removeSMBMappingDuringUnmount {
if err := incementRemotePathReferencesCount(mappingPath, source); err != nil {
return fmt.Errorf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err)
}
}

return nil
}

Expand All @@ -138,35 +137,32 @@ func (mounter *csiProxyMounter) SMBUnmount(target string) error {
if remotePath, err := os.Readlink(target); err != nil {
klog.Warningf("SMBUnmount: can't get remote path: %v", err)
} else {
if strings.HasSuffix(remotePath, "\\") {
remotePath = strings.TrimSuffix(remotePath, "\\")
}
remotePath = strings.TrimSuffix(remotePath, "\\")
mappingPath, err := getRootMappingPath(remotePath)
if err != nil {
klog.Warningf("getRootMappingPath(%s) failed with error: %v", remotePath, err)
} else {
klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath)
if mounter.removeSMBMappingDuringUnmount && err != nil {
return fmt.Errorf("getRootMappingPath(%s) failed with error: %v", remotePath, err)
}
klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath)

unlock := lock(mappingPath)
defer unlock()
unlock := lock(mappingPath)
defer unlock()

if mounter.removeSMBMappingDuringUnmount {
if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil {
klog.Warningf("decrementMappingPathCount(%s, %d) failed with error: %v", mappingPath, remotePath, err)
} else {
count := getRemotePathReferencesCount(mappingPath)
if count == 0 {
smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{
RemotePath: remotePath,
}
klog.V(2).Infof("begin to unmount %s on %s", remotePath, target)
if _, err := mounter.SMBClient.RemoveSmbGlobalMapping(context.Background(), smbUnmountRequest); err != nil {
return fmt.Errorf("smb unmapping failed with error: %v", err)
} else {
klog.V(2).Infof("unmount %s on %s successfully", remotePath, target)
}
} else {
klog.Infof("SMBUnmount: found %f links to %s", count, mappingPath)
return fmt.Errorf("decrementMappingPathCount(%s, %s) failed with error: %v", mappingPath, remotePath, err)
}
count := getRemotePathReferencesCount(mappingPath)
if count == 0 {
smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{
RemotePath: remotePath,
}
klog.V(2).Infof("begin to RemoveSmbGlobalMapping %s on %s", remotePath, target)
if _, err := mounter.SMBClient.RemoveSmbGlobalMapping(context.Background(), smbUnmountRequest); err != nil {
return fmt.Errorf("RemoveSmbGlobalMapping failed with error: %v", err)
}
klog.V(2).Infof("RemoveSmbGlobalMapping %s on %s successfully", remotePath, target)
} else {
klog.Infof("SMBUnmount: found %d links to %s", count, mappingPath)
}
}
}
Expand Down Expand Up @@ -342,7 +338,7 @@ func (mounter *csiProxyMounter) MountSensitiveWithoutSystemdWithMountFlags(sourc

// NewCSIProxyMounter - creates a new CSI Proxy mounter struct which encompassed all the
// clients to the CSI proxy - filesystem, disk and volume clients.
func NewCSIProxyMounter() (*csiProxyMounter, error) {
func NewCSIProxyMounter(removeSMBMappingDuringUnmount bool) (*csiProxyMounter, error) {
fsClient, err := fsclient.NewClient()
if err != nil {
return nil, err
Expand All @@ -358,8 +354,8 @@ func NewCSIProxyMounter() (*csiProxyMounter, error) {
}, nil
}

func NewSafeMounter() (*mount.SafeFormatAndMount, error) {
csiProxyMounter, err := NewCSIProxyMounter()
func NewSafeMounter(removeSMBMappingDuringUnmount bool) (*mount.SafeFormatAndMount, error) {
csiProxyMounter, err := NewCSIProxyMounter(removeSMBMappingDuringUnmount)
if err == nil {
klog.V(2).Infof("using CSIProxyMounterV1, %s", csiProxyMounter.GetAPIVersions())
return &mount.SafeFormatAndMount{
Expand Down
2 changes: 1 addition & 1 deletion pkg/smb/fake_mounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (f *fakeMounter) IsLikelyNotMountPoint(file string) (bool, error) {

func NewFakeMounter() (*mount.SafeFormatAndMount, error) {
if runtime.GOOS == "windows" {
return mounter.NewSafeMounter()
return mounter.NewSafeMounter(true)
}
return &mount.SafeFormatAndMount{
Interface: &fakeMounter{},
Expand Down
6 changes: 3 additions & 3 deletions pkg/smb/nodeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContext,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed "+
flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed "+
"with smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource),
expectedErr: testutil.TestError{
Expand All @@ -169,7 +169,7 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContext,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed with "+
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
expectedErr: testutil.TestError{},
Expand All @@ -180,7 +180,7 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContextWithMetadata,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed with "+
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
expectedErr: testutil.TestError{},
Expand Down
9 changes: 7 additions & 2 deletions pkg/smb/smb.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ type DriverOptions struct {
NodeID string
DriverName string
EnableGetVolumeStats bool
WorkingMountDir string
// this only applies to Windows node
RemoveSMBMappingDuringUnmount bool
WorkingMountDir string
}

// Driver implements all interfaces of CSI drivers
Expand All @@ -62,6 +64,8 @@ type Driver struct {
volumeLocks *volumeLocks
workingMountDir string
enableGetVolumeStats bool
// this only applies to Windows node
removeSMBMappingDuringUnmount bool
}

// NewDriver Creates a NewCSIDriver object. Assumes vendor version is equal to driver version &
Expand All @@ -72,6 +76,7 @@ func NewDriver(options *DriverOptions) *Driver {
driver.Version = driverVersion
driver.NodeID = options.NodeID
driver.enableGetVolumeStats = options.EnableGetVolumeStats
driver.removeSMBMappingDuringUnmount = options.RemoveSMBMappingDuringUnmount
driver.workingMountDir = options.WorkingMountDir
driver.volumeLocks = newVolumeLocks()
return &driver
Expand All @@ -85,7 +90,7 @@ func (d *Driver) Run(endpoint, kubeconfig string, testMode bool) {
}
klog.V(2).Infof("\nDRIVER INFORMATION:\n-------------------\n%s\n\nStreaming logs below:", versionMeta)

d.mounter, err = mounter.NewSafeMounter()
d.mounter, err = mounter.NewSafeMounter(d.removeSMBMappingDuringUnmount)
if err != nil {
klog.Fatalf("Failed to get safe mounter. Error: %v", err)
}
Expand Down

0 comments on commit e767c59

Please sign in to comment.