From 7c62f69ac54a40aa33b90a89ad2cc3d54ddfb20f Mon Sep 17 00:00:00 2001 From: James Date: Fri, 28 May 2021 22:36:33 +0200 Subject: [PATCH 1/2] allow external packages --- parse/parse.go | 132 +++++++++++++++++- parse/parse_test.go | 96 +++++++++++++ .../test/extrapackages/built_extrapackage.go | 11 ++ parse/test/extrapackages/extrapkg/my_type.go | 3 + .../extrapackages/generic_extrapackage.go | 9 ++ parse/typesets.go | 1 + 6 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 parse/test/extrapackages/built_extrapackage.go create mode 100644 parse/test/extrapackages/extrapkg/my_type.go create mode 100644 parse/test/extrapackages/generic_extrapackage.go diff --git a/parse/parse.go b/parse/parse.go index bbf8a82..292c139 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -3,6 +3,7 @@ package parse import ( "bufio" "bytes" + "errors" "fmt" "go/ast" "go/parser" @@ -172,6 +173,7 @@ func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]stri // Generics parses the source file and generates the bytes replacing the // generic types for the keys map with the specific types (its value). func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) { + var err error var localUnwantedLinePrefixes [][]byte localUnwantedLinePrefixes = append(localUnwantedLinePrefixes, unwantedLinePrefixes...) @@ -181,10 +183,28 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t totalOutput := header + extraPkgsMap := make(map[string]struct{}) + for _, typeSet := range typeSets { + // for each typeset, investigate whether it is a same-package import path (like "User"), or a long import path (like "example.com/a/b.User") + typeSetWithoutFullImportPaths := make(map[string]string) + for typeName, fullTypePath := range typeSet { + shortTypePath := fullTypePath + lastSlashIdx := strings.LastIndex(fullTypePath, "/") + if lastSlashIdx != -1 { + // there is a full package name + lastIdxDot := strings.LastIndex(fullTypePath, ".") + if lastIdxDot < lastSlashIdx { + return nil, errors.New("error parsing full package name: last dot index before last slash index") + } + extraPkgsMap[fullTypePath[:lastIdxDot]] = struct{}{} + shortTypePath = fullTypePath[lastSlashIdx+1:] + } + typeSetWithoutFullImportPaths[typeName] = shortTypePath + } // generate the specifics - parsed, err := generateSpecific(filename, in, typeSet) + parsed, err := generateSpecific(filename, in, typeSetWithoutFullImportPaths) if err != nil { return nil, err } @@ -240,7 +260,6 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t cleanOutput := strings.Join(cleanOutputLines, "") output := []byte(cleanOutput) - var err error // change package name if pkgName != "" { @@ -252,9 +271,118 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t return nil, &errImports{Err: err} } + var extraPkgs []string + for importPath := range extraPkgsMap { + extraPkgs = append(extraPkgs, importPath) + } + + // add extra imports + output, err = AddExtraImports(bytes.NewBuffer(output), extraPkgs) + if err != nil { + return nil, err + } + return output, nil } +func AddExtraImports(in io.Reader, importPaths []string) ([]byte, error) { + var outLines []string + var importBlockProcessed bool + + var existingImports []string + var multiLineImportBlockOpen bool + scanner := bufio.NewScanner(in) + for scanner.Scan() { + if importBlockProcessed && !multiLineImportBlockOpen { + outLines = append(outLines, scanner.Text()) + continue + } + + if multiLineImportBlockOpen { + switch scanner.Text() { + case ")": + multiLineImportBlockOpen = false + importBlockProcessed = true + default: + existingImports = append(existingImports, scanner.Text()) + } + continue + } + + if !strings.HasPrefix(scanner.Text(), "import ") { + // not the import line. Not interesting to us. + outLines = append(outLines, scanner.Text()) + continue + } + + // line is an import declaration. Find out if it is a one or multi line declaration + if strings.Contains(scanner.Text(), `"`) { + // one line declaration + importPath := strings.TrimPrefix(scanner.Text(), "import ") + existingImports = append(existingImports, fmt.Sprintf("\t%s", importPath)) + importBlockProcessed = true + continue + } + + // open multi-line declaration + multiLineImportBlockOpen = true + } + if scanner.Err() != nil { + return nil, scanner.Err() + } + + var imports []string + imports = append(imports, existingImports...) + + for _, importPath := range importPaths { + // if not already in imports, add to list. Otherwise skip this one + var isAlreadyInImports bool + for _, existingImport := range imports { + if strings.Trim(strings.TrimSpace(existingImport), `"`) == strings.Trim(strings.TrimSpace(importPath), `"`) { + isAlreadyInImports = true + break + } + } + if isAlreadyInImports { + continue + } + imports = append(imports, fmt.Sprintf("\t\"%s\"", importPath)) + } + + // add imports + if len(imports) != 0 { + var pkgDefProcessed bool + + for i, outLine := range outLines { + if strings.HasPrefix(outLine, "package ") { + pkgDefProcessed = true + continue + } + + if pkgDefProcessed { + var importBlock string + if len(imports) == 1 { + importBlock = fmt.Sprintf("\nimport %s", strings.TrimPrefix(imports[0], "\t")) + } else { + importBlock = fmt.Sprintf("\nimport (\n%s\n)", strings.Join(imports, "\n")) + } + + if len(existingImports) == 0 { + // if there were no existing imports, we have to pad the end of the imports line + importBlock += "\n" + } + outLinesPreImport := outLines[:i] + outLinesPostImport := outLines[i+1:] + + outLines = append(append(outLinesPreImport, importBlock), outLinesPostImport...) + break + } + } + } + + return []byte(strings.Join(outLines, "\n") + "\n"), nil +} + func makeLine(s string) string { return fmt.Sprintln(strings.TrimRight(s, linefeed)) } diff --git a/parse/parse_test.go b/parse/parse_test.go index 79c4f09..9dc0aa1 100644 --- a/parse/parse_test.go +++ b/parse/parse_test.go @@ -1,6 +1,7 @@ package parse_test import ( + "bytes" "io/ioutil" "log" "strings" @@ -125,6 +126,13 @@ var tests = []struct { expectedOut: `test/buildtags/buildtags_expected_nostrip.go`, tag: "", }, + { + filename: "buildtags.go", + in: `test/extrapackages/generic_extrapackage.go`, + types: []map[string]string{{"ForeignType": "github.com/cheekybits/genny/parse/test/extrapackages/extrapkg.MyType"}}, + expectedOut: `test/extrapackages/built_extrapackage.go`, + tag: "", + }, } func TestParse(t *testing.T) { @@ -164,3 +172,91 @@ func contents(s string) string { } return s } + +func Test_AddExtraImports(t *testing.T) { + type testDef struct { + Name string + In string + Out string + ExtraImports []string + } + + tests := []testDef{ + { + Name: "single import", + In: `package x + +import "fmt" + +func sayHello(user userpkg.User) { + return fmt.Sprintf("hello %s", user.Name) +} +`, + Out: `package x + +import ( + "fmt" + "example.com/me/userpkg" +) + +func sayHello(user userpkg.User) { + return fmt.Sprintf("hello %s", user.Name) +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, { + Name: "no imports", + In: `package x + +func sayHello(user userpkg.User) { + return "hello " + user.Name +} +`, + Out: `package x + +import ( + "example.com/me/userpkg" +) + +func sayHello(user userpkg.User) { + return "hello " + user.Name +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, { + Name: "multiple imports", + In: `package x + +import ( + "fmt" + "io" +) + +func sayHello(writer io.Writer, user userpkg.User) { + return fmt.Fprintf(writer, "hello %s", user.Name) +} +`, + Out: `package x + +import ( + "fmt" + "io" + "example.com/me/userpkg" +) + +func sayHello(writer io.Writer, user userpkg.User) { + return fmt.Fprintf(writer, "hello %s", user.Name) +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, + } + for _, test := range tests { + + out, err := parse.AddExtraImports(bytes.NewBufferString(test.In), test.ExtraImports) + assert.NoError(t, err) + + assert.Equal(t, test.Out, string(out), test.Name) + } + +} diff --git a/parse/test/extrapackages/built_extrapackage.go b/parse/test/extrapackages/built_extrapackage.go new file mode 100644 index 0000000..a5314c5 --- /dev/null +++ b/parse/test/extrapackages/built_extrapackage.go @@ -0,0 +1,11 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package extrapackages + +import "github.com/cheekybits/genny/parse/test/extrapackages/extrapkg" + +func ExtrapkgMyTypeSayHello(a extrapkg.MyType) extrapkg.MyType { + return a +} diff --git a/parse/test/extrapackages/extrapkg/my_type.go b/parse/test/extrapackages/extrapkg/my_type.go new file mode 100644 index 0000000..f599f02 --- /dev/null +++ b/parse/test/extrapackages/extrapkg/my_type.go @@ -0,0 +1,3 @@ +package extrapkg + +type MyType struct{} diff --git a/parse/test/extrapackages/generic_extrapackage.go b/parse/test/extrapackages/generic_extrapackage.go new file mode 100644 index 0000000..9a36864 --- /dev/null +++ b/parse/test/extrapackages/generic_extrapackage.go @@ -0,0 +1,9 @@ +package extrapackages + +import "github.com/cheekybits/genny/generic" + +type ForeignType generic.Type + +func ForeignTypeSayHello(a ForeignType) ForeignType { + return a +} diff --git a/parse/typesets.go b/parse/typesets.go index c30b972..382c81f 100644 --- a/parse/typesets.go +++ b/parse/typesets.go @@ -20,6 +20,7 @@ const ( // Person=man Animal=dog Animal2=cat // Person=man,woman Animal=dog,cat // Person=man,woman,child Animal=dog,cat Place=london,paris +// Animal=example.com/a/b.Dog,example.com/a/b.Cat func TypeSet(arg string) ([]map[string]string, error) { types := make(map[string][]string) From 790fadaa3354739734de6195c67adc31d653373f Mon Sep 17 00:00:00 2001 From: James Date: Mon, 7 Jun 2021 20:55:47 +0200 Subject: [PATCH 2/2] extract getting full package names to separate function --- parse/parse.go | 66 +++++++++++------ parse/parse_same_pkg_test.go | 136 +++++++++++++++++++++++++++++++++++ parse/parse_test.go | 89 ----------------------- 3 files changed, 180 insertions(+), 111 deletions(-) create mode 100644 parse/parse_same_pkg_test.go diff --git a/parse/parse.go b/parse/parse.go index 292c139..7ceefba 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -10,7 +10,6 @@ import ( "go/scanner" "go/token" "io" - "os" "strings" "unicode" @@ -93,7 +92,10 @@ func subTypeIntoLine(line, typeTemplate, specificType string) string { func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) { // ensure we are at the beginning of the file - in.Seek(0, os.SEEK_SET) + _, err := in.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } // parse the source file fs := token.NewFileSet() @@ -126,8 +128,10 @@ func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]stri } } - in.Seek(0, os.SEEK_SET) - + _, err = in.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } var buf bytes.Buffer comment := "" @@ -170,6 +174,29 @@ func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]stri return buf.Bytes(), nil } +// removeFullyQualifiedImportPathAndAddToMap takes a typeset with import paths such as example.com/a/b.MyType and returns a new typeset with b.MyType. +// It adds the package path to the extraPackagesSet, to be later added into the imports section of the generated file. +// It returns the typeSet without package path, or an error if applicable +func removeFullyQualifiedImportPathAndAddToMap(typeSet map[string]string, extraPackagesSet map[string]struct{}) (map[string]string, error) { + typeSetWithoutFullImportPaths := make(map[string]string) + for typeName, fullTypePath := range typeSet { + shortTypePath := fullTypePath + lastSlashIdx := strings.LastIndex(fullTypePath, "/") + if lastSlashIdx != -1 { + // there is a full package name + lastIdxDot := strings.LastIndex(fullTypePath, ".") + if lastIdxDot < lastSlashIdx { + return nil, errors.New("error parsing full package name: last dot index before last slash index") + } + extraPackagesSet[fullTypePath[:lastIdxDot]] = struct{}{} + shortTypePath = fullTypePath[lastSlashIdx+1:] + } + typeSetWithoutFullImportPaths[typeName] = shortTypePath + } + + return typeSetWithoutFullImportPaths, nil +} + // Generics parses the source file and generates the bytes replacing the // generic types for the keys map with the specific types (its value). func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) { @@ -185,26 +212,16 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t extraPkgsMap := make(map[string]struct{}) - for _, typeSet := range typeSets { + for _, typeSetWithPackagePaths := range typeSets { // for each typeset, investigate whether it is a same-package import path (like "User"), or a long import path (like "example.com/a/b.User") - typeSetWithoutFullImportPaths := make(map[string]string) - for typeName, fullTypePath := range typeSet { - shortTypePath := fullTypePath - lastSlashIdx := strings.LastIndex(fullTypePath, "/") - if lastSlashIdx != -1 { - // there is a full package name - lastIdxDot := strings.LastIndex(fullTypePath, ".") - if lastIdxDot < lastSlashIdx { - return nil, errors.New("error parsing full package name: last dot index before last slash index") - } - extraPkgsMap[fullTypePath[:lastIdxDot]] = struct{}{} - shortTypePath = fullTypePath[lastSlashIdx+1:] - } - typeSetWithoutFullImportPaths[typeName] = shortTypePath + typeSet, err := removeFullyQualifiedImportPathAndAddToMap(typeSetWithPackagePaths, extraPkgsMap) + if err != nil { + return nil, err } // generate the specifics - parsed, err := generateSpecific(filename, in, typeSetWithoutFullImportPaths) + // parsed, err := generateSpecific(filename, in, typeSetWithoutFullImportPaths) + parsed, err := generateSpecific(filename, in, typeSet) if err != nil { return nil, err } @@ -277,7 +294,7 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t } // add extra imports - output, err = AddExtraImports(bytes.NewBuffer(output), extraPkgs) + output, err = addExtraImports(bytes.NewReader(output), extraPkgs) if err != nil { return nil, err } @@ -285,10 +302,15 @@ func Generics(filename, outputFilename, pkgName, tag string, in io.ReadSeeker, t return output, nil } -func AddExtraImports(in io.Reader, importPaths []string) ([]byte, error) { +func addExtraImports(in io.ReadSeeker, importPaths []string) ([]byte, error) { var outLines []string var importBlockProcessed bool + _, err := in.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + var existingImports []string var multiLineImportBlockOpen bool scanner := bufio.NewScanner(in) diff --git a/parse/parse_same_pkg_test.go b/parse/parse_same_pkg_test.go new file mode 100644 index 0000000..db2b862 --- /dev/null +++ b/parse/parse_same_pkg_test.go @@ -0,0 +1,136 @@ +package parse + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_removeFullyQualifiedImportPathAndAddToMap(t *testing.T) { + type args struct { + typeSet map[string]string + extraPackagesSet map[string]struct{} + } + tests := []struct { + name string + args args + want map[string]string + wantExtraPackagesSet map[string]struct{} + wantErr bool + }{ + { + name: "one field example", + args: args{ + typeSet: map[string]string{ + "GenericField": "example.com/a/b.MyType", + }, + extraPackagesSet: map[string]struct{}{}, + }, + want: map[string]string{ + "GenericField": "b.MyType", + }, + wantExtraPackagesSet: map[string]struct{}{ + "example.com/a/b": {}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := removeFullyQualifiedImportPathAndAddToMap(tt.args.typeSet, tt.args.extraPackagesSet) + if (err != nil) != tt.wantErr { + t.Errorf("removeFullyQualifiedImportPathAndAddToMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantExtraPackagesSet, tt.args.extraPackagesSet) + }) + } +} + +func Test_addExtraImports(t *testing.T) { + type testDef struct { + Name string + In string + Out string + ExtraImports []string + } + + tests := []testDef{ + { + Name: "single import", + In: `package x + +import "fmt" + +func sayHello(user userpkg.User) { + return fmt.Sprintf("hello %s", user.Name) +} +`, + Out: `package x + +import ( + "fmt" + "example.com/me/userpkg" +) + +func sayHello(user userpkg.User) { + return fmt.Sprintf("hello %s", user.Name) +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, { + Name: "no imports", + In: `package x + +func sayHello(user userpkg.User) { + return "hello " + user.Name +} +`, + Out: `package x + +import ( + "example.com/me/userpkg" +) + +func sayHello(user userpkg.User) { + return "hello " + user.Name +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, { + Name: "multiple imports", + In: `package x + +import ( + "fmt" + "io" +) + +func sayHello(writer io.Writer, user userpkg.User) { + return fmt.Fprintf(writer, "hello %s", user.Name) +} +`, + Out: `package x + +import ( + "fmt" + "io" + "example.com/me/userpkg" +) + +func sayHello(writer io.Writer, user userpkg.User) { + return fmt.Fprintf(writer, "hello %s", user.Name) +} +`, + ExtraImports: []string{"example.com/me/userpkg"}, + }, + } + for _, test := range tests { + + out, err := addExtraImports(bytes.NewReader([]byte(test.In)), test.ExtraImports) + assert.NoError(t, err) + + assert.Equal(t, test.Out, string(out), test.Name) + } +} diff --git a/parse/parse_test.go b/parse/parse_test.go index 9dc0aa1..bb3bdc6 100644 --- a/parse/parse_test.go +++ b/parse/parse_test.go @@ -1,7 +1,6 @@ package parse_test import ( - "bytes" "io/ioutil" "log" "strings" @@ -172,91 +171,3 @@ func contents(s string) string { } return s } - -func Test_AddExtraImports(t *testing.T) { - type testDef struct { - Name string - In string - Out string - ExtraImports []string - } - - tests := []testDef{ - { - Name: "single import", - In: `package x - -import "fmt" - -func sayHello(user userpkg.User) { - return fmt.Sprintf("hello %s", user.Name) -} -`, - Out: `package x - -import ( - "fmt" - "example.com/me/userpkg" -) - -func sayHello(user userpkg.User) { - return fmt.Sprintf("hello %s", user.Name) -} -`, - ExtraImports: []string{"example.com/me/userpkg"}, - }, { - Name: "no imports", - In: `package x - -func sayHello(user userpkg.User) { - return "hello " + user.Name -} -`, - Out: `package x - -import ( - "example.com/me/userpkg" -) - -func sayHello(user userpkg.User) { - return "hello " + user.Name -} -`, - ExtraImports: []string{"example.com/me/userpkg"}, - }, { - Name: "multiple imports", - In: `package x - -import ( - "fmt" - "io" -) - -func sayHello(writer io.Writer, user userpkg.User) { - return fmt.Fprintf(writer, "hello %s", user.Name) -} -`, - Out: `package x - -import ( - "fmt" - "io" - "example.com/me/userpkg" -) - -func sayHello(writer io.Writer, user userpkg.User) { - return fmt.Fprintf(writer, "hello %s", user.Name) -} -`, - ExtraImports: []string{"example.com/me/userpkg"}, - }, - } - for _, test := range tests { - - out, err := parse.AddExtraImports(bytes.NewBufferString(test.In), test.ExtraImports) - assert.NoError(t, err) - - assert.Equal(t, test.Out, string(out), test.Name) - } - -}