Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make local files default for fs commands #506

Merged
merged 12 commits into from
Jun 23, 2023
63 changes: 27 additions & 36 deletions cmd/fs/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,38 @@ import (
"github.com/databricks/cli/libs/filer"
)

type Scheme string

const (
DbfsScheme = Scheme("dbfs")
LocalScheme = Scheme("file")
NoScheme = Scheme("")
)

func filerForPath(ctx context.Context, fullPath string) (filer.Filer, string, error) {
parts := strings.SplitN(fullPath, ":/", 2)
// Split path at : to detect any file schemes
parts := strings.SplitN(fullPath, ":", 2)

// If no scheme is specified, then local path
if len(parts) < 2 {
return nil, "", fmt.Errorf(`no scheme specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, fullPath)
f, err := filer.NewLocalClient("")
return f, fullPath, err
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct though; if folks specify foobar:/bla` it shouldn't be interpreted as a local path.

If there is a : in the path they can always use ./foo:bar to specify it and it should work.

As long as the parts[0] is all \w characters, it should be "dbfs". If it isn't, we should return an error.

Otherwise typos on the scheme can cause issues.


// On windows systems, paths start with a drive letter. If the scheme
// is a single letter and the OS is windows, then we conclude the path
// is meant to be a local path.
if runtime.GOOS == "windows" && len(parts[0]) == 1 {
f, err := filer.NewLocalClient("")
return f, fullPath, err
}
scheme := Scheme(parts[0])

if parts[0] != "dbfs" {
return nil, "", fmt.Errorf("invalid scheme: %s", parts[0])
}

path := parts[1]
switch scheme {
case DbfsScheme:
w := root.WorkspaceClient(ctx)
// If the specified path has the "Volumes" prefix, use the Files API.
if strings.HasPrefix(path, "Volumes/") {
f, err := filer.NewFilesClient(w, "/")
return f, path, err
}
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
w := root.WorkspaceClient(ctx)

case LocalScheme:
if runtime.GOOS == "windows" {
parts := strings.SplitN(path, ":", 2)
if len(parts) < 2 {
return nil, "", fmt.Errorf("no volume specfied for path: %s", path)
}
volume := parts[0] + ":"
relPath := parts[1]
f, err := filer.NewLocalClient(volume)
return f, relPath, err
}
f, err := filer.NewLocalClient("/")
// If the specified path has the "Volumes" prefix, use the Files API.
if strings.HasPrefix(path, "/Volumes/") {
f, err := filer.NewFilesClient(w, "/")
return f, path, err

default:
return nil, "", fmt.Errorf(`unsupported scheme %s specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, scheme, fullPath)
}

// The file is a dbfs file, and uses the DBFS APIs
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
}
56 changes: 46 additions & 10 deletions cmd/fs/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,58 @@ import (
"runtime"
"testing"

"github.com/databricks/cli/libs/filer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNotSpecifyingVolumeForWindowsPathErrors(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip()
}
func TestFilerForPathForLocalPaths(t *testing.T) {
tmpDir := t.TempDir()
ctx := context.Background()

f, path, err := filerForPath(ctx, tmpDir)
assert.NoError(t, err)
assert.Equal(t, tmpDir, path)

info, err := f.Stat(ctx, path)
require.NoError(t, err)
assert.True(t, info.IsDir())
}

func TestFilerForPathForInvalidScheme(t *testing.T) {
ctx := context.Background()
pathWithVolume := `file:/c:/foo/bar`
pathWOVolume := `file:/uno/dos`

_, path, err := filerForPath(ctx, pathWithVolume)
_, _, err := filerForPath(ctx, "dbf:/a")
assert.ErrorContains(t, err, "invalid scheme")

_, _, err = filerForPath(ctx, "foo:a")
assert.ErrorContains(t, err, "invalid scheme")

_, _, err = filerForPath(ctx, "file:/a")
assert.ErrorContains(t, err, "invalid scheme")
}

func testWindowsFilerForPath(t *testing.T, ctx context.Context, fullPath string) {
f, path, err := filerForPath(ctx, fullPath)
assert.NoError(t, err)
assert.Equal(t, `/foo/bar`, path)

_, _, err = filerForPath(ctx, pathWOVolume)
assert.Equal(t, "no volume specfied for path: uno/dos", err.Error())
// Assert path remains unchanged
assert.Equal(t, path, fullPath)

// Assert local client is created
_, ok := f.(*filer.LocalClient)
assert.True(t, ok)
}

func TestFilerForWindowsLocalPaths(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}

ctx := context.Background()
testWindowsFilerForPath(t, ctx, `c:\abc`)
testWindowsFilerForPath(t, ctx, `c:abc`)
testWindowsFilerForPath(t, ctx, `d:\abc`)
testWindowsFilerForPath(t, ctx, `d:\abc`)
testWindowsFilerForPath(t, ctx, `f:\abc\ef`)
}
13 changes: 3 additions & 10 deletions internal/fs_cp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) {
f, err := filer.NewLocalClient(tmp)
require.NoError(t, err)

return f, path.Join("file:/", filepath.ToSlash(tmp))
return f, path.Join(filepath.ToSlash(tmp))
}

func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
Expand Down Expand Up @@ -259,21 +259,14 @@ func TestAccFsCpErrorsWhenSourceIsDirWithoutRecursiveFlag(t *testing.T) {
tmpDir := temporaryDbfsDir(t, w)

_, _, err = RequireErrorRun(t, "fs", "cp", "dbfs:"+tmpDir, "dbfs:/tmp")
assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", strings.TrimPrefix(tmpDir, "/")), err.Error())
}

func TestAccFsCpErrorsOnNoScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))

_, _, err := RequireErrorRun(t, "fs", "cp", "/a", "/b")
assert.Equal(t, "no scheme specified for path /a. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error())
assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", tmpDir), err.Error())
}

func TestAccFsCpErrorsOnInvalidScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))

_, _, err := RequireErrorRun(t, "fs", "cp", "dbfs:/a", "https:/b")
assert.Equal(t, "unsupported scheme https specified for path https:/b. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error())
assert.Equal(t, "invalid scheme: https", err.Error())
}

func TestAccFsCpSourceIsDirectoryButTargetIsFile(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions libs/filer/dbfs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ type DbfsClient struct {
workspaceClient *databricks.WorkspaceClient

// File operations will be relative to this path.
root RootPath
root WorkspaceRootPath
}

func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
return &DbfsClient{
workspaceClient: w,

root: NewRootPath(root),
root: NewWorkspaceRootPath(root),
}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions libs/filer/files_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type FilesClient struct {
apiClient *client.DatabricksClient

// File operations will be relative to this path.
root RootPath
root WorkspaceRootPath
}

func filesNotImplementedError(fn string) error {
Expand All @@ -77,7 +77,7 @@ func NewFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
workspaceClient: w,
apiClient: apiClient,

root: NewRootPath(root),
root: NewWorkspaceRootPath(root),
}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions libs/filer/local_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
// LocalClient implements the [Filer] interface for the local filesystem.
type LocalClient struct {
// File operations will be relative to this path.
root RootPath
root localRootPath
}

func NewLocalClient(root string) (Filer, error) {
return &LocalClient{
root: NewRootPath(root),
root: NewLocalRootPath(root),
}, nil
}

Expand Down
27 changes: 27 additions & 0 deletions libs/filer/local_root_path.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package filer

import (
"fmt"
"path/filepath"
"strings"
)

type localRootPath struct {
rootPath string
}

func NewLocalRootPath(root string) localRootPath {
if root == "" {
return localRootPath{""}
}
return localRootPath{filepath.Clean(root)}
}

func (rp *localRootPath) Join(name string) (string, error) {
absPath := filepath.Join(rp.rootPath, name)

if !strings.HasPrefix(absPath, rp.rootPath) {
return "", fmt.Errorf("relative path escapes root: %s", name)
}
return absPath, nil
}
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
142 changes: 142 additions & 0 deletions libs/filer/local_root_path_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package filer

import (
"path/filepath"
"runtime"
"testing"

"github.com/stretchr/testify/assert"
)

func testUnixLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)

remotePath, err := rp.Join("a/b/c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)

remotePath, err = rp.Join("a/b/../d")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/d", remotePath)

remotePath, err = rp.Join("a/../c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/c", remotePath)

remotePath, err = rp.Join("a/b/c/.")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)

remotePath, err = rp.Join("a/b/c/d/./../../f/g")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/f/g", remotePath)

remotePath, err = rp.Join(".//a/..//./b/..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("a/b/../..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("/")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root: ..`)

_, err = rp.Join("a/../..")
assert.ErrorContains(t, err, `relative path escapes root: a/../..`)

_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)

_, err = rp.Join("/./.././..")
assert.ErrorContains(t, err, `relative path escapes root: /./.././..`)

_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)

_, err = rp.Join("./..")
assert.ErrorContains(t, err, `relative path escapes root: ./..`)

_, err = rp.Join("./../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../../..`)

_, err = rp.Join("./../a/./b../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../a/./b../../..`)

_, err = rp.Join("../..")
assert.ErrorContains(t, err, `relative path escapes root: ../..`)
}

func TestUnixLocalRootPath(t *testing.T) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
t.SkipNow()
}

testUnixLocalRootPath(t, "/some/root/path")
testUnixLocalRootPath(t, "/some/root/path/")
testUnixLocalRootPath(t, "/some/root/path/.")
testUnixLocalRootPath(t, "/some/root/../path/")
}

func testWindowsLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)

remotePath, err := rp.Join(`a\b\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)

remotePath, err = rp.Join(`a\b\..\d`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\d`, remotePath)

remotePath, err = rp.Join(`a\..\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\c`, remotePath)

remotePath, err = rp.Join(`a\b\c\.`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)

remotePath, err = rp.Join(`a\b\..\..`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root`)

_, err = rp.Join(`a\..\..`)
assert.ErrorContains(t, err, `relative path escapes root`)
}

func TestWindowsLocalRootPath(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}

testWindowsLocalRootPath(t, `c:\some\root\path`)
testWindowsLocalRootPath(t, `c:\some\root\path\`)
testWindowsLocalRootPath(t, `c:\some\root\path\.`)
testWindowsLocalRootPath(t, `C:\some\root\..\path\`)
}
Loading