Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: fix TLS buffering #19

Merged
merged 11 commits into from
Jul 28, 2022
35 changes: 21 additions & 14 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,39 @@ const (
defaultReaderSize = 16 * 1024
)

type rdbufConn struct {
net.Conn
*bufio.Reader
}

func (f *rdbufConn) Read(b []byte) (int, error) {
return f.Reader.Read(b)
}

// PacketIO is a helper to read and write sql and proxy protocol.
type PacketIO struct {
conn net.Conn
tlsConn net.Conn
buf *bufio.ReadWriter
sequence uint8
proxyInited bool
proxy *Proxy
}

func NewPacketIO(conn net.Conn) *PacketIO {
buf := bufio.NewReadWriter(
xhebox marked this conversation as resolved.
Show resolved Hide resolved
bufio.NewReaderSize(conn, defaultReaderSize),
bufio.NewWriterSize(conn, defaultWriterSize),
)
p := &PacketIO{
conn: conn,
conn: &rdbufConn{
xhebox marked this conversation as resolved.
Show resolved Hide resolved
conn,
buf.Reader,
},
sequence: 0,
// TODO: enable proxy probe for clients only
// disable it by default now
proxyInited: true,
buf: bufio.NewReadWriter(
bufio.NewReaderSize(conn, defaultReaderSize),
bufio.NewWriterSize(conn, defaultWriterSize),
),
buf: buf,
}
return p
}
Expand All @@ -102,7 +114,7 @@ func (p *PacketIO) ResetSequence() {
func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {
var header [4]byte

if _, err := io.ReadFull(p.buf, header[:]); err != nil {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}

Expand All @@ -124,7 +136,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {

// refill mysql headers
if refill {
if _, err := io.ReadFull(p.buf, header[:]); err != nil {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}
}
Expand All @@ -137,7 +149,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

data := make([]byte, length)
if _, err := io.ReadFull(p.buf, data); err != nil {
if _, err := io.ReadFull(p.conn, data); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}
return data, length == mysql.MaxPayloadLen, nil
Expand Down Expand Up @@ -224,11 +236,6 @@ func (p *PacketIO) Close() error {
errs = append(errs, err)
}
*/
if p.tlsConn != nil {
if err := p.tlsConn.Close(); err != nil {
errs = append(errs, err)
}
}
if p.conn != nil {
if err := p.conn.Close(); err != nil {
errs = append(errs, err)
Expand Down
92 changes: 81 additions & 11 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,58 @@ import (
"net"
"testing"

"github.com/pingcap/TiProxy/pkg/util/security"
"github.com/pingcap/TiProxy/pkg/util/waitgroup"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)

func testConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO)) {
func testPipeConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) {
var wg waitgroup.WaitGroup
client, server := net.Pipe()
cli, srv := NewPacketIO(client), NewPacketIO(server)
wg.Run(func() {
a(t, cli)
require.NoError(t, cli.Close())
})
wg.Run(func() {
b(t, srv)
require.NoError(t, srv.Close())
})
wg.Wait()
for i := 0; i < loop; i++ {
wg.Run(func() {
a(t, cli)
require.NoError(t, cli.Close())
})
wg.Run(func() {
b(t, srv)
require.NoError(t, srv.Close())
})
wg.Wait()
}
}

func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) {
listener, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
defer func() {
require.NoError(t, listener.Close())
}()
var wg waitgroup.WaitGroup
for i := 0; i < loop; i++ {
wg.Run(func() {
cli, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
cliIO := NewPacketIO(cli)
a(t, cliIO)
require.NoError(t, cliIO.Close())
})
wg.Run(func() {
srv, err := listener.Accept()
require.NoError(t, err)
srvIO := NewPacketIO(srv)
b(t, srvIO)
require.NoError(t, srvIO.Close())
})
wg.Wait()
}
}

func TestPacketIO(t *testing.T) {
expectMsg := []byte("test")
testConn(t,
testPipeConn(t,
func(t *testing.T, cli *PacketIO) {
var err error

Expand Down Expand Up @@ -94,5 +123,46 @@ func TestPacketIO(t *testing.T) {
_, err = srv.ReadSSLRequest()
require.ErrorIs(t, err, ErrExpectSSLRequest)
},
1,
)
}

func TestTLS(t *testing.T) {
stls, ctls, err := security.CreateTLSConfigForTest()
require.NoError(t, err)
message := []byte("hello wolrd")
testTCPConn(t,
func(t *testing.T, cli *PacketIO) {
data, err := cli.ReadPacket()
require.NoError(t, err)
require.Equal(t, message, data)
err = cli.WritePacket(message, true)
require.NoError(t, err)

require.NoError(t, cli.UpgradeToClientTLS(ctls))

err = cli.WritePacket(message, true)
require.NoError(t, err)
data, err = cli.ReadPacket()
require.NoError(t, err)
require.Equal(t, message, data)
},
func(t *testing.T, srv *PacketIO) {
err = srv.WritePacket(message, true)
require.NoError(t, err)
data, err := srv.ReadPacket()
require.NoError(t, err)
require.Equal(t, message, data)

_, err = srv.UpgradeToServerTLS(stls)
require.NoError(t, err)

data, err = srv.ReadPacket()
require.NoError(t, err)
require.Equal(t, message, data)
err = srv.WritePacket(message, true)
require.NoError(t, err)
},
500,
xhebox marked this conversation as resolved.
Show resolved Hide resolved
)
}
5 changes: 3 additions & 2 deletions pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestProxy(t *testing.T) {
tcpaddr, err := net.ResolveTCPAddr("tcp", "192.168.1.1:34")
require.NoError(t, err)

testConn(t,
testPipeConn(t,
func(t *testing.T, cli *PacketIO) {
require.NoError(t, cli.writeProxyV2(&Proxy{
Version: ProxyVersion2,
Expand All @@ -48,7 +48,7 @@ func TestProxy(t *testing.T) {
func(t *testing.T, srv *PacketIO) {
// skip 4 bytes of magic
var hdr [4]byte
_, err := io.ReadFull(srv.buf, hdr[:])
_, err := io.ReadFull(srv.conn, hdr[:])
require.NoError(t, err)

// try to parse V2
Expand All @@ -64,5 +64,6 @@ func TestProxy(t *testing.T) {
require.Equal(t, ProxyTlvUniqueID, p.TLV[1].typ)
require.Equal(t, []byte("test"), p.TLV[1].content)
},
1,
)
}
16 changes: 8 additions & 8 deletions pkg/proxy/net/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ import (
)

func (p *PacketIO) UpgradeToServerTLS(tlsConfig *tls.Config) (tls.ConnectionState, error) {
tlsConfig = tlsConfig.Clone()
tlsConn := tls.Server(p.conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return tlsConn.ConnectionState(), errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
}
p.buf.Reader.Reset(tlsConn)
p.buf.Writer.Reset(tlsConn)
p.conn = tlsConn
xhebox marked this conversation as resolved.
Show resolved Hide resolved
p.buf.Writer.Reset(p.conn)
return tlsConn.ConnectionState(), nil
}

func (p *PacketIO) UpgradeToClientTLS(tlsConfig *tls.Config) error {
tlsConfig = tlsConfig.Clone()
host, _, err := net.SplitHostPort(p.conn.RemoteAddr().String())
if err != nil {
return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
if err == nil {
tlsConfig.ServerName = host
}
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
tlsConn := tls.Client(p.conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
}
p.buf.Reader.Reset(tlsConn)
p.buf.Writer.Reset(tlsConn)
p.conn = tlsConn
p.buf.Writer.Reset(p.conn)
return nil
}
110 changes: 109 additions & 1 deletion pkg/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
package security

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"go.uber.org/zap"
"math/big"
"net"
"os"
"path/filepath"
"time"

"go.uber.org/zap"

xhebox marked this conversation as resolved.
Show resolved Hide resolved
"github.com/pingcap/errors"
"github.com/pingcap/tidb/util/logutil"
)
Expand Down Expand Up @@ -192,3 +195,108 @@ func CreateClientTLSConfig(sslCA, sslKey, sslCert string) (tlsConfig *tls.Config
}
return
}

// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
xhebox marked this conversation as resolved.
Show resolved Hide resolved
// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}

// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}

// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})

caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})

// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}

certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}

certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}

certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})

certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})

serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return nil, nil, err
}

serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
}

certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
clientTLSConf = &tls.Config{
RootCAs: certpool,
}

return
}