Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
Use "." to refer to the current path's package in reflect mode (#387)
Browse files Browse the repository at this point in the history
* feat: use "." to refer to the current path's package

* doc: update reflect mode

* fix: generated code lose package name
  • Loading branch information
XSAM committed Feb 2, 2020
1 parent 3dcdcb6 commit 5c85495
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,15 @@ that uses reflection to understand interfaces. It is enabled
by passing two non-flag arguments: an import path, and a
comma-separated list of symbols.

You can use "." to refer to the current path's package.

Example:

```bash
mockgen database/sql/driver Conn,Driver

# Convenient for `go:generate`.
mockgen . Conn,Driver
```

The `mockgen` command is used to generate source code for a mock
Expand Down
16 changes: 14 additions & 2 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,26 @@ func main() {

var pkg *model.Package
var err error
var packageName string
if *source != "" {
pkg, err = sourceMode(*source)
} else {
if flag.NArg() != 2 {
usage()
log.Fatal("Expected exactly two arguments")
}
pkg, err = reflectMode(flag.Arg(0), strings.Split(flag.Arg(1), ","))
packageName = flag.Arg(0)
if packageName == "." {
dir, err := os.Getwd()
if err != nil {
log.Fatalf("Get current directory failed: %v", err)
}
packageName, err = packageNameOfDir(dir)
if err != nil {
log.Fatalf("Parse package name failed: %v", err)
}
}
pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
}
if err != nil {
log.Fatalf("Loading input failed: %v", err)
Expand Down Expand Up @@ -130,7 +142,7 @@ func main() {
if *source != "" {
g.filename = *source
} else {
g.srcPackage = flag.Arg(0)
g.srcPackage = packageName
g.srcInterfaces = flag.Arg(1)
}

Expand Down
55 changes: 45 additions & 10 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"go/build"
"go/parser"
"go/token"
"io/ioutil"
"log"
"path"
"path/filepath"
Expand All @@ -48,19 +49,10 @@ func sourceMode(source string) (*model.Package, error) {
return nil, fmt.Errorf("failed getting source directory: %v", err)
}

cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir}
pkgs, err := packages.Load(cfg, "file="+source)
packageImport, err := parsePackageImport(source, srcDir)
if err != nil {
return nil, err
}
if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 {
return nil, errors.New("loading package failed")
}

packageImport := pkgs[0].PkgPath

// It is illegal to import a _test package.
packageImport = strings.TrimSuffix(packageImport, "_test")

fs := token.NewFileSet()
file, err := parser.ParseFile(fs, source, nil, 0)
Expand Down Expand Up @@ -519,3 +511,46 @@ func isVariadic(f *ast.FuncType) bool {
_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
return ok
}

// packageNameOfDir get package import path via dir
func packageNameOfDir(srcDir string) (string, error) {
files, err := ioutil.ReadDir(srcDir)
if err != nil {
log.Fatal(err)
}

var goFilePath string
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") {
goFilePath = file.Name()
break
}
}
if goFilePath == "" {
return "", fmt.Errorf("go source file not found %s", srcDir)
}

packageImport, err := parsePackageImport(goFilePath, srcDir)
if err != nil {
return "", err
}
return packageImport, nil
}

// parseImportPackage get package import path via source file
func parsePackageImport(source, srcDir string) (string, error) {
cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir}
pkgs, err := packages.Load(cfg, "file="+source)
if err != nil {
return "", err
}
if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 {
return "", errors.New("loading package failed")
}

packageImport := pkgs[0].PkgPath

// It is illegal to import a _test package.
packageImport = strings.TrimSuffix(packageImport, "_test")
return packageImport, nil
}

0 comments on commit 5c85495

Please sign in to comment.