diff --git a/cmd/common.go b/cmd/common.go index 5ac15cd40..f0cf4f553 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -53,6 +53,15 @@ func workspaceExists(ctx context.Context, config utils.KongClientConfig, workspa return exists, nil } +func getWorkspaceName(workspaceFlag string, targetContent *file.Content) string { + if workspaceFlag != targetContent.Workspace && workspaceFlag != "" { + cprint.DeletePrintf("Warning: Workspace '%v' specified via --workspace flag is "+ + "different from workspace '%v' found in state file(s).\n", workspaceFlag, targetContent.Workspace) + return workspaceFlag + } + return targetContent.Workspace +} + func syncMain(ctx context.Context, filenames []string, dry bool, parallelism, delay int, workspace string) error { @@ -70,16 +79,9 @@ func syncMain(ctx context.Context, filenames []string, dry bool, parallelism, return err } - var wsConfig utils.KongClientConfig - var workspaceName string // prepare to read the current state from Kong - if workspace != targetContent.Workspace && workspace != "" { - cprint.DeletePrintf("Warning: Workspace '%v' specified via --workspace flag is "+ - "different from workspace '%v' found in state file(s).\n", workspace, targetContent.Workspace) - workspaceName = workspace - } else { - workspaceName = targetContent.Workspace - } + var wsConfig utils.KongClientConfig + workspaceName := getWorkspaceName(workspace, targetContent) wsConfig = rootConfig.ForWorkspace(workspaceName) // load Kong version after workspace diff --git a/cmd/validate.go b/cmd/validate.go index e6981e558..9a1f3c1dd 100644 --- a/cmd/validate.go +++ b/cmd/validate.go @@ -6,12 +6,17 @@ import ( "github.com/kong/deck/dump" "github.com/kong/deck/file" "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/deck/validate" "github.com/spf13/cobra" ) var ( validateCmdKongStateFile []string validateCmdRBACResourcesOnly bool + validateOnline bool + validateWorkspace string + validateParallelism int ) // validateCmd represents the diff command @@ -19,12 +24,12 @@ var validateCmd = &cobra.Command{ Use: "validate", Short: "Validate the state file", Long: `The validate command reads the state file and ensures validity. - It reads all the specified state files and reports YAML/JSON parsing issues. It also checks for foreign relationships and alerts if there are broken relationships, or missing links present. + No communication takes places between decK and Kong during the execution of -this command. +this command unless --online flag is used. `, Args: validateNoArgs, RunE: func(cmd *cobra.Command, args []string) error { @@ -51,11 +56,16 @@ this command. return err } // this catches foreign relation errors - _, err = state.Get(rawState) + ks, err := state.Get(rawState) if err != nil { return err } + if validateOnline { + if errs := validateWithKong(cmd, ks, targetContent); len(errs) != 0 { + return validate.ErrorsWrapper{Errors: errs} + } + } return nil }, PreRunE: func(cmd *cobra.Command, args []string) error { @@ -67,6 +77,107 @@ this command. }, } +func validateWithKong(cmd *cobra.Command, ks *state.KongState, targetContent *file.Content) []error { + ctx := cmd.Context() + // make sure we are able to connect to Kong + _, err := fetchKongVersion(ctx, rootConfig) + if err != nil { + return []error{fmt.Errorf("couldn't fetch Kong version: %w", err)} + } + + workspaceName := validateWorkspace + if validateWorkspace != "" { + // check if workspace exists + workspaceName := getWorkspaceName(validateWorkspace, targetContent) + workspaceExists, err := workspaceExists(ctx, rootConfig, workspaceName) + if err != nil { + return []error{err} + } + if !workspaceExists { + return []error{fmt.Errorf("workspace doesn't exist: %s", workspaceName)} + } + } + + wsConfig := rootConfig.ForWorkspace(workspaceName) + kongClient, err := utils.GetKongClient(wsConfig) + if err != nil { + return []error{err} + } + + opts := validate.ValidatorOpts{ + Ctx: ctx, + State: ks, + Client: kongClient, + Parallelism: validateParallelism, + RBACResourcesOnly: validateCmdRBACResourcesOnly, + } + validator := validate.NewValidator(opts) + return validator.Validate() +} + +// ensureGetAllMethod ensures at init time that `GetAll()` method exists on the relevant structs. +// If the method doesn't exist, the code will panic. This increases the likelihood of catching such an +// error during manual testing. +func ensureGetAllMethods() error { + // let's make sure ASAP that all resources have the expected GetAll method + dummyEmptyState, _ := state.NewKongState() + if _, err := utils.CallGetAll(dummyEmptyState.Services); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.ACLGroups); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.BasicAuths); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.CACertificates); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Certificates); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Consumers); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Documents); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.HMACAuths); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.JWTAuths); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.KeyAuths); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Oauth2Creds); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Plugins); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Routes); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.SNIs); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Targets); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.Upstreams); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.RBACEndpointPermissions); err != nil { + return err + } + if _, err := utils.CallGetAll(dummyEmptyState.RBACRoles); err != nil { + return err + } + return nil +} + func init() { rootCmd.AddCommand(validateCmd) validateCmd.Flags().BoolVar(&validateCmdRBACResourcesOnly, "rbac-resources-only", @@ -75,4 +186,18 @@ func init() { "state", "s", []string{"kong.yaml"}, "file(s) containing Kong's configuration.\n"+ "This flag can be specified multiple times for multiple files.\n"+ "Use '-' to read from stdin.") + validateCmd.Flags().BoolVar(&validateOnline, "online", + false, "perform validations against Kong API. When this flag is used, validation is done\n"+ + "via communication with Kong. This increases the time for validation but catches \n"+ + "significant errors. No resource is created in Kong.") + validateCmd.Flags().StringVarP(&validateWorkspace, "workspace", "w", + "", "validate configuration of a specific workspace "+ + "(Kong Enterprise only).\n"+ + "This takes precedence over _workspace fields in state files.") + validateCmd.Flags().IntVar(&validateParallelism, "parallelism", + 10, "Maximum number of concurrent requests to Kong.") + + if err := ensureGetAllMethods(); err != nil { + panic(err.Error()) + } } diff --git a/main.go b/main.go index 66712e8bb..88c92d0f3 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ func registerSignalHandler() context.Context { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) go func() { + defer signal.Stop(sigs) sig := <-sigs fmt.Println("received", sig, ", terminating...") cancel() diff --git a/utils/utils.go b/utils/utils.go index 7d48a3c8e..2ca0a444e 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,6 +5,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "regexp" "strings" ) @@ -52,3 +53,16 @@ func NameToFilename(name string) string { func FilenameToName(filename string) string { return strings.ReplaceAll(filename, url.PathEscape(string(os.PathSeparator)), string(os.PathSeparator)) } + +func CallGetAll(obj interface{}) (reflect.Value, error) { + // call GetAll method on entity + var result reflect.Value + method := reflect.ValueOf(obj).MethodByName("GetAll") + if !method.IsValid() { + return result, fmt.Errorf("GetAll() method not found for type '%v'. "+ + "Please file a bug with Kong Inc", reflect.ValueOf(obj).Type()) + } + entities := method.Call([]reflect.Value{})[0].Interface() + result = reflect.ValueOf(entities) + return result, nil +} diff --git a/validate/validate.go b/validate/validate.go new file mode 100644 index 000000000..8b8a9e28f --- /dev/null +++ b/validate/validate.go @@ -0,0 +1,168 @@ +package validate + +import ( + "context" + "fmt" + "net/http" + "sync" + + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +type Validator struct { + ctx context.Context + state *state.KongState + client *kong.Client + parallelism int + rbacResourcesOnly bool +} + +type ValidatorOpts struct { + Ctx context.Context + State *state.KongState + Client *kong.Client + Parallelism int + RBACResourcesOnly bool +} + +func NewValidator(opt ValidatorOpts) *Validator { + return &Validator{ + ctx: opt.Ctx, + state: opt.State, + client: opt.Client, + parallelism: opt.Parallelism, + rbacResourcesOnly: opt.RBACResourcesOnly, + } +} + +type ErrorsWrapper struct { + Errors []error +} + +func (v ErrorsWrapper) Error() string { + var errStr string + for _, e := range v.Errors { + errStr += e.Error() + if e != v.Errors[len(v.Errors)-1] { + errStr += "\n" + } + } + return errStr +} + +func (v *Validator) validateEntity(entityType string, entity interface{}) (bool, error) { + errWrap := "validate entity '%s': %s" + endpoint := fmt.Sprintf("/schemas/%s/validate", entityType) + req, err := v.client.NewRequest(http.MethodPost, endpoint, nil, entity) + if err != nil { + return false, fmt.Errorf(errWrap, entityType, err) + } + resp, err := v.client.Do(v.ctx, req, nil) + if err != nil { + return false, fmt.Errorf(errWrap, entityType, err) + } + return resp.StatusCode == http.StatusOK, nil +} + +func (v *Validator) entities(obj interface{}, entityType string) []error { + entities, err := utils.CallGetAll(obj) + if err != nil { + return []error{err} + } + errors := []error{} + + // create a buffer of channels. Creation of new coroutines + // are allowed only if the buffer is not full. + chanBuff := make(chan struct{}, v.parallelism) + + var wg sync.WaitGroup + wg.Add(entities.Len()) + // each coroutine will append on a slice of errors. + // since slices are not thread-safe, let's add a mutex + // to handle access to the slice. + mu := &sync.Mutex{} + for i := 0; i < entities.Len(); i++ { + // reserve a slot + chanBuff <- struct{}{} + go func(i int) { + defer wg.Done() + // release a slot when completed + defer func() { <-chanBuff }() + _, err := v.validateEntity(entityType, entities.Index(i).Interface()) + if err != nil { + mu.Lock() + errors = append(errors, err) + mu.Unlock() + } + }(i) + } + wg.Wait() + return errors +} + +func (v *Validator) Validate() []error { + allErr := []error{} + + // validate RBAC resources first. + if err := v.entities(v.state.RBACEndpointPermissions, "rbac-endpointpermission"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.RBACRoles, "rbac-role"); err != nil { + allErr = append(allErr, err...) + } + if v.rbacResourcesOnly { + return allErr + } + + if err := v.entities(v.state.Services, "services"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.ACLGroups, "acls"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.BasicAuths, "basicauth_credentials"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.CACertificates, "ca_certificates"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Certificates, "certificates"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Consumers, "consumers"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Documents, "documents"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.HMACAuths, "hmacauth_credentials"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.JWTAuths, "jwt_secrets"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.KeyAuths, "keyauth_credentials"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Oauth2Creds, "oauth2_credentials"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Plugins, "plugins"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Routes, "routes"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.SNIs, "snis"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Targets, "targets"); err != nil { + allErr = append(allErr, err...) + } + if err := v.entities(v.state.Upstreams, "upstreams"); err != nil { + allErr = append(allErr, err...) + } + return allErr +}