Skip to content

Commit

Permalink
*: simplify options
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
  • Loading branch information
xhebox committed Sep 13, 2022
1 parent 529a1a8 commit 691e919
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 101 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ cmd_%:
test: ./bin/gocovmerge
go test -coverprofile=.cover.pkg ./...
cd lib && go test -coverprofile=../.cover.lib ./...
rm -f .cover.*
./bin/gocovmerge .cover.* > .cover
rm .cover.*
rm -f .cover.*
go tool cover -html=.cover -o .cover.html

./bin/gocovmerge:
Expand Down
33 changes: 14 additions & 19 deletions conf/weirproxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,26 @@ log:
max-backups: 1
security:
rsa-key-size: 4096
# tls object
# for client:
# 1. requires ca or skip-ca
# 2. set certs will be used for server-side client verification, if any
# 3. auto-certs is useless
# for server:
# 1. requires cert/key or auto-certs
# 2. set ca will enable server-side client verification
# 3. skip-ca is useless
# tls:
# tls object is either of type server or client
# xxxx:
# ca: ca.pem
# cert: c.pem
# key: k.pem
# auto-certs: true
# ca: ca.pem
# skip-ca: trure
tidb-tls: # client object
# client-tls:
# 1. requires: ca or skip-ca(skip verify server certs)
# 2. optionally: cert/key will be used for server-side client verification
# 3. useless/forbid: auto-certs
# server object:
# 1. requires: cert/key or auto-certs(generate a temporary cert, mostly for testing)
# 2. optionally: ca will enable server-side client verification
# 3. useless/forbid: skip-ca
cluster-tls: # client object
# access to other components like TiDB or PD, will use this
skip-ca: true
pd-tls: # client object
skip-ca: true
cluster:
# cluster will be used for internal peer communication:
# that said, it is both a client and a server object
client: # server object
server-tls: # server object
# proxy SQL or internal HTTP port will all use this
skip-ca: true
auto-certs: true
advance:
# ignore-wrong-namespace: true
Expand Down
8 changes: 4 additions & 4 deletions lib/config/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ type Namespace struct {
}

type FrontendNamespace struct {
Security TLSCert `yaml:"security" json:"security" toml:"security"`
Security TLSConfig `yaml:"security" json:"security" toml:"security"`
}

