diff --git a/Gopkg.lock b/Gopkg.lock index 36bc13e6..b559d11f 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -116,7 +116,7 @@ revision = "ebd802b4f1c2b52d221ca35f624d3312b765422c" [[projects]] - digest = "1:500d16a939de31dc868a67fa95cbeb5bc7cbd70c59d35e570d7c805c0d5c814f" + digest = "1:2e5c631a7074c11e3d9f4ad68bbeacb6ab51c5e9ad78714c1903bc39c5d40f50" name = "github.com/palantir/conjure-go" packages = [ "conjure", @@ -127,8 +127,8 @@ "conjure/werrorexpressions", ] pruneopts = "UT" - revision = "24cd98b20961d5ecb3734cb114c600807d832861" - version = "v4.0.3" + revision = "a655527e82280aa461c838d52ebc275d9a18eca5" + version = "v4.1.1" [[projects]] digest = "1:62e516988833f683d36991a3747965e49fb8dafb397915257faff88f99b68725" diff --git a/Gopkg.toml b/Gopkg.toml index a96822cb..b9f7896f 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -22,7 +22,7 @@ [[constraint]] name = "github.com/palantir/conjure-go" - version = "4.0.3" + version = "4.1.1" [[constraint]] name = "github.com/go-bindata/go-bindata" diff --git a/vendor/github.com/palantir/conjure-go/conjure/serverwriter.go b/vendor/github.com/palantir/conjure-go/conjure/serverwriter.go index 9cfc5de9..9c312093 100644 --- a/vendor/github.com/palantir/conjure-go/conjure/serverwriter.go +++ b/vendor/github.com/palantir/conjure-go/conjure/serverwriter.go @@ -214,8 +214,7 @@ func getResourceFunction(endpointDefinition spec.EndpointDefinition) string { func AstForServerInterface(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) ([]astgen.ASTDecl, error) { serviceName := serviceDefinition.ServiceName.Name - isClient := false - interfaceAST, _, err := serviceInterfaceAST(serviceDefinition, info, false, isClient) + interfaceAST, _, err := serverServiceInterfaceAST(serviceDefinition, info, serviceASTConfig{}) if err != nil { return nil, errors.Wrapf(err, "failed to generate interface for service %q", serviceName) } diff --git a/vendor/github.com/palantir/conjure-go/conjure/servicewriter.go b/vendor/github.com/palantir/conjure-go/conjure/servicewriter.go index 98501640..1dc9c3cf 100644 --- a/vendor/github.com/palantir/conjure-go/conjure/servicewriter.go +++ b/vendor/github.com/palantir/conjure-go/conjure/servicewriter.go @@ -43,14 +43,21 @@ const ( httpClientImportPath = "github.com/palantir/conjure-go-runtime/conjure-go-client/httpclient" httpClientClientType = expression.Type("httpclient.Client") httpClientPkgName = "httpclient" + + tokenProviderVar = "tokenProvider" + tokenProviderType = expression.Type("httpclient.TokenProvider") + tokenProviderImportPath = "github.com/palantir/conjure-go-runtime/conjure-go-client/httpclient" ) +type serviceASTConfig struct { + withAuth bool + withTokenProvider bool +} + func astForService(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) ([]astgen.ASTDecl, StringSet, error) { allImports := NewStringSet() serviceName := serviceDefinition.ServiceName.Name - isClient := true - - interfaceAST, imports, err := serviceInterfaceAST(serviceDefinition, info, false, isClient) + interfaceAST, imports, err := clientServiceInterfaceAST(serviceDefinition, info, serviceASTConfig{}) if err != nil { return nil, nil, errors.Wrapf(err, "failed to generate interface for service %q", serviceName) } @@ -86,7 +93,7 @@ func astForService(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) if hasHeaderAuth || hasCookieAuth { // at least one endpoint uses authentication: define decorator structures - withAuthInterfaceAST, imports, err := serviceInterfaceAST(serviceDefinition, info, true, isClient) + withAuthInterfaceAST, imports, err := clientServiceInterfaceAST(serviceDefinition, info, serviceASTConfig{withAuth: true}) if err != nil { return nil, nil, errors.Wrapf(err, "failed to generate interface with auth for service %q", serviceName) } @@ -108,6 +115,24 @@ func astForService(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) components = append(components, withAuthMethodsAST...) allImports.AddAll(imports) } + + if canAddTokenInterface(serviceDefinition.Endpoints) { + withTokenServiceNewFunc, tokenServiceNewFuncImports := withTokenServiceNewFuncAST(serviceName, info) + components = append(components, withTokenServiceNewFunc) + allImports.AddAll(tokenServiceNewFuncImports) + + withTokenServiceStruct, tokenServiceStructImports := withTokenServiceStructAST(serviceName, info) + components = append(components, withTokenServiceStruct) + allImports.AddAll(tokenServiceStructImports) + + withTokenMethodsAST, imports, err := withTokenServiceStructMethodsAST(serviceDefinition, info) + if err != nil { + return nil, nil, errors.Wrapf(err, "Failed to generate methods with token provider for service %q", serviceName) + } + components = append(components, withTokenMethodsAST...) + allImports.AddAll(imports) + } + return components, allImports, nil } @@ -116,31 +141,68 @@ func hasAuth(endpoints []spec.EndpointDefinition) (hasHeaderAuth, hasCookieAuth if endpointDefinition.Auth == nil { continue } - possibleHeaderAuth, err := visitors.GetPossibleHeaderAuth(*endpointDefinition.Auth) - if err != nil { + if possibleHeaderAuth, err := visitors.GetPossibleHeaderAuth(*endpointDefinition.Auth); err != nil { return false, false, err - } - if possibleHeaderAuth != nil { + } else if possibleHeaderAuth != nil { hasHeaderAuth = true } - possibleCookieAuth, err := visitors.GetPossibleCookieAuth(*endpointDefinition.Auth) - if err != nil { + if possibleCookieAuth, err := visitors.GetPossibleCookieAuth(*endpointDefinition.Auth); err != nil { return false, false, err - } - if possibleCookieAuth != nil { + } else if possibleCookieAuth != nil { hasCookieAuth = true } } return } -func serviceInterfaceAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo, withAuth, isClient bool) (astgen.ASTDecl, StringSet, error) { +// Return true if all endpoints that require authentication are of the same auth type (header or cookie) and at least +// one endpoint has auth. The same auth type is required because a single token provider will likely not be useful for +// both auth types. +func canAddTokenInterface(endpoints []spec.EndpointDefinition) bool { + var hasHeaderAuth, hasCookieAuth bool + for _, endpointDefinition := range endpoints { + if endpointDefinition.Auth == nil { + continue + } + possibleHeaderAuth, err := visitors.GetPossibleHeaderAuth(*endpointDefinition.Auth) + if err != nil { + return false + } + hasHeaderAuth = hasHeaderAuth || possibleHeaderAuth != nil + possibleCookieAuth, err := visitors.GetPossibleCookieAuth(*endpointDefinition.Auth) + if err != nil { + return false + } + hasCookieAuth = hasCookieAuth || possibleCookieAuth != nil + if hasHeaderAuth && hasCookieAuth { + return false + } + } + return hasHeaderAuth || hasCookieAuth +} + +type generatorType bool + +const ( + generatorTypeClient generatorType = true + generatorTypeServer generatorType = false +) + +func clientServiceInterfaceAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo, config serviceASTConfig) (astgen.ASTDecl, StringSet, error) { + return serviceInterfaceAST(serviceDefinition, info, config, generatorTypeClient) +} + +func serverServiceInterfaceAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo, config serviceASTConfig) (astgen.ASTDecl, StringSet, error) { + return serviceInterfaceAST(serviceDefinition, info, config, generatorTypeServer) +} + +func serviceInterfaceAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo, config serviceASTConfig, generatorType generatorType) (astgen.ASTDecl, StringSet, error) { allImports := make(StringSet) var interfaceFuncs []*expression.InterfaceFunctionDecl serviceName := serviceDefinition.ServiceName.Name for _, endpointDefinition := range serviceDefinition.Endpoints { endpointName := string(endpointDefinition.EndpointName) - params, imports, err := paramsForEndpoint(endpointDefinition, info, withAuth, isClient) + params, imports, err := paramsForEndpoint(endpointDefinition, info, config, generatorType) if err != nil { return nil, nil, errors.Wrapf(err, "failed to generate parameters for endpoint %q", endpointName) } @@ -161,10 +223,10 @@ func serviceInterfaceAST(serviceDefinition spec.ServiceDefinition, info types.Pk } name := interfaceTypeName(serviceName) - if isClient { + if generatorType == generatorTypeClient { name = clientInterfaceTypeName(name) } - if withAuth { + if config.withAuth { name = withAuthName(name) } return &decl.Interface{ @@ -201,6 +263,22 @@ func withAuthServiceStructAST(serviceName string, hasHeaderAuth, hasCookieAuth b return decl.NewStruct(withAuthName(clientStructTypeName(serviceName)), fields, ""), imports } +func withTokenServiceStructAST(serviceName string, info types.PkgInfo) (astgen.ASTDecl, StringSet) { + imports := NewStringSet(types.Bearertoken.ImportPaths()...) + imports.Add(tokenProviderImportPath) + fields := []*expression.StructField{ + { + Name: wrappedClientVar, + Type: expression.Type(clientInterfaceTypeName(serviceName)), + }, + { + Name: tokenProviderVar, + Type: tokenProviderType, + }, + } + return decl.NewStruct(withTokenProviderName(clientStructTypeName(serviceName)), fields, ""), imports +} + func serviceNewFuncAST(serviceName string) (astgen.ASTDecl, StringSet) { return &decl.Function{ Name: "New" + clientInterfaceTypeName(serviceName), @@ -283,13 +361,49 @@ func withAuthServiceNewFuncAST(serviceName string, hasHeaderAuth, hasCookieAuth }, imports } +func withTokenServiceNewFuncAST(serviceName string, info types.PkgInfo) (astgen.ASTDecl, StringSet) { + funcParams := []*expression.FuncParam{ + expression.NewFuncParam(wrappedClientVar, expression.Type(clientInterfaceTypeName(serviceName))), + expression.NewFuncParam(tokenProviderVar, tokenProviderType), + } + imports := NewStringSet() + + structElems := []astgen.ASTExpr{ + expression.NewKeyValue(wrappedClientVar, expression.VariableVal(wrappedClientVar)), + expression.NewKeyValue(tokenProviderVar, expression.VariableVal(tokenProviderVar)), + } + + return &decl.Function{ + Name: withTokenProviderName("New" + clientInterfaceTypeName(serviceName)), + FuncType: expression.FuncType{ + Params: funcParams, + ReturnTypes: []expression.Type{ + expression.Type(withAuthName(clientInterfaceTypeName(serviceName))), + }, + }, + Body: []astgen.ASTStmt{ + &statement.Return{ + Values: []astgen.ASTExpr{ + &expression.Unary{ + Op: token.AND, + Receiver: &expression.CompositeLit{ + Type: expression.Type(withTokenProviderName(clientStructTypeName(serviceName))), + Elements: structElems, + }, + }, + }, + }, + }, + }, imports +} + func serviceStructMethodsAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) ([]astgen.ASTDecl, StringSet, error) { allImports := make(StringSet) var methods []astgen.ASTDecl serviceName := serviceDefinition.ServiceName.Name for _, endpointDefinition := range serviceDefinition.Endpoints { endpointName := string(endpointDefinition.EndpointName) - params, imports, err := paramsForEndpoint(endpointDefinition, info, false, true) + params, imports, err := paramsForEndpoint(endpointDefinition, info, serviceASTConfig{}, generatorTypeClient) if err != nil { return nil, nil, errors.Wrapf(err, "failed to generate parameters for endpoint %q", endpointName) } @@ -330,7 +444,7 @@ func withAuthServiceStructMethodsAST(serviceDefinition spec.ServiceDefinition, i serviceName := serviceDefinition.ServiceName.Name for _, endpointDefinition := range serviceDefinition.Endpoints { endpointName := string(endpointDefinition.EndpointName) - params, imports, err := paramsForEndpoint(endpointDefinition, info, true, true) + params, imports, err := paramsForEndpoint(endpointDefinition, info, serviceASTConfig{withAuth: true}, generatorTypeClient) if err != nil { return nil, nil, errors.Wrapf(err, "failed to generate parameters for endpoint %q", endpointName) } @@ -362,6 +476,45 @@ func withAuthServiceStructMethodsAST(serviceDefinition spec.ServiceDefinition, i return methods, allImports, nil } +func withTokenServiceStructMethodsAST(serviceDefinition spec.ServiceDefinition, info types.PkgInfo) ([]astgen.ASTDecl, StringSet, error) { + allImports := make(StringSet) + var methods []astgen.ASTDecl + serviceName := serviceDefinition.ServiceName.Name + for _, endpointDefinition := range serviceDefinition.Endpoints { + endpointName := string(endpointDefinition.EndpointName) + params, imports, err := paramsForEndpoint(endpointDefinition, info, serviceASTConfig{withTokenProvider: true}, generatorTypeClient) + if err != nil { + return nil, nil, errors.Wrapf(err, "Failed to generate parameters for endpoint %q", endpointName) + } + allImports.AddAll(imports) + + returnTypes, imports, err := returnTypesForEndpoint(endpointDefinition, info) + if err != nil { + return nil, nil, errors.Wrapf(err, "Failed to generate return types for endpoint %q", endpointName) + } + allImports.AddAll(imports) + + body, err := serviceWithTokenStructMethodBodyAST(endpointDefinition, params, returnTypes) + if err != nil { + return nil, nil, errors.Wrapf(err, "Failed to generate token provider structs for endpoint %q", endpointName) + } + + methods = append(methods, &decl.Method{ + ReceiverName: receiverName, + ReceiverType: expression.Type(withTokenProviderName(clientStructTypeName(serviceName))).Pointer(), + Function: decl.Function{ + Name: transforms.Export(endpointName), + FuncType: expression.FuncType{ + Params: params, + ReturnTypes: returnTypes, + }, + Body: body, + }, + }) + } + return methods, allImports, nil +} + func isReturnTypeCollectionType(inType *spec.Type) (bool, error) { if inType == nil { return false, nil @@ -752,6 +905,65 @@ func serviceWithAuthStructMethodBodyAST(endpointDefinition spec.EndpointDefiniti } +func serviceWithTokenStructMethodBodyAST(endpointDefinition spec.EndpointDefinition, params expression.FuncParams, returnTypes expression.Types) ([]astgen.ASTStmt, error) { + endpointName := string(endpointDefinition.EndpointName) + args := []astgen.ASTExpr{expression.VariableVal(ctxName)} + statements := []astgen.ASTStmt{} + if endpointDefinition.Auth != nil { + possibleHeader, err := visitors.GetPossibleHeaderAuth(*endpointDefinition.Auth) + if err != nil { + return nil, err + } + possibleCookie, err := visitors.GetPossibleCookieAuth(*endpointDefinition.Auth) + if err != nil { + return nil, err + } + if possibleHeader != nil || possibleCookie != nil { + tokenInit := &statement.Assignment{ + LHS: []astgen.ASTExpr{ + expression.VariableVal("token"), + expression.VariableVal("err"), + }, + Tok: token.DEFINE, + RHS: expression.NewCallExpression(expression.NewSelector(expression.VariableVal(receiverName), tokenProviderVar), expression.VariableVal(ctxName)), + } + var errReturn *statement.If + if len(returnTypes) > 1 { + statements = append(statements, statement.NewDecl(decl.NewVar("defaultReturnVal", returnTypes[0]))) + errReturn = ifErrNotNilReturnHelper(true, "defaultReturnVal", "err", nil) + } else { + errReturn = ifErrNotNilReturnErrStatement("err", nil) + } + args = append(args, expression.NewCallExpression(expression.Type("bearertoken.Token"), expression.VariableVal("token"))) + statements = append(statements, tokenInit, errReturn) + } + } + + for _, param := range params { + if param.Type == "context.Context" { + // We already added ctx as the first argument. + continue + } + for _, curr := range param.Names { + args = append(args, expression.VariableVal(curr)) + } + } + + return append(statements, + statement.NewReturn( + expression.NewCallExpression( + expression.NewSelector( + expression.NewSelector(expression.VariableVal(receiverName), + wrappedClientVar, + ), + transforms.Export(endpointName), + ), + args..., + ), + ), + ), nil +} + func returnVals(hasReturnVal bool, optional, required astgen.ASTExpr) []astgen.ASTExpr { var rvals []astgen.ASTExpr if hasReturnVal { @@ -811,10 +1023,10 @@ func returnTypesForEndpoint(endpointDefinition spec.EndpointDefinition, info typ return append(returnTypes, expression.ErrorType), imports, nil } -func paramsForEndpoint(endpointDefinition spec.EndpointDefinition, info types.PkgInfo, withAuth, isClient bool) (expression.FuncParams, StringSet, error) { +func paramsForEndpoint(endpointDefinition spec.EndpointDefinition, info types.PkgInfo, config serviceASTConfig, generatorType generatorType) (expression.FuncParams, StringSet, error) { imports := NewStringSet("context") params := []*expression.FuncParam{expression.NewFuncParam(ctxName, expression.Type("context.Context"))} - if endpointDefinition.Auth != nil && !withAuth { + if endpointDefinition.Auth != nil && !config.withAuth && !config.withTokenProvider { if authHeader, err := visitors.GetPossibleHeaderAuth(*endpointDefinition.Auth); err != nil { return nil, nil, err } else if authHeader != nil { @@ -838,7 +1050,7 @@ func paramsForEndpoint(endpointDefinition spec.EndpointDefinition, info types.Pk if binaryParam { // special case: "binary" types resolve to []byte, but this indicates a streaming parameter when // specified as the request argument of a service, so use "io.ReadCloser". - if isClient { + if generatorType == generatorTypeClient { // special case: the client provides "func() io.ReadCloser" instead of "io.ReadCloser" so // that a fresh "io.ReadCloser" can be retrieved for retries. goType = types.GetBodyType.GoType(info) @@ -876,6 +1088,10 @@ func withAuthName(name string) string { return name + "WithAuth" } +func withTokenProviderName(name string) string { + return name + "WithTokenProvider" +} + // argNameTransform returns the input string with "Arg" appended to it. This transformation is done to ensure that // argument variable names do not shadow any package names. func argNameTransform(input string) string {