diff --git a/foo/caller_not_allowed_test.go b/foo/caller_not_allowed_test.go new file mode 100644 index 00000000..208fad71 --- /dev/null +++ b/foo/caller_not_allowed_test.go @@ -0,0 +1,15 @@ +package foo + +import ( + "errors" + "testing" + + "knative.dev/hack/shell" +) + +func TestCallerNotAllowed(t *testing.T) { + _, err := shell.NewProjectLocation("..") + if !errors.Is(err, shell.ErrCallerNotAllowed) { + t.Error("usage should be blocked") + } +} diff --git a/shell/project.go b/shell/project.go index 2d7ba973..f217fab2 100644 --- a/shell/project.go +++ b/shell/project.go @@ -18,12 +18,21 @@ package shell import ( "errors" + "fmt" "path" + "regexp" "runtime" ) -// ErrCantGetCaller is raised when we can't calculate a caller of NewProjectLocation. -var ErrCantGetCaller = errors.New("can't get caller") +var ( + // ErrCantGetCaller is raised when we can't calculate a caller of NewProjectLocation. + ErrCantGetCaller = errors.New("can't get caller") + + // ErrCallerNotAllowed is raised when user tries to use this shell-out package + // outside of allowed places. This package is deprecated from start and was + // introduced to allow rewriting of shell code to Golang in small chunks. + ErrCallerNotAllowed = errors.New("don't try use knative.dev/hack/shell package outside of allowed places") +) // NewProjectLocation creates a ProjectLocation that is used to calculate // relative paths within the project. @@ -32,6 +41,10 @@ func NewProjectLocation(pathToRoot string) (ProjectLocation, error) { if !ok { return nil, ErrCantGetCaller } + err := ensureIsValid(filename) + if err != nil { + return nil, err + } return &callerLocation{ caller: filename, pathToRoot: pathToRoot, @@ -50,3 +63,17 @@ type callerLocation struct { caller string pathToRoot string } + +func ensureIsValid(filename string) error { + validPaths := []string{ + "knative.+/test/upgrade/", + "knative[/-]hack/shell/", + } + for _, validPath := range validPaths { + r := regexp.MustCompile(validPath) + if loc := r.FindStringIndex(filename); loc != nil { + return nil + } + } + return fmt.Errorf("%w: %s", ErrCallerNotAllowed, filename) +}