Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for type names with the full package path #73

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 156 additions & 6 deletions parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package parse
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"io"
"os"
"strings"
"unicode"

Expand Down Expand Up @@ -92,7 +92,10 @@ func subTypeIntoLine(line, typeTemplate, specificType string) string {
func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) {

// ensure we are at the beginning of the file
in.Seek(0, os.SEEK_SET)
_, err := in.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}

// parse the source file
fs := token.NewFileSet()
Expand Down Expand Up @@ -125,8 +128,10 @@ func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]stri
}
}

in.Seek(0, os.SEEK_SET)

_, err = in.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
var buf bytes.Buffer

comment := ""
Expand Down Expand Up @@ -169,9 +174,33 @@ func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]stri
return buf.Bytes(), nil
}

// removeFullyQualifiedImportPathAndAddToMap takes a typeset with import paths such as example.com/a/b.MyType and returns a new typeset with b.MyType.
// It adds the package path to the extraPackagesSet, to be later added into the imports section of the generated file.
// It returns the typeSet without package path, or an error if applicable
func removeFullyQualifiedImportPathAndAddToMap(typeSet map[string]string, extraPackagesSet map[string]struct{}) (map[string]string, error) {
typeSetWithoutFullImportPaths := make(map[string]string)
for typeName, fullTypePath := range typeSet {
shortTypePath := fullTypePath
lastSlashIdx := strings.LastIndex(fullTypePath, "/")
if lastSlashIdx != -1 {
// there is a full package name
lastIdxDot := strings.LastIndex(fullTypePath, ".")
if lastIdxDot < lastSlashIdx {
return nil, errors.New("error parsing full package name: last dot index before last slash index")
}
extraPackagesSet[fullTypePath[:lastIdxDot]] = struct{}{}
shortTypePath = fullTypePath[lastSlashIdx+1:]
}
typeSetWithoutFullImportPaths[typeName] = shortTypePath
}

return typeSetWithoutFullImportPaths, nil
}

