diff --git a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/Example_mock.go b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/Example_mock.go index ee625e78..8e3b5d57 100644 --- a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/Example_mock.go +++ b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/Example_mock.go @@ -5,8 +5,8 @@ package mocks import ( - "github.com/vektra/mockery/v2/pkg/fixtures/12345678/http" - "github.com/vektra/mockery/v2/pkg/fixtures/http" + http1 "github.com/vektra/mockery/v2/pkg/fixtures/12345678/http" + http0 "github.com/vektra/mockery/v2/pkg/fixtures/http" "net/http" mock "github.com/stretchr/testify/mock" ) @@ -95,7 +95,7 @@ func (_c *Example_A_Call) RunAndReturn(run func()http.Flusher) *Example_A_Call { // B provides a mock function for the type Example -func (_mock *Example) B(fixtureshttp string) http.MyStruct { +func (_mock *Example) B(fixtureshttp string) http0.MyStruct { ret := _mock.Called(fixtureshttp) if len(ret) == 0 { @@ -103,11 +103,11 @@ func (_mock *Example) B(fixtureshttp string) http.MyStruct { } - var r0 http.MyStruct - if returnFunc, ok := ret.Get(0).(func(string) http.MyStruct); ok { + var r0 http0.MyStruct + if returnFunc, ok := ret.Get(0).(func(string) http0.MyStruct); ok { r0 = returnFunc(fixtureshttp) } else { - r0 = ret.Get(0).(http.MyStruct) + r0 = ret.Get(0).(http0.MyStruct) } return r0 } @@ -134,19 +134,19 @@ func (_c *Example_B_Call) Run(run func(fixtureshttp string)) *Example_B_Call { return _c } -func (_c *Example_B_Call) Return(myStruct http.MyStruct) *Example_B_Call { +func (_c *Example_B_Call) Return(myStruct http0.MyStruct) *Example_B_Call { _c.Call.Return(myStruct) return _c } -func (_c *Example_B_Call) RunAndReturn(run func(fixtureshttp string)http.MyStruct) *Example_B_Call { +func (_c *Example_B_Call) RunAndReturn(run func(fixtureshttp string)http0.MyStruct) *Example_B_Call { _c.Call.Return(run) return _c } // C provides a mock function for the type Example -func (_mock *Example) C(fixtureshttp string) http.MyStruct { +func (_mock *Example) C(fixtureshttp string) http1.MyStruct { ret := _mock.Called(fixtureshttp) if len(ret) == 0 { @@ -154,11 +154,11 @@ func (_mock *Example) C(fixtureshttp string) http.MyStruct { } - var r0 http.MyStruct - if returnFunc, ok := ret.Get(0).(func(string) http.MyStruct); ok { + var r0 http1.MyStruct + if returnFunc, ok := ret.Get(0).(func(string) http1.MyStruct); ok { r0 = returnFunc(fixtureshttp) } else { - r0 = ret.Get(0).(http.MyStruct) + r0 = ret.Get(0).(http1.MyStruct) } return r0 } @@ -185,12 +185,12 @@ func (_c *Example_C_Call) Run(run func(fixtureshttp string)) *Example_C_Call { return _c } -func (_c *Example_C_Call) Return(myStruct http.MyStruct) *Example_C_Call { +func (_c *Example_C_Call) Return(myStruct http1.MyStruct) *Example_C_Call { _c.Call.Return(myStruct) return _c } -func (_c *Example_C_Call) RunAndReturn(run func(fixtureshttp string)http.MyStruct) *Example_C_Call { +func (_c *Example_C_Call) RunAndReturn(run func(fixtureshttp string)http1.MyStruct) *Example_C_Call { _c.Call.Return(run) return _c } diff --git a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/HasConflictingNestedImports_mock.go b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/HasConflictingNestedImports_mock.go index 101ae6b7..075666f4 100644 --- a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/HasConflictingNestedImports_mock.go +++ b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/HasConflictingNestedImports_mock.go @@ -5,12 +5,12 @@ package mocks import ( - "github.com/vektra/mockery/v2/pkg/fixtures/http" + http0 "github.com/vektra/mockery/v2/pkg/fixtures/http" "net/http" mock "github.com/stretchr/testify/mock" ) - + // NewHasConflictingNestedImports creates a new instance of HasConflictingNestedImports. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewHasConflictingNestedImports (t interface { @@ -39,32 +39,32 @@ func (_m *HasConflictingNestedImports) EXPECT() *HasConflictingNestedImports_Exp return &HasConflictingNestedImports_Expecter{mock: &_m.Mock} } - + // Get provides a mock function for the type HasConflictingNestedImports -func (_mock *HasConflictingNestedImports) Get(path string) (http.Response, error) { +func (_mock *HasConflictingNestedImports) Get(path string) (http.Response, error) { ret := _mock.Called(path) if len(ret) == 0 { panic("no return value specified for Get") } - + var r0 http.Response var r1 error if returnFunc, ok := ret.Get(0).(func(string) (http.Response, error)); ok { return returnFunc(path) - } + } if returnFunc, ok := ret.Get(0).(func(string) http.Response); ok { r0 = returnFunc(path) } else { r0 = ret.Get(0).(http.Response) - } + } if returnFunc, ok := ret.Get(1).(func(string) error); ok { r1 = returnFunc(path) } else { r1 = ret.Error(1) - } + } return r0, r1 } @@ -99,23 +99,23 @@ func (_c *HasConflictingNestedImports_Get_Call) RunAndReturn(run func(path strin _c.Call.Return(run) return _c } - + // Z provides a mock function for the type HasConflictingNestedImports -func (_mock *HasConflictingNestedImports) Z() http.MyStruct { +func (_mock *HasConflictingNestedImports) Z() http0.MyStruct { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for Z") } - - var r0 http.MyStruct - if returnFunc, ok := ret.Get(0).(func() http.MyStruct); ok { + + var r0 http0.MyStruct + if returnFunc, ok := ret.Get(0).(func() http0.MyStruct); ok { r0 = returnFunc() } else { - r0 = ret.Get(0).(http.MyStruct) - } + r0 = ret.Get(0).(http0.MyStruct) + } return r0 } @@ -140,14 +140,14 @@ func (_c *HasConflictingNestedImports_Z_Call) Run(run func()) *HasConflictingNes return _c } -func (_c *HasConflictingNestedImports_Z_Call) Return(myStruct http.MyStruct) *HasConflictingNestedImports_Z_Call { +func (_c *HasConflictingNestedImports_Z_Call) Return(myStruct http0.MyStruct) *HasConflictingNestedImports_Z_Call { _c.Call.Return(myStruct) return _c } -func (_c *HasConflictingNestedImports_Z_Call) RunAndReturn(run func()http.MyStruct) *HasConflictingNestedImports_Z_Call { +func (_c *HasConflictingNestedImports_Z_Call) RunAndReturn(run func()http0.MyStruct) *HasConflictingNestedImports_Z_Call { _c.Call.Return(run) return _c } - + diff --git a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ImportsSameAsPackage_mock.go b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ImportsSameAsPackage_mock.go index 147aba1d..612a19b7 100644 --- a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ImportsSameAsPackage_mock.go +++ b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ImportsSameAsPackage_mock.go @@ -5,7 +5,7 @@ package mocks import ( - "github.com/vektra/mockery/v2/pkg/fixtures" + test0 "github.com/vektra/mockery/v2/pkg/fixtures" "github.com/vektra/mockery/v2/pkg/fixtures/redefined_type_b" mock "github.com/stretchr/testify/mock" ) @@ -92,7 +92,7 @@ func (_c *ImportsSameAsPackage_A_Call) RunAndReturn(run func()test.B) *ImportsSa // B provides a mock function for the type ImportsSameAsPackage -func (_mock *ImportsSameAsPackage) B() test.KeyManager { +func (_mock *ImportsSameAsPackage) B() test0.KeyManager { ret := _mock.Called() if len(ret) == 0 { @@ -100,12 +100,12 @@ func (_mock *ImportsSameAsPackage) B() test.KeyManager { } - var r0 test.KeyManager - if returnFunc, ok := ret.Get(0).(func() test.KeyManager); ok { + var r0 test0.KeyManager + if returnFunc, ok := ret.Get(0).(func() test0.KeyManager); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(test.KeyManager) + r0 = ret.Get(0).(test0.KeyManager) } } return r0 @@ -132,19 +132,19 @@ func (_c *ImportsSameAsPackage_B_Call) Run(run func()) *ImportsSameAsPackage_B_C return _c } -func (_c *ImportsSameAsPackage_B_Call) Return(keyManager test.KeyManager) *ImportsSameAsPackage_B_Call { +func (_c *ImportsSameAsPackage_B_Call) Return(keyManager test0.KeyManager) *ImportsSameAsPackage_B_Call { _c.Call.Return(keyManager) return _c } -func (_c *ImportsSameAsPackage_B_Call) RunAndReturn(run func()test.KeyManager) *ImportsSameAsPackage_B_Call { +func (_c *ImportsSameAsPackage_B_Call) RunAndReturn(run func()test0.KeyManager) *ImportsSameAsPackage_B_Call { _c.Call.Return(run) return _c } // C provides a mock function for the type ImportsSameAsPackage -func (_mock *ImportsSameAsPackage) C(c test.C) { _mock.Called(c) +func (_mock *ImportsSameAsPackage) C(c test0.C) { _mock.Called(c) return } @@ -163,9 +163,9 @@ func (_e *ImportsSameAsPackage_Expecter) C(c interface{}, ) *ImportsSameAsPackag return &ImportsSameAsPackage_C_Call{Call: _e.mock.On("C",c, )} } -func (_c *ImportsSameAsPackage_C_Call) Run(run func(c test.C)) *ImportsSameAsPackage_C_Call { +func (_c *ImportsSameAsPackage_C_Call) Run(run func(c test0.C)) *ImportsSameAsPackage_C_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(test.C),) + run(args[0].(test0.C),) }) return _c } @@ -175,7 +175,7 @@ func (_c *ImportsSameAsPackage_C_Call) Return() *ImportsSameAsPackage_C_Call { return _c } -func (_c *ImportsSameAsPackage_C_Call) RunAndReturn(run func(c test.C)) *ImportsSameAsPackage_C_Call { +func (_c *ImportsSameAsPackage_C_Call) RunAndReturn(run func(c test0.C)) *ImportsSameAsPackage_C_Call { _c.Run(run) return _c } diff --git a/pkg/registry/method_scope.go b/pkg/registry/method_scope.go index c09b4c2e..1fc57fad 100644 --- a/pkg/registry/method_scope.go +++ b/pkg/registry/method_scope.go @@ -39,6 +39,19 @@ func NewMethodScope(r *Registry) *MethodScope { } } +func (m *MethodScope) ResolveVariableNameCollisions(ctx context.Context) { + log := zerolog.Ctx(ctx) + for _, v := range m.vars { + varLog := log.With().Str("variable-name", v.Name).Logger() + newName := m.AllocateName(v.Name) + if newName != v.Name { + varLog.Debug().Str("new-name", newName).Msg("variable was found to conflict with previously allocated name. Giving new name.") + } + v.Name = newName + m.visibleNames[v.Name] = nil + } +} + // AllocateName creates a new variable name in the lexical scope of the method. // It ensures the returned name does not conflict with any other name visible // to the scope. It registers the returned name in the lexical scope such that @@ -57,7 +70,6 @@ func (m *MethodScope) AllocateName(prefix string) string { } break } - m.visibleNames[suggestion] = nil return suggestion } @@ -79,7 +91,9 @@ func (m *MethodScope) AddVar(ctx context.Context, vr *types.Var, prefix string) log.Debug().Str("visible-name", key).Msg("visible name") } name := m.AllocateName(varName(vr, prefix)) - log.Debug().Str("allocated-name", name).Msg("allocated name for variable in method") + // This suggested name is subject to change because it might come into conflict + // with a future package import. + log.Debug().Str("suggested-name", name).Msg("suggested name for variable in method") v := Var{ vr: vr, @@ -91,7 +105,7 @@ func (m *MethodScope) AddVar(ctx context.Context, vr *types.Var, prefix string) return &v } -func (m MethodScope) populateImportNamedType( +func (m *MethodScope) populateImportNamedType( ctx context.Context, t interface { Obj() *types.TypeName @@ -114,7 +128,7 @@ func (m MethodScope) populateImportNamedType( } } -func (m MethodScope) PopulateImports(ctx context.Context, t types.Type) map[string]*Package { +func (m *MethodScope) PopulateImports(ctx context.Context, t types.Type) map[string]*Package { imports := map[string]*Package{} m.populateImports(ctx, t, imports) return imports @@ -125,7 +139,7 @@ func (m MethodScope) PopulateImports(ctx context.Context, t types.Type) map[stri // one (ex: map[a.Type]b.Type). // // Returned are the imports that were added for the given type. -func (m MethodScope) populateImports(ctx context.Context, t types.Type, imports map[string]*Package) { +func (m *MethodScope) populateImports(ctx context.Context, t types.Type, imports map[string]*Package) { log := zerolog.Ctx(ctx).With(). Str("type-str", t.String()).Logger() switch t := t.(type) { diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 40acf1a2..16de1b81 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -2,12 +2,9 @@ package registry import ( "context" - "errors" "fmt" - "go/ast" "go/types" "sort" - "strings" "github.com/rs/zerolog" "golang.org/x/tools/go/packages" @@ -21,7 +18,6 @@ type Registry struct { dstPkgPath string srcPkg *packages.Package srcPkgName string - aliases map[string]string imports map[string]*Package importQualifiers map[string]*Package } @@ -33,7 +29,6 @@ func New(srcPkg *packages.Package, dstPkgPath string) (*Registry, error) { dstPkgPath: dstPkgPath, srcPkg: srcPkg, srcPkgName: srcPkg.Name, - aliases: parseImportsAliases(srcPkg.Syntax), imports: make(map[string]*Package), importQualifiers: make(map[string]*Package), }, nil @@ -89,14 +84,17 @@ func (r *Registry) AddImport(ctx context.Context, pkg *types.Package) *Package { return imprt } - imprt := Package{pkg: pkg, Alias: r.aliases[path]} - var aliasSuggestion string + imprt := Package{pkg: pkg} + originalQualifier := imprt.Qualifier() + var aliasSuggestion string = imprt.Qualifier() for i := 0; ; i++ { if _, conflict := r.importQualifiers[aliasSuggestion]; conflict { aliasSuggestion = fmt.Sprintf("%s%d", imprt.Qualifier(), i) continue } - imprt.Alias = aliasSuggestion + if originalQualifier != aliasSuggestion { + imprt.Alias = aliasSuggestion + } break } @@ -117,46 +115,3 @@ func (r Registry) Imports() []*Package { }) return imports } - -func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) { - pkgs, err := packages.Load(&packages.Config{ - Mode: mode, - Dir: srcDir, - }) - if err != nil { - return nil, err - } - if len(pkgs) == 0 { - return nil, errors.New("package not found") - } - if len(pkgs) > 1 { - return nil, errors.New("found more than one package") - } - if errs := pkgs[0].Errors; len(errs) != 0 { - if len(errs) == 1 { - return nil, errs[0] - } - return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1) - } - return pkgs[0], nil -} - -func pkgInDir(pkgName, dir string) bool { - currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) - if err != nil { - return false - } - return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName -} - -func parseImportsAliases(syntaxTree []*ast.File) map[string]string { - aliases := make(map[string]string) - for _, syntax := range syntaxTree { - for _, imprt := range syntax.Imports { - if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" { - aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name - } - } - } - return aliases -} diff --git a/pkg/template_generator.go b/pkg/template_generator.go index 89a2ad13..6ebd3c40 100644 --- a/pkg/template_generator.go +++ b/pkg/template_generator.go @@ -77,19 +77,6 @@ func (g *TemplateGenerator) methodData(ctx context.Context, method *types.Func) signature := method.Type().(*types.Signature) params := make([]template.ParamData, signature.Params().Len()) - // First pass to populate all imports first. This greatly simplifies name - // collision logic to first allocate the package qualifiers in the file-global - // scope first before allocating variable names. - for j := 0; j < signature.Params().Len(); j++ { - param := signature.Params().At(j) - methodScope.PopulateImports(ctx, param.Type()) - } - for j := 0; j < signature.Results().Len(); j++ { - param := signature.Results().At(j) - methodScope.PopulateImports(ctx, param.Type()) - } - - // Now add parameter names. Their imports have already been processed. for j := 0; j < signature.Params().Len(); j++ { param := signature.Params().At(j) log.Debug().Str("param-string", param.String()).Msg("found parameter") @@ -172,6 +159,17 @@ func (g *TemplateGenerator) Generate( for i := 0; i < iface.NumMethods(); i++ { methods[i] = g.methodData(ctx, iface.Method(i)) } + // Now that all methods have been generated, we need to resolve naming + // conflicts that arise between variable names and package qualifiers. + for _, method := range methods { + method.Scope.ResolveVariableNameCollisions( + zerolog. + Ctx(ctx). + With(). + Str("method-name", method.Name). + Logger(). + WithContext(ctx)) + } mockData = append(mockData, template.MockData{ InterfaceName: ifaceMock.Name,