From acccdf1271d6b76ca7a01777e4a93e81e6ea1ed6 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 8 Oct 2024 12:04:33 +0900 Subject: [PATCH] feat: Deprecated Function Checker (#85) # Description Add a deprecated functon checker that identifies the usage and suggests a alternative functions if possible. ## Implementation Details The core of this implementation is the `DeprecatedFuncChecker` struct type, which maintains a map of deprecated functions along with their alternative functions. ### How Deprecated Functions are Recognized 1. **Registration**: Deprecated functions are registered using thr `Register` method, which takes three parameters - Package name - Function name - Alternative function (`packageName.funcName` format) 2. **Function Call Identification**: For each AST node, the checker looks for `*ast.CallExpr` nodes, which represent function calls. 3. **Selector Expression Analysis**: It then checks if the function call is a selector expression (`*ast.SelectorExpr`), which is typical for `package.Function()` calls. 4. **Package and Function Matching**: The checker extracts the package name and function name from the selector expression and compares them against the registered deprecated functions. ## Usage Example ```go checker := NewDeprecatedFuncChecker() checker.Register("fmt", "Println", "fmt.Print") checker.Register("os", "Remove", "os.RemoveAll") deprecated, err := checker.Check(filename, astNode, fileSet) if err != nil { // Handle error } for _, dep := range deprecated { fmt.Printf("Deprecated function %s.%s used. Consider using %s instead.\n", dep.Package, dep.Function, dep.Alternative) } ``` --- formatter/builder.go | 75 +++++---- formatter/deprecated.go | 18 +++ formatter/formatter_test.go | 8 +- internal/checker/deprecate.go | 121 ++++++++++++++ internal/checker/deprecate_test.go | 250 +++++++++++++++++++++++++++++ internal/engine.go | 1 + internal/lints/deprecate_func.go | 51 ++++++ internal/rule_set.go | 10 ++ testdata/deprecate/orig_caller.gno | 15 ++ 9 files changed, 517 insertions(+), 32 deletions(-) create mode 100644 formatter/deprecated.go create mode 100644 internal/checker/deprecate.go create mode 100644 internal/checker/deprecate_test.go create mode 100644 internal/lints/deprecate_func.go create mode 100644 testdata/deprecate/orig_caller.gno diff --git a/formatter/builder.go b/formatter/builder.go index 05345f7..ec4d25c 100644 --- a/formatter/builder.go +++ b/formatter/builder.go @@ -3,6 +3,7 @@ package formatter import ( "fmt" "strings" + "unicode" "github.com/fatih/color" "github.com/gnolang/tlin/internal" @@ -19,6 +20,7 @@ const ( SliceBound = "slice-bounds-check" Defers = "defer-issues" MissingModPackage = "gno-mod-tidy" + DeprecatedFunc = "deprecated" ) const tabWidth = 8 @@ -58,6 +60,8 @@ func GenerateFormattedIssue(issues []tt.Issue, snippet *internal.SourceCode) str // If no specific formatter is found for the given rule, it returns a GeneralIssueFormatter. func getFormatter(rule string) IssueFormatter { switch rule { + case DeprecatedFunc: + return &DeprecatedFuncFormatter{} case EarlyReturn: return &EarlyReturnOpportunityFormatter{} case SimplifySliceExpr: @@ -146,7 +150,7 @@ func (b *IssueFormatterBuilder) AddCodeSnippet() *IssueFormatterBuilder { continue } - line := expandTabs(b.snippet.Lines[i-1]) + line := b.snippet.Lines[i-1] line = strings.TrimPrefix(line, commonIndent) lineNum := fmt.Sprintf("%*d", maxLineNumWidth, i) @@ -213,7 +217,8 @@ func (b *IssueFormatterBuilder) AddSuggestion() *IssueFormatterBuilder { suggestionLines := strings.Split(b.issue.Suggestion, "\n") for i, line := range suggestionLines { lineNum := fmt.Sprintf("%*d", maxLineNumWidth, b.issue.Start.Line+i) - b.result.WriteString(lineStyle.Sprintf("%s | %s\n", lineNum, line)) + b.result.WriteString(lineStyle.Sprintf("%s | ", lineNum)) + b.result.WriteString(line + "\n") } b.result.WriteString(lineStyle.Sprintf("%s|\n", padding)) @@ -244,21 +249,6 @@ func calculateMaxLineNumWidth(endLine int) int { return len(fmt.Sprintf("%d", endLine)) } -// expandTabs replaces tab characters('\t') with spaces. -// Assuming a table width of 8. -func expandTabs(line string) string { - var expanded strings.Builder - for i, ch := range line { - if ch == '\t' { - spaceCount := tabWidth - (i % tabWidth) - expanded.WriteString(strings.Repeat(" ", spaceCount)) - } else { - expanded.WriteRune(ch) - } - } - return expanded.String() -} - // calculateVisualColumn calculates the visual column position // in a string. taking into account tab characters. func calculateVisualColumn(line string, column int) int { @@ -279,26 +269,55 @@ func calculateVisualColumn(line string, column int) int { return visualColumn } +// findCommonIndent finds the common indent in the code snippet. func findCommonIndent(lines []string) string { if len(lines) == 0 { return "" } - commonIndentPrefix := strings.TrimLeft(lines[0], " \t") - commonIndentPrefix = lines[0][:len(lines[0])-len(commonIndentPrefix)] + // find first non-empty line's indent + var firstIndent []rune + for _, line := range lines { + // trimmed := strings.TrimSpace(line) + trimmed := strings.TrimLeftFunc(line, unicode.IsSpace) + if trimmed != "" { + firstIndent = []rune(line[:len(line)-len(trimmed)]) + break + } + } - for _, line := range lines[1:] { - if strings.TrimSpace(line) == "" { - continue // ignore empty lines + if len(firstIndent) == 0 { + return "" + } + + // search common indent for all non-empty lines + for _, line := range lines { + trimmed := strings.TrimLeftFunc(line, unicode.IsSpace) + if trimmed == "" { + continue } - for !strings.HasPrefix(line, commonIndentPrefix) { - commonIndentPrefix = commonIndentPrefix[:len(commonIndentPrefix)-1] - if len(commonIndentPrefix) == 0 { - return "" - } + currentIndent := []rune(line[:len(line)-len(trimmed)]) + firstIndent = commonPrefix(firstIndent, currentIndent) + + if len(firstIndent) == 0 { + break } } - return commonIndentPrefix + return string(firstIndent) +} + +// commonPrefix finds the common prefix of two strings. +func commonPrefix(a, b []rune) []rune { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if a[i] != b[i] { + return a[:i] + } + } + return a[:minLen] } diff --git a/formatter/deprecated.go b/formatter/deprecated.go new file mode 100644 index 0000000..4f89c25 --- /dev/null +++ b/formatter/deprecated.go @@ -0,0 +1,18 @@ +package formatter + +import ( + "github.com/gnolang/tlin/internal" + tt "github.com/gnolang/tlin/internal/types" +) + +type DeprecatedFuncFormatter struct{} + +func (f *DeprecatedFuncFormatter) Format(issue tt.Issue, snippet *internal.SourceCode) string { + builder := NewIssueFormatterBuilder(issue, snippet) + return builder. + AddHeader(errorHeader). + AddCodeSnippet(). + AddUnderlineAndMessage(). + AddNote(). + Build() +} diff --git a/formatter/formatter_test.go b/formatter/formatter_test.go index 84579d5..ec3c046 100644 --- a/formatter/formatter_test.go +++ b/formatter/formatter_test.go @@ -169,7 +169,7 @@ func TestFormatIssuesWithArrows_UnnecessaryElse(t *testing.T) { "package main", "", "func unnecessaryElse() bool {", - " if condition {", + " if condition {", " return true", " } else {", " return false", @@ -280,9 +280,9 @@ func TestFindCommonIndent(t *testing.T) { { name: "tab indent", lines: []string{ - "\tif foo {", - "\t\tprintln()", - "\t}", + " if foo {", + " println()", + " }", }, expected: "\t", }, diff --git a/internal/checker/deprecate.go b/internal/checker/deprecate.go new file mode 100644 index 0000000..06c8600 --- /dev/null +++ b/internal/checker/deprecate.go @@ -0,0 +1,121 @@ +package checker + +import ( + "go/ast" + "go/token" + "strconv" + "strings" +) + +// pkgPath -> funcName -> alternative +type deprecatedFuncMap map[string]map[string]string + +// DeprecatedFunc represents a deprecated function. +type DeprecatedFunc struct { + Package string + Function string + Alternative string + Position token.Position +} + +// DeprecatedFuncChecker checks for deprecated functions. +type DeprecatedFuncChecker struct { + deprecatedFuncs deprecatedFuncMap +} + +func NewDeprecatedFuncChecker() *DeprecatedFuncChecker { + return &DeprecatedFuncChecker{ + deprecatedFuncs: make(deprecatedFuncMap), + } +} + +func (d *DeprecatedFuncChecker) Register(pkgName, funcName, alternative string) { + if _, ok := d.deprecatedFuncs[pkgName]; !ok { + d.deprecatedFuncs[pkgName] = make(map[string]string) + } + d.deprecatedFuncs[pkgName][funcName] = alternative +} + +// Check checks a AST node for deprecated functions. +// +// TODO: use this in the linter rule implementation +func (d *DeprecatedFuncChecker) Check( + filename string, + node *ast.File, + fset *token.FileSet, +) ([]DeprecatedFunc, error) { + var found []DeprecatedFunc + + packageAliases := make(map[string]string) + for _, imp := range node.Imports { + path, err := strconv.Unquote(imp.Path.Value) + if err != nil { + continue + } + name := "" + if imp.Name != nil { + name = imp.Name.Name + } else { + parts := strings.Split(path, "/") + name = parts[len(parts)-1] + } + packageAliases[name] = path + } + + ast.Inspect(node, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + switch fun := call.Fun.(type) { + case *ast.SelectorExpr: + ident, ok := fun.X.(*ast.Ident) + if !ok { + return true + } + pkgAlias := ident.Name + funcName := fun.Sel.Name + + pkgPath, ok := packageAliases[pkgAlias] + if !ok { + // Not a package alias, possibly a method call + return true + } + + if funcs, ok := d.deprecatedFuncs[pkgPath]; ok { + if alt, ok := funcs[funcName]; ok { + found = append(found, DeprecatedFunc{ + Package: pkgPath, + Function: funcName, + Alternative: alt, + Position: fset.Position(call.Pos()), + }) + } + } + case *ast.Ident: + // Handle functions imported via dot imports + funcName := fun.Name + // Check dot-imported packages + for alias, pkgPath := range packageAliases { + if alias != "." { + continue + } + if funcs, ok := d.deprecatedFuncs[pkgPath]; ok { + if alt, ok := funcs[funcName]; ok { + found = append(found, DeprecatedFunc{ + Package: pkgPath, + Function: funcName, + Alternative: alt, + Position: fset.Position(call.Pos()), + }) + break + } + } + } + } + return true + }) + + return found, nil +} diff --git a/internal/checker/deprecate_test.go b/internal/checker/deprecate_test.go new file mode 100644 index 0000000..bc1a491 --- /dev/null +++ b/internal/checker/deprecate_test.go @@ -0,0 +1,250 @@ +package checker + +import ( + "go/parser" + "go/token" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRegisterDeprecatedFunctions(t *testing.T) { + t.Parallel() + checker := NewDeprecatedFuncChecker() + + checker.Register("fmt", "Println", "fmt.Print") + checker.Register("os", "Remove", "os.RemoveAll") + + expected := deprecatedFuncMap{ + "fmt": {"Println": "fmt.Print"}, + "os": {"Remove": "os.RemoveAll"}, + } + + assert.Equal(t, expected, checker.deprecatedFuncs) +} + +func TestCheck(t *testing.T) { + t.Parallel() + src := ` +package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("Hello, World!") + os.Remove("some_file.txt") +} +` + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "example.go", src, 0) + if err != nil { + t.Fatalf("Failed to parse file: %v", err) + } + + checker := NewDeprecatedFuncChecker() + checker.Register("fmt", "Println", "fmt.Print") + checker.Register("os", "Remove", "os.RemoveAll") + + deprecated, err := checker.Check("example.go", node, fset) + if err != nil { + t.Fatalf("Check failed with error: %v", err) + } + + expected := []DeprecatedFunc{ + { + Package: "fmt", + Function: "Println", + Alternative: "fmt.Print", + Position: token.Position{ + Filename: "example.go", + Offset: 55, + Line: 10, + Column: 2, + }, + }, + { + Package: "os", + Function: "Remove", + Alternative: "os.RemoveAll", + Position: token.Position{ + Filename: "example.go", + Offset: 85, + Line: 11, + Column: 2, + }, + }, + } + + assert.Equal(t, expected, deprecated) +} + +func TestCheckNoDeprecated(t *testing.T) { + t.Parallel() + src := ` +package main + +import "fmt" + +func main() { + fmt.Printf("Hello, %s\n", "World") +} +` + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "example.go", src, 0) + if err != nil { + t.Fatalf("Failed to parse file: %v", err) + } + + checker := NewDeprecatedFuncChecker() + checker.Register("fmt", "Println", "fmt.Print") + checker.Register("os", "Remove", "os.RemoveAll") + + deprecated, err := checker.Check("example.go", node, fset) + if err != nil { + t.Fatalf("Check failed with error: %v", err) + } + + assert.Equal(t, 0, len(deprecated)) +} + +func TestCheckMultipleDeprecatedCalls(t *testing.T) { + t.Parallel() + src := ` +package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("Hello") + fmt.Println("World") + os.Remove("file1.txt") + os.Remove("file2.txt") +} +` + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "example.go", src, 0) + if err != nil { + t.Fatalf("Failed to parse file: %v", err) + } + + checker := NewDeprecatedFuncChecker() + checker.Register("fmt", "Println", "fmt.Print") + checker.Register("os", "Remove", "os.RemoveAll") + + deprecated, err := checker.Check("example.go", node, fset) + if err != nil { + t.Fatalf("Check failed with error: %v", err) + } + + expected := []DeprecatedFunc{ + {Package: "fmt", Function: "Println", Alternative: "fmt.Print"}, + {Package: "fmt", Function: "Println", Alternative: "fmt.Print"}, + {Package: "os", Function: "Remove", Alternative: "os.RemoveAll"}, + {Package: "os", Function: "Remove", Alternative: "os.RemoveAll"}, + } + + assert.Equal(t, len(expected), len(deprecated)) + for i, exp := range expected { + assertDeprecatedFuncEqual(t, exp, deprecated[i]) + } +} + +func TestDeprecatedFuncCheckerWithAlias(t *testing.T) { + t.Parallel() + + c := NewDeprecatedFuncChecker() + c.Register("math", "Sqrt", "math.Pow") + + const src = ` +package main + +import ( + m "math" + "fmt" +) + +type MyStruct struct{} + +func (s *MyStruct) Method() {} + +func main() { + result := m.Sqrt(42) + _ = result + + fmt.Println("Hello") + + s := &MyStruct{} + s.Method() +} +` + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "sample.go", src, parser.ParseComments) + assert.NoError(t, err) + + results, err := c.Check("sample.go", node, fset) + assert.NoError(t, err) + + assert.Equal(t, 1, len(results)) + + expected := DeprecatedFunc{ + Package: "math", + Function: "Sqrt", + Alternative: "math.Pow", + } + + assertDeprecatedFuncEqual(t, expected, results[0]) +} + +func TestDeprecatedFuncChecker_Check_DotImport(t *testing.T) { + t.Parallel() + + checker := NewDeprecatedFuncChecker() + checker.Register("fmt", "Println", "Use fmt.Print instead") + + src := ` +package main + +import . "fmt" + +func main() { + Println("Hello, World!") +} +` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", src, 0) + assert.NoError(t, err) + + found, err := checker.Check("test.go", f, fset) + assert.NoError(t, err) + + assert.Equal(t, 1, len(found)) + + if len(found) > 0 { + df := found[0] + if df.Package != "fmt" || df.Function != "Println" || df.Alternative != "Use fmt.Print instead" { + t.Errorf("unexpected deprecated function info: %+v", df) + } + } +} + +func assertDeprecatedFuncEqual(t *testing.T, expected, actual DeprecatedFunc) { + t.Helper() + assert.Equal(t, expected.Package, actual.Package) + assert.Equal(t, expected.Function, actual.Function) + assert.Equal(t, expected.Alternative, actual.Alternative) + assert.NotEmpty(t, actual.Position.Filename) + assert.Greater(t, actual.Position.Offset, 0) + assert.Greater(t, actual.Position.Line, 0) + assert.Greater(t, actual.Position.Column, 0) +} diff --git a/internal/engine.go b/internal/engine.go index a2bb38c..aab91ba 100644 --- a/internal/engine.go +++ b/internal/engine.go @@ -42,6 +42,7 @@ func (e *Engine) registerDefaultRules() { func (e *Engine) initDefaultRules() { e.defaultRules = []LintRule{ &GolangciLintRule{}, + &DeprecateFuncRule{}, &EarlyReturnOpportunityRule{}, &SimplifySliceExprRule{}, &UnnecessaryConversionRule{}, diff --git a/internal/lints/deprecate_func.go b/internal/lints/deprecate_func.go new file mode 100644 index 0000000..058b7e5 --- /dev/null +++ b/internal/lints/deprecate_func.go @@ -0,0 +1,51 @@ +package lints + +import ( + "fmt" + "go/ast" + "go/token" + + "github.com/gnolang/tlin/internal/checker" + tt "github.com/gnolang/tlin/internal/types" +) + +func DetectDeprecatedFunctions( + filename string, + node *ast.File, + fset *token.FileSet, +) ([]tt.Issue, error) { + deprecated := checker.NewDeprecatedFuncChecker() + + deprecated.Register("std", "SetOrigCaller", "std.PrevRealm") + deprecated.Register("std", "GetOrigCaller", "std.PrevRealm") + deprecated.Register("std", "TestSetOrigCaller", "") + + dfuncs, err := deprecated.Check(filename, node, fset) + if err != nil { + return nil, err + } + + issues := make([]tt.Issue, 0, len(dfuncs)) + for _, df := range dfuncs { + issues = append(issues, tt.Issue{ + Rule: "deprecated", + Filename: filename, + Start: df.Position, + End: df.Position, + Message: createDeprecationMessage(df), + Suggestion: df.Alternative, + }) + } + + return issues, nil +} + +func createDeprecationMessage(df checker.DeprecatedFunc) string { + msg := "Use of deprecated function" + if df.Alternative != "" { + msg = fmt.Sprintf("%s. Please use %s instead.", msg, df.Alternative) + return msg + } + msg = fmt.Sprintf("%s. Please remove it.", msg) + return msg +} diff --git a/internal/rule_set.go b/internal/rule_set.go index b30f06b..e9acf25 100644 --- a/internal/rule_set.go +++ b/internal/rule_set.go @@ -31,6 +31,16 @@ func (r *GolangciLintRule) Name() string { return "golangci-lint" } +type DeprecateFuncRule struct{} + +func (r *DeprecateFuncRule) Check(filename string, node *ast.File, fset *token.FileSet) ([]tt.Issue, error) { + return lints.DetectDeprecatedFunctions(filename, node, fset) +} + +func (r *DeprecateFuncRule) Name() string { + return "deprecated-function" +} + type SimplifySliceExprRule struct{} func (r *SimplifySliceExprRule) Check(filename string, node *ast.File, fset *token.FileSet) ([]tt.Issue, error) { diff --git a/testdata/deprecate/orig_caller.gno b/testdata/deprecate/orig_caller.gno new file mode 100644 index 0000000..f3f9385 --- /dev/null +++ b/testdata/deprecate/orig_caller.gno @@ -0,0 +1,15 @@ +package main + +import ( + "std" +) + +func main() { + origCaller := std.GetOrigCaller() + println(origCaller.String()) + + prev := std.PrevRealm() + prevAddr := prev.Addr() + + println(prevAddr.String()) +}