// Generics parses the source file and generates the bytes replacing the
// generic types for the keys map with the specific types (its value).
func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) {
var err error
var localUnwantedLinePrefixes [][]byte
localUnwantedLinePrefixes = append(localUnwantedLinePrefixes, unwantedLinePrefixes...)

Expand All @@ -181,9 +210,17 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t

totalOutput := header

for _, typeSet := range typeSets {
extraPkgsMap := make(map[string]struct{})

for _, typeSetWithPackagePaths := range typeSets {
// for each typeset, investigate whether it is a same-package import path (like "User"), or a long import path (like "example.com/a/b.User")
typeSet, err := removeFullyQualifiedImportPathAndAddToMap(typeSetWithPackagePaths, extraPkgsMap)
if err != nil {
return nil, err
}

// generate the specifics
// parsed, err := generateSpecific(filename, in, typeSetWithoutFullImportPaths)
parsed, err := generateSpecific(filename, in, typeSet)
if err != nil {
return nil, err
Expand Down Expand Up @@ -240,7 +277,6 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t
cleanOutput := strings.Join(cleanOutputLines, "")

output := []byte(cleanOutput)
var err error

// change package name
if pkgName != "" {
Expand All @@ -252,9 +288,123 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t
return nil, &errImports{Err: err}
}

var extraPkgs []string
for importPath := range extraPkgsMap {
extraPkgs = append(extraPkgs, importPath)
}

// add extra imports
output, err = addExtraImports(bytes.NewReader(output), extraPkgs)
if err != nil {
return nil, err
}

return output, nil
}

func addExtraImports(in io.ReadSeeker, importPaths []string) ([]byte, error) {
var outLines []string
var importBlockProcessed bool

_, err := in.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}

var existingImports []string
var multiLineImportBlockOpen bool
scanner := bufio.NewScanner(in)
for scanner.Scan() {
if importBlockProcessed && !multiLineImportBlockOpen {
outLines = append(outLines, scanner.Text())
continue
}

if multiLineImportBlockOpen {
switch scanner.Text() {
case ")":
multiLineImportBlockOpen = false
importBlockProcessed = true
default:
existingImports = append(existingImports, scanner.Text())
}
continue
}

if !strings.HasPrefix(scanner.Text(), "import ") {
// not the import line. Not interesting to us.
outLines = append(outLines, scanner.Text())
continue
}

// line is an import declaration. Find out if it is a one or multi line declaration
if strings.Contains(scanner.Text(), `"`) {
// one line declaration
importPath := strings.TrimPrefix(scanner.Text(), "import ")
existingImports = append(existingImports, fmt.Sprintf("\t%s", importPath))
importBlockProcessed = true
continue
}

// open multi-line declaration
multiLineImportBlockOpen = true
}
if scanner.Err() != nil {
return nil, scanner.Err()
}

var imports []string
imports = append(imports, existingImports...)

for _, importPath := range importPaths {
// if not already in imports, add to list. Otherwise skip this one
var isAlreadyInImports bool
for _, existingImport := range imports {
if strings.Trim(strings.TrimSpace(existingImport), `"`) == strings.Trim(strings.TrimSpace(importPath), `"`) {
isAlreadyInImports = true
break
}
}
if isAlreadyInImports {
continue
}
imports = append(imports, fmt.Sprintf("\t\"%s\"", importPath))
}

// add imports
if len(imports) != 0 {
var pkgDefProcessed bool

for i, outLine := range outLines {
if strings.HasPrefix(outLine, "package ") {
pkgDefProcessed = true
continue
}

if pkgDefProcessed {
var importBlock string
if len(imports) == 1 {
importBlock = fmt.Sprintf("\nimport %s", strings.TrimPrefix(imports[0], "\t"))
} else {
importBlock = fmt.Sprintf("\nimport (\n%s\n)", strings.Join(imports, "\n"))
}

if len(existingImports) == 0 {
// if there were no existing imports, we have to pad the end of the imports line
importBlock += "\n"
}
outLinesPreImport := outLines[:i]
outLinesPostImport := outLines[i+1:]

outLines = append(append(outLinesPreImport, importBlock), outLinesPostImport...)
break
}
}
}

return []byte(strings.Join(outLines, "\n") + "\n"), nil
}

func makeLine(s string) string {
return fmt.Sprintln(strings.TrimRight(s, linefeed))
}
Expand Down
136 changes: 136 additions & 0 deletions parse/parse_same_pkg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package parse

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_removeFullyQualifiedImportPathAndAddToMap(t *testing.T) {
type args struct {
typeSet map[string]string
extraPackagesSet map[string]struct{}
}
tests := []struct {
name string
args args
want map[string]string
wantExtraPackagesSet map[string]struct{}
wantErr bool
}{
{
name: "one field example",
args: args{
typeSet: map[string]string{
"GenericField": "example.com/a/b.MyType",
},
extraPackagesSet: map[string]struct{}{},
},
want: map[string]string{
"GenericField": "b.MyType",
},
wantExtraPackagesSet: map[string]struct{}{
"example.com/a/b": {},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := removeFullyQualifiedImportPathAndAddToMap(tt.args.typeSet, tt.args.extraPackagesSet)
if (err != nil) != tt.wantErr {
t.Errorf("removeFullyQualifiedImportPathAndAddToMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantExtraPackagesSet, tt.args.extraPackagesSet)
})
}
}

func Test_addExtraImports(t *testing.T) {
type testDef struct {
Name string
In string
Out string
ExtraImports []string
}

tests := []testDef{
{
Name: "single import",
In: `package x

import "fmt"

func sayHello(user userpkg.User) {
return fmt.Sprintf("hello %s", user.Name)
}
`,
Out: `package x

import (
"fmt"
"example.com/me/userpkg"
)

func sayHello(user userpkg.User) {
return fmt.Sprintf("hello %s", user.Name)
}
`,
ExtraImports: []string{"example.com/me/userpkg"},
}, {
Name: "no imports",
In: `package x

func sayHello(user userpkg.User) {
return "hello " + user.Name
}
`,
Out: `package x

import (
"example.com/me/userpkg"
)

func sayHello(user userpkg.User) {
return "hello " + user.Name
}
`,
ExtraImports: []string{"example.com/me/userpkg"},
}, {
Name: "multiple imports",
In: `package x

import (
"fmt"
"io"
)

func sayHello(writer io.Writer, user userpkg.User) {
return fmt.Fprintf(writer, "hello %s", user.Name)
}
`,
Out: `package x

import (
"fmt"
"io"
"example.com/me/userpkg"
)

func sayHello(writer io.Writer, user userpkg.User) {
return fmt.Fprintf(writer, "hello %s", user.Name)
}
`,
ExtraImports: []string{"example.com/me/userpkg"},
},
}
for _, test := range tests {

out, err := addExtraImports(bytes.NewReader([]byte(test.In)), test.ExtraImports)
assert.NoError(t, err)

assert.Equal(t, test.Out, string(out), test.Name)
}
}
7 changes: 7 additions & 0 deletions parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ var tests = []struct {
expectedOut: `test/buildtags/buildtags_expected_nostrip.go`,
tag: "",
},
{
filename: "buildtags.go",
in: `test/extrapackages/generic_extrapackage.go`,
types: []map[string]string{{"ForeignType": "github.com/cheekybits/genny/parse/test/extrapackages/extrapkg.MyType"}},
expectedOut: `test/extrapackages/built_extrapackage.go`,
tag: "",
},
}

func TestParse(t *testing.T) {
Expand Down
11 changes: 11 additions & 0 deletions parse/test/extrapackages/built_extrapackage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny

package extrapackages

import "github.com/cheekybits/genny/parse/test/extrapackages/extrapkg"

func ExtrapkgMyTypeSayHello(a extrapkg.MyType) extrapkg.MyType {
return a
}
3 changes: 3 additions & 0 deletions parse/test/extrapackages/extrapkg/my_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package extrapkg

type MyType struct{}
Loading