From 475eb1bc73fd0ffa9279544135ae668e5d3aec89 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 28 Sep 2022 14:35:14 +0800 Subject: [PATCH] *: ctl TLS refine (#98) --- cmd/tiproxyctl/main.go | 2 +- go.mod | 4 ++-- lib/cli/main.go | 38 ++++++++++++++++++++++++++++++++++---- lib/cli/util.go | 32 ++++++++++++++++++++++++++------ 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/cmd/tiproxyctl/main.go b/cmd/tiproxyctl/main.go index 7b7ede49..758b4c83 100644 --- a/cmd/tiproxyctl/main.go +++ b/cmd/tiproxyctl/main.go @@ -23,7 +23,7 @@ import ( ) func main() { - rootCmd := cli.GetRootCmd() + rootCmd := cli.GetRootCmd(nil) rootCmd.Use = strings.Replace(rootCmd.Use, "tiproxyctl", os.Args[0], 1) cmd.RunRootCommand(rootCmd) } diff --git a/go.mod b/go.mod index 83cfc7f5..f5016a3e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/gin-gonic/gin v1.8.1 github.com/go-mysql-org/go-mysql v1.6.0 github.com/pingcap/TiProxy/lib v0.0.0-00010101000000-000000000000 + github.com/pingcap/log v1.1.0 github.com/pingcap/tidb v1.1.0-beta.0.20220908042057-08b1faf2ad1e github.com/pingcap/tidb/parser v0.0.0-20220908042057-08b1faf2ad1e github.com/prometheus/client_golang v1.13.0 @@ -22,6 +23,7 @@ require ( go.uber.org/zap v1.23.0 golang.org/x/exp v0.0.0-20220907003533-145caa8ea1d0 google.golang.org/grpc v1.49.0 + gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) require ( @@ -71,7 +73,6 @@ require ( github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/failpoint v0.0.0-20220423142525-ae43b7f4e5c3 // indirect github.com/pingcap/kvproto v0.0.0-20220906053631-2e37953b2b43 // indirect - github.com/pingcap/log v1.1.0 // indirect github.com/pingcap/tipb v0.0.0-20220824081009-0714a57aff1d // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -121,7 +122,6 @@ require ( golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect google.golang.org/genproto v0.0.0-20220822174746-9e6da59bd2fc // indirect google.golang.org/protobuf v1.28.1 // indirect - gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect sigs.k8s.io/yaml v1.2.0 // indirect diff --git a/lib/cli/main.go b/lib/cli/main.go index 4a8bd5af..fb0d453e 100644 --- a/lib/cli/main.go +++ b/lib/cli/main.go @@ -15,16 +15,20 @@ package cli import ( + "crypto/tls" "net/http" + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/security" "github.com/spf13/cobra" "go.uber.org/zap" ) -func GetRootCmd() *cobra.Command { +func GetRootCmd(tlsConfig *tls.Config) *cobra.Command { rootCmd := &cobra.Command{ - Use: "tiproxyctl", - Short: "cli", + Use: "tiproxyctl", + Short: "cli", + SilenceUsage: true, } ctx := &Context{} @@ -32,6 +36,10 @@ func GetRootCmd() *cobra.Command { curls := rootCmd.PersistentFlags().StringArray("curls", []string{"localhost:3080"}, "API gateway addresses") logEncoder := rootCmd.PersistentFlags().String("log_encoder", "tidb", "log in format of tidb, console, or json") logLevel := rootCmd.PersistentFlags().String("log_level", "info", "log level") + insecure := rootCmd.PersistentFlags().BoolP("insecure", "k", false, "enable TLS without CA, useful for testing, or for expired certs") + caPath := rootCmd.PersistentFlags().String("ca", "", "CA to verify server certificates, set to 'skip' if want to enable SSL without verification") + certPath := rootCmd.PersistentFlags().String("cert", "", "cert for server-side client authentication") + keyPath := rootCmd.PersistentFlags().String("key", "", "key for server-side client authentication") rootCmd.PersistentFlags().Bool("indent", true, "whether indent the returned json") rootCmd.PersistentPreRunE = func(_ *cobra.Command, _ []string) error { zapcfg := zap.NewDevelopmentConfig() @@ -43,9 +51,31 @@ func GetRootCmd() *cobra.Command { if err != nil { return err } + if tlsConfig == nil { + skipCA := *insecure + realCAPath := *caPath + if skipCA { + realCAPath = "" + } + tlsConfig, err = security.BuildClientTLSConfig(logger, config.TLSConfig{ + CA: realCAPath, + Cert: *certPath, + Key: *keyPath, + SkipCA: skipCA, + }) + if err != nil { + return err + } + } ctx.Logger = logger.Named("cli") - ctx.Client = &http.Client{} + ctx.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: tlsConfig, + }, + } ctx.CUrls = *curls + ctx.SSL = tlsConfig != nil return nil } diff --git a/lib/cli/util.go b/lib/cli/util.go index a887ce8f..62b397f2 100644 --- a/lib/cli/util.go +++ b/lib/cli/util.go @@ -22,6 +22,7 @@ import ( "math/rand" "net/http" + "github.com/pingcap/TiProxy/lib/util/errors" "go.uber.org/zap" ) @@ -29,6 +30,7 @@ type Context struct { Logger *zap.Logger Client *http.Client CUrls []string + SSL bool } func doRequest(ctx context.Context, bctx *Context, method string, url string, rd io.Reader) (string, error) { @@ -37,18 +39,36 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd sep = "/" } - req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("http://localhost%s%s", sep, url), rd) + schema := "http" + if bctx.SSL { + schema = "https" + } + + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s://localhost%s%s", schema, sep, url), rd) if err != nil { return "", err } var rete string + var res *http.Response for _, i := range rand.Perm(len(bctx.CUrls)) { req.URL.Host = bctx.CUrls[i] - res, err := bctx.Client.Do(req) + res, err = bctx.Client.Do(req) if err != nil { - return "", err + if errors.Is(err, io.EOF) { + if req.URL.Scheme == "https" { + req.URL.Scheme = "http" + } else if req.URL.Scheme == "http" { + req.URL.Scheme = "https" + } + // probably server did not enable TLS, try again with plain http + // or the reverse, server enabled TLS, but we should try https + res, err = bctx.Client.Do(req) + } + if err != nil { + return "", err + } } resb, _ := ioutil.ReadAll(res.Body) res.Body.Close() @@ -57,9 +77,9 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd case http.StatusOK: return string(resb), nil case http.StatusBadRequest: - return fmt.Sprintf("bad request: %s", string(resb)), nil + return "", errors.Errorf("bad request: %s", string(resb)) case http.StatusInternalServerError: - rete = fmt.Sprintf("internal error: %s", string(resb)) + err = errors.Errorf("internal error: %s", string(resb)) continue default: rete = fmt.Sprintf("%s: %s", res.Status, string(resb)) @@ -67,5 +87,5 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd } } - return rete, nil + return rete, err }