type BackendNamespace struct {
Instances []string `yaml:"instances" json:"instances" toml:"instances"`
SelectorType string `yaml:"selector-type" json:"selector-type" toml:"selector-type"`
Security TLSCert `yaml:"security" json:"security" toml:"security"`
Instances []string `yaml:"instances" json:"instances" toml:"instances"`
SelectorType string `yaml:"selector-type" json:"selector-type" toml:"selector-type"`
Security TLSConfig `yaml:"security" json:"security" toml:"security"`
}
18 changes: 10 additions & 8 deletions lib/config/namespace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ import (
var testNamespaceConfig = Namespace{
Namespace: "test_ns",
Frontend: FrontendNamespace{
Security: TLSCert{
CA: "t",
Cert: "t",
Key: "t",
Security: TLSConfig{
CA: "t",
Cert: "t",
Key: "t",
AutoCerts: true,
},
},
Backend: BackendNamespace{
Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"},
SelectorType: "random",
Security: TLSCert{
CA: "t",
Cert: "t",
Key: "t",
Security: TLSConfig{
CA: "t",
Cert: "t",
Key: "t",
SkipCA: true,
},
},
}
Expand Down
18 changes: 8 additions & 10 deletions lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,26 @@ type LogFile struct {
MaxBackups int `yaml:"max-backups,omitempty" toml:"max-backups,omitempty" json:"max-backups,omitempty"`
}

type TLSCert struct {
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
type TLSConfig struct {
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
}

func (c TLSCert) HasCert() bool {
func (c TLSConfig) HasCert() bool {
return !(c.Cert == "" && c.Key == "")
}

func (c TLSCert) HasCA() bool {
func (c TLSConfig) HasCA() bool {
return c.CA != ""
}

type Security struct {
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
Client TLSCert `yaml:"client,omitempty" toml:"client,omitempty" json:"client,omitempty"`
Cluster TLSCert `yaml:"cluster,omitempty" toml:"cluster,omitempty" json:"cluster,omitempty"`
PDTLS TLSCert `yaml:"pd-tls,omitempty" toml:"pd-tls,omitempty" json:"pd-tls,omitempty"`
TiDBTLS TLSCert `yaml:"tidb-tls,omitempty" toml:"tidb-tls,omitempty" json:"tidb-tls,omitempty"`
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
ServerTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"`
ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"`
}

func NewConfig(data []byte) (*Config, error) {
Expand Down
28 changes: 6 additions & 22 deletions lib/config/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,17 @@ var testProxyConfig = Config{
},
Security: Security{
RSAKeySize: 64,
Client: TLSCert{
ServerTLS: TLSConfig{
CA: "a",
SkipCA: true,
Cert: "b",
Key: "c",
AutoCerts: true,
},
Cluster: TLSCert{
CA: "a",
SkipCA: true,
Cert: "b",
Key: "c",
AutoCerts: true,
},
PDTLS: TLSCert{
CA: "a",
SkipCA: true,
Cert: "b",
Key: "c",
AutoCerts: true,
},
TiDBTLS: TLSCert{
CA: "a",
SkipCA: true,
Cert: "b",
Key: "c",
AutoCerts: true,
ClusterTLS: TLSConfig{
CA: "a",
SkipCA: true,
Cert: "b",
Key: "c",
},
},
}
Expand Down
45 changes: 22 additions & 23 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ import (
"go.uber.org/zap"
)

func createTLSCertificates(logger *zap.Logger, certpath string, keypath string, rsaKeySize int) error {
func createTLSConfigificates(logger *zap.Logger, certpath string, keypath string, rsaKeySize int) error {
_, e1 := os.Stat(certpath)
_, e2 := os.Stat(keypath)
if errors.Is(e1, os.ErrExist) && errors.Is(e2, os.ErrExist) {
if errors.Is(e1, os.ErrExist) && errors.Is(e2, os.ErrExist) {
return nil
} else if errors.Is(e1, os.ErrExist) || errors.Is(e2, os.ErrExist) {
} else if errors.Is(e1, os.ErrExist) || errors.Is(e2, os.ErrExist) {
return errors.New("cert and key should be present or not at the same time")
}

Expand Down Expand Up @@ -213,12 +213,12 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con
return
}

func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSCert, workdir, mod string, keySize int) (*tls.Config, error) {
func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSConfig, workdir, mod string, keySize int) (*tls.Config, error) {
if !cfg.HasCert() {
if cfg.AutoCerts {
cfg.Cert = filepath.Join(workdir, mod, "cert.pem")
cfg.Key = filepath.Join(workdir, mod, "key.pem")
if err := createTLSCertificates(logger, cfg.Cert, cfg.Key, keySize); err != nil {
if err := createTLSConfigificates(logger, cfg.Cert, cfg.Key, keySize); err != nil {
return nil, err
}
return BuildServerTLSConfig(logger, cfg, workdir, mod, keySize)
Expand Down Expand Up @@ -253,7 +253,7 @@ func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSCert, workdir, mod s
return tcfg, nil
}

func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSCert, mod string) (*tls.Config, error) {
func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig, mod string) (*tls.Config, error) {
if !cfg.HasCA() {
logger.Warn(fmt.Sprintf("require CA to verify %s server connections", mod))
if cfg.SkipCA {
Expand All @@ -265,7 +265,6 @@ func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSCert, mod string) (*
}

tcfg := &tls.Config{}
tcfg.ClientAuth = tls.RequireAndVerifyClientCert
tcfg.ClientCAs = x509.NewCertPool()
certBytes, err := ioutil.ReadFile(cfg.CA)
if err != nil {
Expand All @@ -288,33 +287,33 @@ func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSCert, mod string) (*
return tcfg, nil
}

func BuildEtcdTLSConfig(logger *zap.Logger, client, cluster config.TLSCert, workdir, mod string, keySize int) (clientInfo, peerInfo transport.TLSInfo, err error) {
if !client.HasCert() {
if client.AutoCerts {
client.Cert = filepath.Join(workdir, mod, "cert.pem")
client.Key = filepath.Join(workdir, mod, "key.pem")
if err = createTLSCertificates(logger, client.Cert, client.Key, keySize); err != nil {
func BuildEtcdTLSConfig(logger *zap.Logger, server config.TLSConfig, workdir, mod string, keySize int) (clientInfo, peerInfo transport.TLSInfo, err error) {
if !server.HasCert() {
if server.AutoCerts {
server.Cert = filepath.Join(workdir, mod, "cert.pem")
server.Key = filepath.Join(workdir, mod, "key.pem")
if err = createTLSConfigificates(logger, server.Cert, server.Key, keySize); err != nil {
return
}
return BuildEtcdTLSConfig(logger, client, cluster, workdir, mod, keySize)
return BuildEtcdTLSConfig(logger, server, workdir, mod, keySize)
}
} else {
clientInfo.CertFile = client.Cert
clientInfo.KeyFile = client.Key
if client.HasCA() {
clientInfo.CertFile = server.Cert
clientInfo.KeyFile = server.Key
if server.HasCA() {
clientInfo.ClientCertAuth = true
clientInfo.TrustedCAFile = client.CA
clientInfo.TrustedCAFile = server.CA
} else {
logger.Warn("no signed certs for etcd clients, proxy will not authenticate etcd clients (connection is still secured)")
}
}

if cluster.HasCA() && cluster.HasCert() {
peerInfo.CertFile = cluster.Cert
peerInfo.KeyFile = cluster.Key
peerInfo.TrustedCAFile = cluster.CA
if server.HasCA() && server.HasCert() {
peerInfo.CertFile = server.Cert
peerInfo.KeyFile = server.Key
peerInfo.TrustedCAFile = server.CA
peerInfo.ClientCertAuth = true
} else if cluster.HasCA() || cluster.HasCert() {
} else if server.HasCA() || server.HasCert() {
err = errors.New("need a full set of cert/ca/key for secure etcd peer inter-communication")
return
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/manager/config/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func NewConfigManager() *ConfigManager {
}
}

func (srv *ConfigManager) Init(ctx context.Context, addrs []string, cfg config.Advance, scfg config.TLSCert, logger *zap.Logger) error {
func (srv *ConfigManager) Init(ctx context.Context, addrs []string, cfg config.Advance, scfg config.TLSConfig, logger *zap.Logger) error {
srv.logger = logger
srv.ignoreWrongNamespace = cfg.IgnoreWrongNamespace
if cfg.WatchInterval == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/manager/config/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func testConfigManager(t *testing.T, cfg config.Advance) (*ConfigManager, contex
}

cfgmgr := NewConfigManager()
require.NoError(t, cfgmgr.Init(ctx, ends, cfg, config.TLSCert{}, logger))
require.NoError(t, cfgmgr.Init(ctx, ends, cfg, config.TLSConfig{}, logger))

t.Cleanup(func() {
require.NoError(t, cfgmgr.Close())
Expand Down
2 changes: 1 addition & 1 deletion pkg/manager/router/backend_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, e
pdEndpoints := strings.Split(pdAddr, ",")
logConfig := zap.NewProductionConfig()
logConfig.Level = zap.NewAtomicLevelAt(zap.ErrorLevel)
tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.PDTLS, "pd")
tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.ClusterTLS, "pd")
if err != nil {
return nil, err
}
Expand Down
1 change: 0 additions & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/pingcap/TiProxy/pkg/manager/namespace"
"github.com/pingcap/TiProxy/pkg/proxy/backend"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
Expand Down
16 changes: 8 additions & 8 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ type serverState struct {
}

type SQLServer struct {
listener net.Listener
logger *zap.Logger
nsmgr *mgrns.NamespaceManager
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
wg waitgroup.WaitGroup
listener net.Listener
logger *zap.Logger
nsmgr *mgrns.NamespaceManager
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
wg waitgroup.WaitGroup

mu serverState
}
Expand All @@ -65,10 +65,10 @@ func NewSQLServer(logger *zap.Logger, workdir string, cfg config.ProxyServer, sc
},
}

if s.frontendTLSConfig, err = security.BuildServerTLSConfig(logger, scfg.Client, workdir, "frontend", scfg.RSAKeySize); err != nil {
if s.frontendTLSConfig, err = security.BuildServerTLSConfig(logger, scfg.ServerTLS, workdir, "frontend", scfg.RSAKeySize); err != nil {
return nil, err
}
if s.backendTLSConfig, err = security.BuildClientTLSConfig(logger, scfg.TiDBTLS, "backend"); err != nil {
if s.backendTLSConfig, err = security.BuildClientTLSConfig(logger, scfg.ClusterTLS, "backend"); err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA
for i := range addrs {
addrs[i] = srv.Etcd.Clients[i].Addr().String()
}
err = srv.ConfigManager.Init(ctx, addrs, cfg.Advance, cfg.Security.Client, logger.Named("config"))
err = srv.ConfigManager.Init(ctx, addrs, cfg.Advance, cfg.Security.ServerTLS, logger.Named("config"))
if err != nil {
err = errors.WithStack(err)
return
Expand Down Expand Up @@ -260,7 +260,7 @@ func buildEtcd(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA
etcd_cfg.Dir = filepath.Join(cfg.Workdir, "etcd")
etcd_cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(logger.Named("etcd"))

if etcd_cfg.ClientTLSInfo, etcd_cfg.PeerTLSInfo, err = security.BuildEtcdTLSConfig(logger, cfg.Security.Client, cfg.Security.Cluster, cfg.Workdir, "frontend", cfg.Security.RSAKeySize); err != nil {
if etcd_cfg.ClientTLSInfo, etcd_cfg.PeerTLSInfo, err = security.BuildEtcdTLSConfig(logger, cfg.Security.ServerTLS, cfg.Workdir, "frontend", cfg.Security.RSAKeySize); err != nil {
return
}

Expand Down

0 comments on commit 691e919

Please sign in to comment.