Skip to content

Commit

Permalink
feat(hz): support extends for thrift (#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed May 31, 2023
1 parent 4b43f98 commit 5800385
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 5 deletions.
4 changes: 4 additions & 0 deletions cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func Init() *cli.App {
protoPluginsFlag := cli.StringSliceFlag{Name: "protoc-plugins", Usage: "Specify plugins for the protoc. ({plugin_name}:{options}:{out_dir})"}
noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse}
forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew}
enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends}

jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr}
unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty}
Expand Down Expand Up @@ -223,6 +224,7 @@ func Init() *cli.App {
&optPkgFlag,
&noRecurseFlag,
&forceNewFlag,
&enableExtendsFlag,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down Expand Up @@ -255,6 +257,7 @@ func Init() *cli.App {
&protoOptionsFlag,
&optPkgFlag,
&noRecurseFlag,
&enableExtendsFlag,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down Expand Up @@ -306,6 +309,7 @@ func Init() *cli.App {
&thriftOptionsFlag,
&protoOptionsFlag,
&noRecurseFlag,
&enableExtendsFlag,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down
1 change: 1 addition & 0 deletions cmd/hz/config/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Argument struct {
HandlerByMethod bool
ForceNew bool
SnakeStyleMiddleware bool
EnableExtends bool

CustomizeLayout string
CustomizeLayoutData string
Expand Down
184 changes: 180 additions & 4 deletions cmd/hz/thrift/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"strings"

"github.com/cloudwego/hertz/cmd/hz/config"
"github.com/cloudwego/hertz/cmd/hz/generator"
"github.com/cloudwego/hertz/cmd/hz/generator/model"
"github.com/cloudwego/hertz/cmd/hz/meta"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/cloudwego/thriftgo/generator/golang"
"github.com/cloudwego/thriftgo/generator/golang/styles"
"github.com/cloudwego/thriftgo/parser"
"github.com/cloudwego/thriftgo/semantic"
)

/*---------------------------Import-----------------------------*/
Expand All @@ -48,25 +50,39 @@ func getGoPackage(ast *parser.Thrift, pkgMap map[string]string) string {

/*---------------------------Service-----------------------------*/

func astToService(ast *parser.Thrift, resolver *Resolver, cmdType string) ([]*generator.Service, error) {
func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) ([]*generator.Service, error) {
ss := ast.GetServices()
out := make([]*generator.Service, 0, len(ss))
var models model.Models

extendServices := getExtendServices(ast)
for _, s := range ss {
// if the service is extended, it is not processed
if extendServices.exist(s.Name) && args.EnableExtends {
logs.Debugf("%s is extended, so skip it\n", s.Name)
continue
}

resolver.ExportReferred(true, false)
service := &generator.Service{
Name: s.GetName(),
}
service.BaseDomain = ""
domainAnno := getAnnotation(s.Annotations, ApiBaseDomain)
if len(domainAnno) == 1 {
if cmdType == meta.CmdClient {
if args.CmdType == meta.CmdClient {
service.BaseDomain = domainAnno[0]
}
}

ms := s.GetFunctions()
if len(s.Extends) != 0 && args.EnableExtends {
// all the services that are extended to the current service
extendsFuncs, err := getAllExtendFunction(s, ast, resolver, args)
if err != nil {
return nil, fmt.Errorf("parser extend function failed, err=%v", err)
}
ms = append(ms, extendsFuncs...)
}
methods := make([]*generator.HttpMethod, 0, len(ms))
clientMethods := make([]*generator.ClientMethod, 0, len(ms))
for _, m := range ms {
Expand Down Expand Up @@ -134,7 +150,7 @@ func astToService(ast *parser.Thrift, resolver *Resolver, cmdType string) ([]*ge
}
models.MergeMap(method.Models)
methods = append(methods, method)
if cmdType == meta.CmdClient {
if args.CmdType == meta.CmdClient {
clientMethod := &generator.ClientMethod{}
clientMethod.HttpMethod = method
rt, err := resolver.ResolveIdentifier(m.Arguments[0].GetType().GetName())
Expand Down Expand Up @@ -250,6 +266,166 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
return nil
}

type extendServiceList []string

func (svr extendServiceList) exist(serviceName string) bool {
for _, s := range svr {
if s == serviceName {
return true
}
}
return false
}

func getExtendServices(ast *parser.Thrift) (res extendServiceList) {
for a := range ast.DepthFirstSearch() {
for _, svc := range a.Services {
if len(svc.Extends) > 0 {
res = append(res, svc.Extends)
}
}
}
return
}

func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Resolver, args *config.Argument) (res []*parser.Function, err error) {
if len(svc.Extends) == 0 {
return
}
parts := semantic.SplitType(svc.Extends)
switch len(parts) {
case 1:
if resolver.mainPkg.Ast.Filename == ast.Filename { // extended current service for master IDL
extendSvc, found := ast.GetService(parts[0])
if found {
funcs := extendSvc.GetFunctions()
// determine if it still has extends
extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
return res, nil
} else { // extended current service for other IDL
extendSvc, found := ast.GetService(parts[0])
if found {
base, err := addResolverDependency(resolver, ast, args)
if err != nil {
return nil, err
}
funcs := extendSvc.GetFunctions()
for _, f := range funcs {
// the method of other file is extended, and the package of req/resp needs to be changed
// ex. base.thrift -> Resp Method(Req){}
// base.Resp Method(base.Req){}
// todo: support container for Struct
if len(f.Arguments) > 0 {
if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() {
f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name
}
}
if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() {
f.FunctionType.Name = base + "." + f.FunctionType.Name
}
}
extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
return res, nil
}
case 2:
refAst, found := ast.GetReference(parts[0])
base, err := addResolverDependency(resolver, refAst, args)
if err != nil {
return nil, err
}
// ff the service extends from other files, it has to resolve the dependencies of other files as well
for _, dep := range refAst.Includes {
_, err := addResolverDependency(resolver, dep.Reference, args)
if err != nil {
return nil, err
}
}
if found {
extendSvc, found := refAst.GetService(parts[1])
if found {
funcs := extendSvc.GetFunctions()
for _, f := range funcs {
// the method of other file is extended, and the package of req/resp needs to be changed
// ex. base.thrift -> Resp Method(Req){}
// base.Resp Method(base.Req){}
// todo: support container for Struct
if len(f.Arguments) > 0 {
if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() {
f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name
}
}
if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() {
f.FunctionType.Name = base + "." + f.FunctionType.Name
}
}
extendFuncs, err := getAllExtendFunction(extendSvc, refAst, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
}
return res, nil
}

return res, nil
}

func getUniqueResolveDependentName(name string, resolver *Resolver) string {
rawName := name
for i := 0; i < 10000; i++ {
if _, exist := resolver.deps[name]; !exist {
return name
}
name = rawName + fmt.Sprint(i)
}

return name
}

func addResolverDependency(resolver *Resolver, ast *parser.Thrift, args *config.Argument) (string, error) {
namespace, err := resolver.LoadOne(ast)
if err != nil {
return "", err
}
baseName := util.BaseName(ast.Filename, ".thrift")
if refPkg, exist := resolver.refPkgs[baseName]; !exist {
resolver.deps[baseName] = namespace
} else {
if ast.Filename != refPkg.Ast.Filename {
baseName = getUniqueResolveDependentName(baseName, resolver)
resolver.deps[baseName] = namespace
}
}
pkg := getGoPackage(ast, args.OptPkgMap)
impt := ast.Filename
pkgName := util.SplitPackageName(pkg, "")
pkgName, err = util.GetPackageUniqueName(pkgName)
if err != nil {
return "", err
}
ref := &PackageReference{baseName, impt, &model.Model{
FilePath: ast.Filename,
Package: pkg,
PackageName: pkgName,
}, ast, false}
if _, exist := resolver.refPkgs[baseName]; !exist {
resolver.refPkgs[baseName] = ref
}

return baseName, nil
}

/*---------------------------Model-----------------------------*/

var BaseThrift = parser.Thrift{}
Expand Down
2 changes: 1 addition & 1 deletion cmd/hz/thrift/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (plugin *Plugin) getPackageInfo() (*generator.HttpPackage, error) {
return nil, fmt.Errorf("go package for '%s' is not defined", ast.GetFilename())
}

services, err := astToService(ast, rs, args.CmdType)
services, err := astToService(ast, rs, args)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 5800385

Please sign in to comment.