Skip to content

Commit

Permalink
add validation for config, tls private key and cert file values
Browse files Browse the repository at this point in the history
  • Loading branch information
devang-gaur committed May 11, 2021
1 parent 9ac1667 commit 2ff24ca
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 11 deletions.
8 changes: 7 additions & 1 deletion pkg/cli/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cli

import (
"flag"
httpserver "github.com/accurics/terrascan/pkg/http-server"
"io/ioutil"
"log"
"os"
Expand All @@ -38,14 +39,19 @@ func Execute() {
rootCmd.PersistentFlags().StringVarP(&LogLevel, "log-level", "l", "info", "log level (debug, info, warn, error, panic, fatal)")
rootCmd.PersistentFlags().StringVarP(&LogType, "log-type", "x", "console", "log output type (console, json)")
rootCmd.PersistentFlags().StringVarP(&OutputType, "output", "o", "human", "output type (human, json, yaml, xml, junit-xml)")
rootCmd.PersistentFlags().StringVarP(&ConfigFile, "config-path", "c", "", "config file path")
rootCmd.PersistentFlags().StringVarP(&ConfigFile, "config-path", "c", httpserver.ConfigFilePlaceholder, "config file path")

// Function to execute before processing commands
cobra.OnInitialize(func() {
// Set up the logger
logging.Init(LogType, LogLevel)

if len(ConfigFile) == 0 {
zap.S().Error("value of --config-path or -c flag left blank")
os.Exit(1)
}

if ConfigFile == httpserver.ConfigFilePlaceholder {
ConfigFile = os.Getenv(config.ConfigEnvvarName)
zap.S().Debugf("%s:%s", config.ConfigEnvvarName, os.Getenv(config.ConfigEnvvarName))
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func server(cmd *cobra.Command, args []string) {
}

func init() {
serverCmd.Flags().StringVarP(&privateKeyFile, "key-path", "", "", "private key file path")
serverCmd.Flags().StringVarP(&certFile, "cert-path", "", "", "certificate file path")
serverCmd.Flags().StringVarP(&privateKeyFile, "key-path", "", httpserver.TLSPrivateKeyFilePlaceholder, "private key file path")
serverCmd.Flags().StringVarP(&certFile, "cert-path", "", httpserver.TLSCertFilePlaceholder, "certificate file path")
serverCmd.Flags().StringVarP(&port, "port", "p", httpserver.GatewayDefaultPort, "server port")
RegisterCommand(rootCmd, serverCmd)
}
7 changes: 7 additions & 0 deletions pkg/http-server/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ const (
// GatewayDefaultPort - default port at which the http server listens
GatewayDefaultPort = "9010"

// TLSPrivateKeyFilePlaceholder - placeholder
TLSPrivateKeyFilePlaceholder = "<key>"
// TLSCertFilePlaceholder - placeholder
TLSCertFilePlaceholder = "<cert>"
// ConfigFilePlaceholder - placeholder
ConfigFilePlaceholder = "<config>"

// APIVersion - default api version for REST endpoints
APIVersion = "v1"

Expand Down
66 changes: 60 additions & 6 deletions pkg/http-server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ package httpserver

import (
"context"
"github.com/accurics/terrascan/pkg/utils"
"github.com/go-errors/errors"
"go.uber.org/zap"
"net/http"
"os"
"os/signal"
"strings"
"time"

"github.com/accurics/terrascan/pkg/logging"
Expand All @@ -29,7 +33,15 @@ import (

// Start initializes api routes and starts http server
func Start(port, configFile, certFile, privateKeyFile string) {
logger := logging.GetDefaultLogger() // new logger

if privateKeyFile != TLSPrivateKeyFilePlaceholder || certFile != TLSCertFilePlaceholder {
logger.Debugf("certfile is %s, privateKeyFile is %s", certFile, privateKeyFile)

if err := validateTLSKeyAndCert(privateKeyFile, certFile); err != nil {
logger.Fatal(err)
}
}
// create a new API server
server := NewAPIServer()

Expand All @@ -42,16 +54,15 @@ func Start(port, configFile, certFile, privateKeyFile string) {
}

// register routes and start the http server
server.start(routes, port, certFile, privateKeyFile)
server.start(routes, logger, port, certFile, privateKeyFile)
}

// start http server
func (g *APIServer) start(routes []*Route, port, certFile, privateKeyFile string) {
func (g *APIServer) start(routes []*Route, logger *zap.SugaredLogger, port, certFile, privateKeyFile string) {

var (
err error
logger = logging.GetDefaultLogger() // new logger
router = mux.NewRouter() // new router
router = mux.NewRouter() // new router
)

logger.Info("registering routes...")
Expand All @@ -72,19 +83,30 @@ func (g *APIServer) start(routes []*Route, port, certFile, privateKeyFile string
Handler: router,
}

var useHTTPS chan bool
go func() {
var err error
if certFile != "" && privateKeyFile != "" {
if certFile != TLSCertFilePlaceholder && privateKeyFile != TLSPrivateKeyFilePlaceholder {
// In case a certificate file is specified, the server support TLS
useHTTPS <- true
err = server.ListenAndServeTLS(certFile, privateKeyFile)
} else {
useHTTPS <- false
err = server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logger.Fatal(err)
}
}()
logger.Infof("http server listening at port %v", port)

var message string
if <-useHTTPS {
message = "https server listening at port %v"
} else {
message = "http server listening at port %v"
}

logger.Infof(message, port)

// Wait for interrupt signal to gracefully shutdown the server
quit := make(chan os.Signal, 1)
Expand All @@ -100,3 +122,35 @@ func (g *APIServer) start(routes []*Route, port, certFile, privateKeyFile string
}
logger.Info("server exiting gracefully")
}

func validateTLSKeyAndCert(privateKeyFile, certFile string) error {
e1 := validateFileName(privateKeyFile)
e2 := validateFileName(certFile)

if e1 != nil && e2 != nil {
return errors.Errorf("error with privateKey filename: %s, error with certFile filename: %s", e1.Error(), e2.Error())
} else if e1 != nil && e2 == nil {
return errors.Errorf("error with privateKey filename: %s", e1.Error())
} else if e1 == nil && e2 != nil {
return errors.Errorf("error with certFile filename: %s", e2.Error())
} else {
return nil
}
}

func validateFileName(file string) error {
if len(file) == 0 {
return errors.New("filename is an empty string")
}

if strings.HasPrefix(file, "-") {
return errors.Errorf("file name mistakenly assigned with some another flag, %s", file)
}

if _, err := utils.GetAbsPath(file); err != nil {
return errors.Errorf("filename '%s' is incorrect: %v", file, err)
}

zap.S().Debugf("validated filename %s", file)
return nil
}
4 changes: 2 additions & 2 deletions pkg/utils/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func GetAbsPath(path string) (string, error) {
}

// get absolute file path
path, _ = filepath.Abs(path)
return path, nil
path, err := filepath.Abs(path)
return path, err
}

// FindAllDirectories Walks the file path and returns a list of all directories within
Expand Down

0 comments on commit 2ff24ca

Please sign in to comment.