Skip to content

Commit

Permalink
*: ctl TLS refine (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Sep 28, 2022
1 parent b3616a3 commit 475eb1b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 13 deletions.
2 changes: 1 addition & 1 deletion cmd/tiproxyctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func main() {
rootCmd := cli.GetRootCmd()
rootCmd := cli.GetRootCmd(nil)
rootCmd.Use = strings.Replace(rootCmd.Use, "tiproxyctl", os.Args[0], 1)
cmd.RunRootCommand(rootCmd)
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/gin-gonic/gin v1.8.1
github.com/go-mysql-org/go-mysql v1.6.0
github.com/pingcap/TiProxy/lib v0.0.0-00010101000000-000000000000
github.com/pingcap/log v1.1.0
github.com/pingcap/tidb v1.1.0-beta.0.20220908042057-08b1faf2ad1e
github.com/pingcap/tidb/parser v0.0.0-20220908042057-08b1faf2ad1e
github.com/prometheus/client_golang v1.13.0
Expand All @@ -22,6 +23,7 @@ require (
go.uber.org/zap v1.23.0
golang.org/x/exp v0.0.0-20220907003533-145caa8ea1d0
google.golang.org/grpc v1.49.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)

require (
Expand Down Expand Up @@ -71,7 +73,6 @@ require (
github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect
github.com/pingcap/failpoint v0.0.0-20220423142525-ae43b7f4e5c3 // indirect
github.com/pingcap/kvproto v0.0.0-20220906053631-2e37953b2b43 // indirect
github.com/pingcap/log v1.1.0 // indirect
github.com/pingcap/tipb v0.0.0-20220824081009-0714a57aff1d // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down Expand Up @@ -121,7 +122,6 @@ require (
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect
google.golang.org/genproto v0.0.0-20220822174746-9e6da59bd2fc // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
Expand Down
38 changes: 34 additions & 4 deletions lib/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
package cli

import (
"crypto/tls"
"net/http"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/security"
"github.com/spf13/cobra"
"go.uber.org/zap"
)

func GetRootCmd() *cobra.Command {
func GetRootCmd(tlsConfig *tls.Config) *cobra.Command {
rootCmd := &cobra.Command{
Use: "tiproxyctl",
Short: "cli",
Use: "tiproxyctl",
Short: "cli",
SilenceUsage: true,
}

ctx := &Context{}

curls := rootCmd.PersistentFlags().StringArray("curls", []string{"localhost:3080"}, "API gateway addresses")
logEncoder := rootCmd.PersistentFlags().String("log_encoder", "tidb", "log in format of tidb, console, or json")
logLevel := rootCmd.PersistentFlags().String("log_level", "info", "log level")
insecure := rootCmd.PersistentFlags().BoolP("insecure", "k", false, "enable TLS without CA, useful for testing, or for expired certs")
caPath := rootCmd.PersistentFlags().String("ca", "", "CA to verify server certificates, set to 'skip' if want to enable SSL without verification")
certPath := rootCmd.PersistentFlags().String("cert", "", "cert for server-side client authentication")
keyPath := rootCmd.PersistentFlags().String("key", "", "key for server-side client authentication")
rootCmd.PersistentFlags().Bool("indent", true, "whether indent the returned json")
rootCmd.PersistentPreRunE = func(_ *cobra.Command, _ []string) error {
zapcfg := zap.NewDevelopmentConfig()
Expand All @@ -43,9 +51,31 @@ func GetRootCmd() *cobra.Command {
if err != nil {
return err
}
if tlsConfig == nil {
skipCA := *insecure
realCAPath := *caPath
if skipCA {
realCAPath = ""
}
tlsConfig, err = security.BuildClientTLSConfig(logger, config.TLSConfig{
CA: realCAPath,
Cert: *certPath,
Key: *keyPath,
SkipCA: skipCA,
})
if err != nil {
return err
}
}
ctx.Logger = logger.Named("cli")
ctx.Client = &http.Client{}
ctx.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: tlsConfig,
},
}
ctx.CUrls = *curls
ctx.SSL = tlsConfig != nil
return nil
}

Expand Down
32 changes: 26 additions & 6 deletions lib/cli/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import (
"math/rand"
"net/http"

"github.com/pingcap/TiProxy/lib/util/errors"
"go.uber.org/zap"
)

type Context struct {
Logger *zap.Logger
Client *http.Client
CUrls []string
SSL bool
}

func doRequest(ctx context.Context, bctx *Context, method string, url string, rd io.Reader) (string, error) {
Expand All @@ -37,18 +39,36 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd
sep = "/"
}

req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("http://localhost%s%s", sep, url), rd)
schema := "http"
if bctx.SSL {
schema = "https"
}

req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s://localhost%s%s", schema, sep, url), rd)
if err != nil {
return "", err
}

var rete string
var res *http.Response
for _, i := range rand.Perm(len(bctx.CUrls)) {
req.URL.Host = bctx.CUrls[i]

res, err := bctx.Client.Do(req)
res, err = bctx.Client.Do(req)
if err != nil {
return "", err
if errors.Is(err, io.EOF) {
if req.URL.Scheme == "https" {
req.URL.Scheme = "http"
} else if req.URL.Scheme == "http" {
req.URL.Scheme = "https"
}
// probably server did not enable TLS, try again with plain http
// or the reverse, server enabled TLS, but we should try https
res, err = bctx.Client.Do(req)
}
if err != nil {
return "", err
}
}
resb, _ := ioutil.ReadAll(res.Body)
res.Body.Close()
Expand All @@ -57,15 +77,15 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd
case http.StatusOK:
return string(resb), nil
case http.StatusBadRequest:
return fmt.Sprintf("bad request: %s", string(resb)), nil
return "", errors.Errorf("bad request: %s", string(resb))
case http.StatusInternalServerError:
rete = fmt.Sprintf("internal error: %s", string(resb))
err = errors.Errorf("internal error: %s", string(resb))
continue
default:
rete = fmt.Sprintf("%s: %s", res.Status, string(resb))
continue
}
}

return rete, nil
return rete, err
}

0 comments on commit 475eb1b

Please sign in to comment.