From 52994088a403735b5cf008741c5aea43bcc9f4a0 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Tue, 23 Aug 2022 15:59:53 +0800 Subject: [PATCH 1/2] support CLIENT_DEPRECATE_EOF --- go.mod | 14 +-- go.sum | 29 ++--- pkg/proxy/backend/authenticator.go | 15 +-- pkg/proxy/backend/backend_conn_mgr.go | 11 ++ pkg/proxy/backend/backend_conn_mgr_test.go | 53 +++++++++ pkg/proxy/backend/cmd_processor.go | 2 + pkg/proxy/backend/cmd_processor_exec.go | 128 +++++++++------------ pkg/proxy/backend/cmd_processor_query.go | 46 +++++--- pkg/proxy/backend/cmd_processor_test.go | 85 ++++++++------ pkg/proxy/backend/mock_backend_test.go | 59 ++++++---- pkg/proxy/backend/mock_client_test.go | 105 +++++++++-------- pkg/proxy/backend/mock_proxy_test.go | 6 +- pkg/proxy/backend/testsuite_test.go | 10 +- pkg/proxy/net/mysql.go | 9 +- pkg/proxy/net/packetio_mysql.go | 4 +- 15 files changed, 345 insertions(+), 231 deletions(-) diff --git a/go.mod b/go.mod index 7479501e..8ff1bd80 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,10 @@ require ( github.com/go-mysql-org/go-mysql v1.6.0 github.com/goccy/go-yaml v1.9.5 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c - github.com/pingcap/tidb v1.1.0-beta.0.20220804075006-e071841317c5 - github.com/pingcap/tidb/parser v0.0.0-20220804082206-fff748348776 + github.com/pingcap/tidb v1.1.0-beta.0.20220819110652-8b5b724d8a93 + github.com/pingcap/tidb/parser v0.0.0-20220819110652-8b5b724d8a93 github.com/prometheus/client_golang v1.12.2 + github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 github.com/spf13/cobra v1.5.0 github.com/stretchr/testify v1.8.0 go.etcd.io/etcd/api/v3 v3.5.2 @@ -23,7 +24,7 @@ require ( ) require ( - github.com/BurntSushi/toml v1.1.0 // indirect + github.com/BurntSushi/toml v1.2.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect @@ -67,7 +68,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pelletier/go-toml/v2 v2.0.1 // indirect github.com/pingcap/failpoint v0.0.0-20220423142525-ae43b7f4e5c3 // indirect - github.com/pingcap/kvproto v0.0.0-20220711062932-08b02befd813 // indirect + github.com/pingcap/kvproto v0.0.0-20220804022843-f006036b1277 // indirect github.com/pingcap/log v1.1.0 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect @@ -78,13 +79,12 @@ require ( github.com/prometheus/procfs v0.7.3 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/shirou/gopsutil/v3 v3.22.6 // indirect - github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 // indirect - github.com/sirupsen/logrus v1.8.1 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stathat/consistent v1.0.0 // indirect - github.com/tikv/client-go/v2 v2.0.1-0.20220729034404-e10841f2d158 // indirect + github.com/tikv/client-go/v2 v2.0.1-0.20220818084834-0d0ae0dcfb1f // indirect github.com/tikv/pd/client v0.0.0-20220725055910-7187a7ab72db // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect github.com/twmb/murmur3 v1.1.3 // indirect diff --git a/go.sum b/go.sum index 5cc20c76..1eaa81ed 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.12.0 h1:VBvHGLJbaY0+c66NZHdS github.com/Azure/azure-sdk-for-go/sdk/internal v0.8.1 h1:BUYIbDf/mMZ8945v3QkG3OuqGVyS4Iek0AOLwdRAYoc= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.2.0 h1:62Ew5xXg5UCGIXDOM7+y4IL5/6mQJq1nenhBCJAeGX8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I= -github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= +github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -69,7 +69,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= -github.com/blacktear23/go-proxyprotocol v1.0.0 h1:WmMmtZanGEfIHnJN9N3A4Pl6mM69D+GxEph2eOaCf7g= +github.com/blacktear23/go-proxyprotocol v1.0.2 h1:zR7PZeoU0wAkElcIXenFiy3R56WB6A+UEVi4c6RH8wo= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= @@ -405,8 +405,8 @@ github.com/pingcap/failpoint v0.0.0-20220423142525-ae43b7f4e5c3/go.mod h1:4qGtCB github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 h1:Pe2LbxRmbTfAoKJ65bZLmhahmvHm7n9DUxGRQT00208= github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 h1:surzm05a8C9dN8dIUmo4Be2+pMRb6f55i+UIYrluu2E= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= -github.com/pingcap/kvproto v0.0.0-20220711062932-08b02befd813 h1:PAXtUVMJnyQQS8t9GzihIFmh6FBXu0JziWbIVknLniA= -github.com/pingcap/kvproto v0.0.0-20220711062932-08b02befd813/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/kvproto v0.0.0-20220804022843-f006036b1277 h1:4UQdx1acoUrQD0Q5Etz1ABd31duzSgp3XwEnb/cvV9I= +github.com/pingcap/kvproto v0.0.0-20220804022843-f006036b1277/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20210317133921-96f4fcab92a4/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= @@ -414,10 +414,10 @@ github.com/pingcap/log v1.1.0 h1:ELiPxACz7vdo1qAvvaWJg1NrYFoY6gqAh/+Uo6aXdD8= github.com/pingcap/log v1.1.0/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/parser v0.0.0-20210415081931-48e7f467fd74/go.mod h1:xZC8I7bug4GJ5KtHhgAikjTfU4kBv1Sbo3Pf1MZ6lVw= github.com/pingcap/sysutil v0.0.0-20220114020952-ea68d2dbf5b4 h1:HYbcxtnkN3s5tqrZ/z3eJS4j3Db8wMphEm1q10lY/TM= -github.com/pingcap/tidb v1.1.0-beta.0.20220804075006-e071841317c5 h1:ykpPpdxg7bJpsKUiMpvieknpPECqZVgChTUBENfQgTw= -github.com/pingcap/tidb v1.1.0-beta.0.20220804075006-e071841317c5/go.mod h1:VaKuKfDKAu2fJA7U1WqGcrcNfOcTdIF9oX0IylG0FA4= -github.com/pingcap/tidb/parser v0.0.0-20220804082206-fff748348776 h1:dvxE+3Vle5PH5TfDm+07VufqROyZatphja10bnA7I3c= -github.com/pingcap/tidb/parser v0.0.0-20220804082206-fff748348776/go.mod h1:wjvp+T3/T9XYt0nKqGX3Kc1AKuyUcfno6LTc6b2A6ew= +github.com/pingcap/tidb v1.1.0-beta.0.20220819110652-8b5b724d8a93 h1:t0wrJTmfqTPe8e7+AaUNiS0LJIegigflRWP6jnGxAco= +github.com/pingcap/tidb v1.1.0-beta.0.20220819110652-8b5b724d8a93/go.mod h1:ibrqg2O6i98YbT6al8tpoz824bcHQlQKyV7VxpC1RH0= +github.com/pingcap/tidb/parser v0.0.0-20220819110652-8b5b724d8a93 h1:nqdE7w2y4UNCfvudEZec6ijA6Ju+1AyvSISQgSvX8Ps= +github.com/pingcap/tidb/parser v0.0.0-20220819110652-8b5b724d8a93/go.mod h1:wjvp+T3/T9XYt0nKqGX3Kc1AKuyUcfno6LTc6b2A6ew= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= @@ -478,7 +478,7 @@ github.com/shirou/gopsutil/v3 v3.22.6/go.mod h1:EdIubSnZhbAvBS1yJ7Xi+AShB/hxwLHO github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/httpfs v0.0.0-20190707220628-8d4bc4ba7749 h1:bUGsEnyNbVPw06Bs80sCeARAlK8lhwqGyi6UT8ymuGk= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/vfsgen v0.0.0-20200824052919-0d455de96546 h1:pXY9qYc/MP5zdvqWEUH6SjNiu7VhSjuVFTFiTcphaLU= +github.com/shurcooL/vfsgen v0.0.0-20180711163814-62bca832be04 h1:y0cMJ0qjii33BnD6tMGcF/+gHYsoKQ6tbwQpy233OII= github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM= github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 h1:oI+RNwuC9jF2g2lP0u0cVEEZrc/AYBCuFdvwrLWM/6Q= @@ -487,8 +487,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -522,8 +522,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= -github.com/tikv/client-go/v2 v2.0.1-0.20220729034404-e10841f2d158 h1:oCtRW/f0FZabdoLuvqxIewcmHR83RlsdN37dS0EBRTU= -github.com/tikv/client-go/v2 v2.0.1-0.20220729034404-e10841f2d158/go.mod h1:v3DEt8LS9olI6D6El17pYBWq7B28hw3NnDFTxQHDLpY= +github.com/tikv/client-go/v2 v2.0.1-0.20220818084834-0d0ae0dcfb1f h1:/nr7P8uzJQ7u3wPEBHCokrsVmuDvi/1x/zI/ydk5n8U= +github.com/tikv/client-go/v2 v2.0.1-0.20220818084834-0d0ae0dcfb1f/go.mod h1:v3DEt8LS9olI6D6El17pYBWq7B28hw3NnDFTxQHDLpY= github.com/tikv/pd/client v0.0.0-20220725055910-7187a7ab72db h1:r1eMh9Rny3hfWuBuxOnbsCRrR4FhthiNxLQ5rAUtaww= github.com/tikv/pd/client v0.0.0-20220725055910-7187a7ab72db/go.mod h1:ew8kS0yIcEaSetuuywkTLIUBR+sz3J5XvAYRae11qwc= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= @@ -786,6 +786,7 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 82087362..4ffe005b 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -30,7 +30,8 @@ const supportedServerCapabilities = mysql.ClientLongPassword | mysql.ClientFound mysql.ClientConnectWithDB | mysql.ClientNoSchema | mysql.ClientODBC | mysql.ClientLocalFiles | mysql.ClientIgnoreSpace | mysql.ClientProtocol41 | mysql.ClientInteractive | mysql.ClientSSL | mysql.ClientIgnoreSigpipe | mysql.ClientTransactions | mysql.ClientReserved | mysql.ClientSecureConnection | mysql.ClientMultiStatements | - mysql.ClientMultiResults | mysql.ClientPluginAuth | mysql.ClientConnectAtts | mysql.ClientPluginAuthLenencClientData + mysql.ClientMultiResults | mysql.ClientPluginAuth | mysql.ClientConnectAtts | mysql.ClientPluginAuthLenencClientData | + mysql.ClientDeprecateEOF // Authenticator handshakes with the client and the backend. type Authenticator struct { @@ -103,12 +104,9 @@ func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO if err = backendIO.WritePacket(clientPkt, true); err != nil { return err } - if err = auth.readHandshakeResponse(clientPkt); err != nil { + if err = auth.readHandshakeResponse(clientPkt, serverCapability); err != nil { return err } - if unsupported := serverCapability & auth.capability &^ supportedServerCapabilities; unsupported > 0 { - return errors.Errorf("server capability is not supported: %d", unsupported) - } // verify password for { @@ -138,14 +136,17 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) { return } -func (auth *Authenticator) readHandshakeResponse(data []byte) error { +func (auth *Authenticator) readHandshakeResponse(data []byte, serverCapability uint32) error { capability := uint32(binary.LittleEndian.Uint16(data[:2])) if capability&mysql.ClientProtocol41 == 0 { // TiDB doesn't support it now. return errors.New("pre-4.1 MySQL client versions are not supported") } resp := pnet.ParseHandshakeResponse(data) - auth.capability = resp.Capability + auth.capability = resp.Capability & serverCapability + if unsupported := auth.capability &^ supportedServerCapabilities; unsupported > 0 { + return errors.Errorf("capability is not supported: %d", unsupported) + } auth.user = resp.User auth.dbname = resp.DB auth.collation = resp.Collation diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 844f0d0e..674d2220 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -17,6 +17,7 @@ package backend import ( "context" "crypto/tls" + "encoding/binary" "encoding/json" "fmt" "strings" @@ -99,6 +100,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, c if err := mgr.authenticator.handshakeFirstTime(clientIO, backendIO, serverTLSConfig, backendTLSConfig); err != nil { return err } + mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) go mgr.processSignals(childCtx) mgr.cancelFunc = cancelFunc @@ -119,6 +121,15 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c switch request[0] { case mysql.ComQuit: return nil + case mysql.ComSetOption: + switch binary.LittleEndian.Uint16(request[1:]) { + case 0: + mgr.authenticator.capability |= mysql.ClientMultiStatements + mgr.cmdProcessor.capability |= mysql.ClientMultiStatements + case 1: + mgr.authenticator.capability &^= mysql.ClientMultiStatements + mgr.cmdProcessor.capability &^= mysql.ClientMultiStatements + } case mysql.ComChangeUser: username, db := pnet.ParseChangeUser(request) mgr.authenticator.changeUser(username, db) diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index c2223dd5..382619fb 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -423,3 +423,56 @@ func TestRedirectFail(t *testing.T) { } ts.runTests(runners) } + +// Test that the proxy sends the right handshake info after COM_CHANGE_USER and COM_SET_OPTION. +func TestSpecialCmds(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // change user + { + client: func(packetIO *pnet.PacketIO) error { + ts.mc.cmd = mysql.ComChangeUser + ts.mc.username = "another_user" + ts.mc.dbName = "another_db" + return ts.mc.request(packetIO) + }, + proxy: ts.forwardCmd4Proxy, + backend: ts.respondWithNoTxn4Backend, + }, + // disable multi-stmts + { + client: func(packetIO *pnet.PacketIO) error { + ts.mc.cmd = mysql.ComSetOption + ts.mc.dataBytes = []byte{1, 0} + return ts.mc.request(packetIO) + }, + proxy: ts.forwardCmd4Proxy, + backend: ts.respondWithNoTxn4Backend, + }, + // 2nd handshake + { + client: nil, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + backend1 := ts.mp.backendConn + ts.mp.Redirect(ts.tc.backendListener.Addr().String()) + ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(t, eventSucceed) + require.NotEqual(t, backend1, ts.mp.backendConn) + return nil + }, + backend: func(packetIO *pnet.PacketIO) error { + require.NoError(t, ts.redirectSucceed4Backend(packetIO)) + require.Equal(t, "another_user", ts.mb.username) + require.Equal(t, "another_db", ts.mb.db) + require.Equal(t, defaultClientCapability&^mysql.ClientMultiStatements, ts.mb.clientCapability) + return nil + }, + }, + } + ts.runTests(runners) +} diff --git a/pkg/proxy/backend/cmd_processor.go b/pkg/proxy/backend/cmd_processor.go index 5e560c69..dd0e26ba 100644 --- a/pkg/proxy/backend/cmd_processor.go +++ b/pkg/proxy/backend/cmd_processor.go @@ -31,6 +31,8 @@ const ( // CmdProcessor maintains the transaction and prepared statement status and decides whether the session can be redirected. type CmdProcessor struct { + capability uint32 + // Only includes in_trans or quit status. serverStatus uint32 // Each prepared statement has an independent status. preparedStmtStatus map[int]uint32 diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 988792eb..3c66e398 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -86,7 +86,11 @@ func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, reque case mysql.ErrHeader: return cp.handleErrorPacket(response) case mysql.EOFHeader: - cp.handleEOFPacket(request, response) + if cp.capability&mysql.ClientDeprecateEOF == 0 { + cp.handleEOFPacket(request, response) + } else { + cp.handleOKPacket(request, response) + } return nil } // impossible here @@ -100,14 +104,27 @@ func forwardOnePacket(destIO, srcIO *pnet.PacketIO, flush bool) (data []byte, er return data, destIO.WritePacket(data, flush) } -func forwardUntilEOF(clientIO, backendIO *pnet.PacketIO) (eofPacket []byte, err error) { - var response []byte +func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) { for { - if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil { - return + response, err := forwardOnePacket(clientIO, backendIO, false) + if err != nil { + return 0, err } - if pnet.IsEOFPacket(response) { - return response, nil + if response[0] == mysql.ErrHeader { + if err := clientIO.Flush(); err != nil { + return 0, err + } + return 0, cp.handleErrorPacket(response) + } + if cp.capability&mysql.ClientDeprecateEOF == 0 { + if pnet.IsEOFPacket(response) { + return cp.handleEOFPacket(request, response), clientIO.Flush() + } + } else { + if pnet.IsResultSetOKPacket(response) { + rs := cp.handleOKPacket(request, response) + return rs.Status, clientIO.Flush() + } } } } @@ -117,21 +134,24 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er if err != nil { return err } - // The OK packet doesn't contain a server status. switch response[0] { case mysql.OKHeader: - expectedEOFNum := 0 + // The OK packet doesn't contain a server status. + // See https://mariadb.com/kb/en/com_stmt_prepare/ numColumns := binary.LittleEndian.Uint16(response[5:]) - if numColumns > 0 { - expectedEOFNum++ - } numParams := binary.LittleEndian.Uint16(response[7:]) - if numParams > 0 { - expectedEOFNum++ + expectedPackets := int(numColumns) + int(numParams) + if cp.capability&mysql.ClientDeprecateEOF == 0 { + if numColumns > 0 { + expectedPackets++ + } + if numParams > 0 { + expectedPackets++ + } } - for i := 0; i < expectedEOFNum; i++ { + for i := 0; i < expectedPackets; i++ { // Ignore this status because PREPARE doesn't affect status. - if _, err = forwardUntilEOF(clientIO, backendIO); err != nil { + if _, err = forwardOnePacket(clientIO, backendIO, false); err != nil { return err } } @@ -147,23 +167,13 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er } func (cp *CmdProcessor) forwardFetchCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error { - response, err := forwardOnePacket(clientIO, backendIO, false) - if err != nil { - return err - } - if response[0] == mysql.ErrHeader { - if err := clientIO.Flush(); err != nil { - return err - } - return cp.handleErrorPacket(response) - } - if !pnet.IsEOFPacket(response) { - if response, err = forwardUntilEOF(clientIO, backendIO); err != nil { - return err - } - } - cp.handleEOFPacket(request, response) - return clientIO.Flush() + _, err := cp.forwardUntilResultEnd(clientIO, backendIO, request) + return err +} + +func (cp *CmdProcessor) forwardFieldListCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error { + _, err := cp.forwardUntilResultEnd(clientIO, backendIO, request) + return err } func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error { @@ -229,36 +239,27 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re return serverStatus, errors.Errorf("unexpected response, cmd:%d resp:%d", mysql.ComQuery, response[0]) } -func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, request, response []byte) (serverStatus uint16, err error) { - if !pnet.IsEOFPacket(response) { +func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, request, response []byte) (uint16, error) { + if cp.capability&mysql.ClientDeprecateEOF == 0 { // read columns - if response, err = forwardUntilEOF(clientIO, backendIO); err != nil { - return - } - } - serverStatus = binary.LittleEndian.Uint16(response[3:]) - // If a cursor exists, only columns are sent this time. The client will then send COM_STMT_FETCH to fetch rows. - // Otherwise, columns and rows are both sent once. - if serverStatus&mysql.ServerStatusCursorExists == 0 { - // read rows for { + var err error if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil { - return - } - // An error may occur when the backend writes rows. - if response[0] == mysql.ErrHeader { - if err = clientIO.Flush(); err != nil { - return serverStatus, err - } - return serverStatus, cp.handleErrorPacket(response) + return 0, err } if pnet.IsEOFPacket(response) { break } } + serverStatus := binary.LittleEndian.Uint16(response[3:]) + // If a cursor exists, only columns are sent this time. The client will then send COM_STMT_FETCH to fetch rows. + // Otherwise, columns and rows are both sent once. + if serverStatus&mysql.ServerStatusCursorExists > 0 { + return cp.handleEOFPacket(request, response), clientIO.Flush() + } } - serverStatus = cp.handleEOFPacket(request, response) - return serverStatus, clientIO.Flush() + // Deprecate EOF or no cursor. + return cp.forwardUntilResultEnd(clientIO, backendIO, request) } func (cp *CmdProcessor) forwardCloseCmd(request []byte) error { @@ -296,25 +297,6 @@ func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO, } } -func (cp *CmdProcessor) forwardFieldListCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error { - response, err := forwardOnePacket(clientIO, backendIO, false) - if err != nil { - return err - } - if response[0] == mysql.ErrHeader { - if err = clientIO.Flush(); err != nil { - return err - } - return cp.handleErrorPacket(response) - } - // It sends some columns and an EOF packet. - if !pnet.IsEOFPacket(response) { - response, err = forwardUntilEOF(clientIO, backendIO) - } - cp.handleEOFPacket(request, response) - return clientIO.Flush() -} - func (cp *CmdProcessor) forwardStatisticsCmd(clientIO, backendIO *pnet.PacketIO) error { // It just sends a string. _, err := forwardOnePacket(clientIO, backendIO, true) diff --git a/pkg/proxy/backend/cmd_processor_query.go b/pkg/proxy/backend/cmd_processor_query.go index 2860cee0..b2899c94 100644 --- a/pkg/proxy/backend/cmd_processor_query.go +++ b/pkg/proxy/backend/cmd_processor_query.go @@ -19,7 +19,7 @@ import ( gomysql "github.com/go-mysql-org/go-mysql/mysql" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" - "github.com/pingcap/errors" + "github.com/pingcap/TiProxy/pkg/util/errors" "github.com/pingcap/tidb/parser/mysql" "github.com/siddontang/go/hack" ) @@ -47,10 +47,8 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy result = cp.handleOKPacket(request, response) case mysql.ErrHeader: err = cp.handleErrorPacket(response) - case mysql.EOFHeader: - cp.handleEOFPacket(request, response) case mysql.LocalInFileHeader: - err = mysql.ErrMalformPacket + err = errors.WithStack(mysql.ErrMalformPacket) default: result, err = cp.readResultSet(packetIO, response) } @@ -61,7 +59,7 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*gomysql.Result, error) { columnCount, _, n := pnet.ParseLengthEncodedInt(data) if n-len(data) != 0 { - return nil, mysql.ErrMalformPacket + return nil, errors.WithStack(mysql.ErrMalformPacket) } result := &gomysql.Result{ @@ -81,22 +79,26 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys var data []byte for { + if fieldIndex == len(result.Fields) { + if cp.capability&mysql.ClientDeprecateEOF == 0 { + if data, err = packetIO.ReadPacket(); err != nil { + return err + } + if !pnet.IsEOFPacket(data) { + return errors.WithStack(mysql.ErrMalformPacket) + } + result.Status = binary.LittleEndian.Uint16(data[3:]) + } + return nil + } if data, err = packetIO.ReadPacket(); err != nil { return err } - if pnet.IsEOFPacket(data) { - result.Status = binary.LittleEndian.Uint16(data[3:]) - if fieldIndex != len(result.Fields) { - err = errors.Trace(mysql.ErrMalformPacket) - } - return - } - if result.Fields[fieldIndex] == nil { result.Fields[fieldIndex] = &gomysql.Field{} } if err = result.Fields[fieldIndex].Parse(data); err != nil { - return errors.Trace(err) + return errors.WithStack(err) } fieldName := hack.String(result.Fields[fieldIndex].Name) result.FieldNames[fieldName] = fieldIndex @@ -111,9 +113,17 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql. if data, err = packetIO.ReadPacket(); err != nil { return err } - if pnet.IsEOFPacket(data) { - result.Status = binary.LittleEndian.Uint16(data[3:]) - break + if cp.capability&mysql.ClientDeprecateEOF == 0 { + if pnet.IsEOFPacket(data) { + result.Status = binary.LittleEndian.Uint16(data[3:]) + break + } + } else { + if pnet.IsResultSetOKPacket(data) { + rs := pnet.ParseOKPacket(data) + result.Status = rs.Status + break + } } // An error may occur when the backend writes rows. if data[0] == mysql.ErrHeader { @@ -131,7 +141,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql. for i := range result.Values { result.Values[i], err = result.RowDatas[i].Parse(result.Fields, false, result.Values[i]) if err != nil { - return errors.Trace(err) + return errors.WithStack(err) } } return nil diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index 23ace73e..d6eb2bb6 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -85,55 +85,60 @@ func TestForwardCommands(t *testing.T) { // Test every respond type for every command. for cmd, respondTypes := range cmdResponseTypes { for _, respondType := range respondTypes { - cfgOvr := func(cfg *testConfig) { - cfg.clientConfig.cmd = cmd - cfg.backendConfig.respondType = respondType - } - // Test more variables for some special response types. - switch respondType { - case responseTypeColumn: - for _, columns := range []int{1, 4096} { - extraCfgOvr := func(cfg *testConfig) { - cfg.backendConfig.columns = columns - } - runTest(cfgOvr, extraCfgOvr) - } - case responseTypeRow: - for _, rows := range []int{0, 1, 3} { - extraCfgOvr := func(cfg *testConfig) { - cfg.backendConfig.rows = rows - } - runTest(cfgOvr, extraCfgOvr) + for _, capability := range []uint32{defaultBackendCapability &^ mysql.ClientDeprecateEOF, defaultBackendCapability | mysql.ClientDeprecateEOF} { + cfgOvr := func(cfg *testConfig) { + cfg.clientConfig.cmd = cmd + cfg.backendConfig.respondType = respondType + cfg.backendConfig.capability = capability + cfg.clientConfig.capability = capability + cfg.proxyConfig.capability = capability } - case responseTypePrepareOK: - for _, columns := range []int{0, 1, 4096} { - for _, params := range []int{0, 1, 3} { + // Test more variables for some special response types. + switch respondType { + case responseTypeColumn: + for _, columns := range []int{1, 4096} { extraCfgOvr := func(cfg *testConfig) { cfg.backendConfig.columns = columns - cfg.backendConfig.params = params } runTest(cfgOvr, extraCfgOvr) } - } - case responseTypeResultSet: - for _, columns := range []int{1, 4096} { + case responseTypeRow: for _, rows := range []int{0, 1, 3} { extraCfgOvr := func(cfg *testConfig) { - cfg.backendConfig.columns = columns cfg.backendConfig.rows = rows } runTest(cfgOvr, extraCfgOvr) } - } - case responseTypeLoadFile: - for _, filePkts := range []int{0, 1, 3} { - extraCfgOvr := func(cfg *testConfig) { - cfg.clientConfig.filePkts = filePkts + case responseTypePrepareOK: + for _, columns := range []int{0, 1, 4096} { + for _, params := range []int{0, 1, 3} { + extraCfgOvr := func(cfg *testConfig) { + cfg.backendConfig.columns = columns + cfg.backendConfig.params = params + } + runTest(cfgOvr, extraCfgOvr) + } } - runTest(cfgOvr, extraCfgOvr) + case responseTypeResultSet: + for _, columns := range []int{1, 4096} { + for _, rows := range []int{0, 1, 3} { + extraCfgOvr := func(cfg *testConfig) { + cfg.backendConfig.columns = columns + cfg.backendConfig.rows = rows + } + runTest(cfgOvr, extraCfgOvr) + } + } + case responseTypeLoadFile: + for _, filePkts := range []int{0, 1, 3} { + extraCfgOvr := func(cfg *testConfig) { + cfg.clientConfig.filePkts = filePkts + } + runTest(cfgOvr, extraCfgOvr) + } + default: + runTest(cfgOvr, cfgOvr) } - default: - runTest(cfgOvr, cfgOvr) } } } @@ -153,6 +158,16 @@ func TestDirectQuery(t *testing.T) { cfg.backendConfig.respondType = responseTypeResultSet }, }, + { + cfg: func(cfg *testConfig) { + cfg.clientConfig.capability = defaultBackendCapability &^ mysql.ClientDeprecateEOF + cfg.proxyConfig.capability = cfg.clientConfig.capability + cfg.backendConfig.capability = cfg.clientConfig.capability + cfg.backendConfig.columns = 2 + cfg.backendConfig.rows = 1 + cfg.backendConfig.respondType = responseTypeResultSet + }, + }, { cfg: func(cfg *testConfig) { cfg.backendConfig.respondType = responseTypeErr diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 648ac208..b915d1da 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -59,11 +59,12 @@ type mockBackend struct { // Inputs that assigned by the test and will be sent to the client. *backendConfig // Outputs that received from the client and will be checked by the test. - username string - authData []byte - db string - attrs []byte - err error + username string + authData []byte + db string + attrs []byte + clientCapability uint32 + err error } func newMockBackend(cfg *backendConfig) *mockBackend { @@ -103,6 +104,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { mb.db = resp.DB mb.authData = resp.AuthData mb.attrs = resp.Attrs + mb.clientCapability = resp.Capability // verify password return mb.verifyPassword(packetIO, resp) } @@ -127,7 +129,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh } } if mb.authSucceed { - if err := packetIO.WriteOKPacket(mb.status); err != nil { + if err := packetIO.WriteOKPacket(mb.status, mysql.OKHeader); err != nil { return err } } else { @@ -181,7 +183,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { if _, err := packetIO.ReadPacket(); err != nil { return err } - return packetIO.WriteOKPacket(mb.status) + return packetIO.WriteOKPacket(mb.status, mysql.OKHeader) case responseTypePrepareOK: return mb.respondPrepare(packetIO) case responseTypeRow: @@ -200,7 +202,7 @@ func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error { } else { status &= ^mysql.ServerMoreResultsExists } - if err := packetIO.WriteOKPacket(status); err != nil { + if err := packetIO.WriteOKPacket(status, mysql.OKHeader); err != nil { return err } } @@ -214,7 +216,14 @@ func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error { return err } } - return packetIO.WriteEOFPacket(mb.status) + return mb.writeResultEndPacket(packetIO, mb.status) +} + +func (mb *mockBackend) writeResultEndPacket(packetIO *pnet.PacketIO, status uint16) error { + if mb.capability&mysql.ClientDeprecateEOF > 0 { + return packetIO.WriteOKPacket(status, mysql.EOFHeader) + } + return packetIO.WriteEOFPacket(status) } // respond to Fetch @@ -224,7 +233,7 @@ func (mb *mockBackend) respondRows(packetIO *pnet.PacketIO) error { return err } } - return packetIO.WriteEOFPacket(mb.status) + return mb.writeResultEndPacket(packetIO, mb.status) } // respond to Query @@ -265,11 +274,13 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v return err } } - if err := packetIO.WriteEOFPacket(status); err != nil { - return err - } if status&mysql.ServerStatusCursorExists == 0 { + if mb.capability&mysql.ClientDeprecateEOF == 0 { + if err := packetIO.WriteEOFPacket(status); err != nil { + return err + } + } for _, row := range values { var data []byte for _, value := range row { @@ -279,9 +290,9 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v return err } } - if err := packetIO.WriteEOFPacket(status); err != nil { - return err - } + } + if err := mb.writeResultEndPacket(packetIO, status); err != nil { + return err } } return nil @@ -313,7 +324,7 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { break } } - if err := packetIO.WriteOKPacket(status); err != nil { + if err := packetIO.WriteOKPacket(status, mysql.OKHeader); err != nil { return err } } @@ -337,8 +348,10 @@ func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { return err } } - if err := packetIO.WriteEOFPacket(mb.status); err != nil { - return err + if mb.capability&mysql.ClientDeprecateEOF == 0 { + if err := packetIO.WriteEOFPacket(mb.status); err != nil { + return err + } } } if mb.columns > 0 { @@ -347,11 +360,13 @@ func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { return err } } - if err := packetIO.WriteEOFPacket(mb.status); err != nil { - return err + if mb.capability&mysql.ClientDeprecateEOF == 0 { + if err := packetIO.WriteEOFPacket(mb.status); err != nil { + return err + } } } - return nil + return packetIO.Flush() } func (mb *mockBackend) respondSessionStates(packetIO *pnet.PacketIO) error { diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index 2854171e..ad4d7e52 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -36,6 +36,7 @@ type clientConfig struct { attrs []byte // for cmd cmd byte + dataBytes []byte filePkts int prepStmtID int sql string @@ -50,6 +51,7 @@ func newClientConfig() *clientConfig { authData: mockAuthData, attrs: make([]byte, 0), cmd: mysql.ComQuery, + dataBytes: mockCmdBytes, sql: mockCmdStr, } } @@ -129,7 +131,7 @@ func (mc *mockClient) request(packetIO *pnet.PacketIO) error { case mysql.ComFieldList: return mc.requestFieldList(packetIO) case mysql.ComRefresh, mysql.ComSetOption: - data = append(data, mockCmdByte) + data = append(data, mc.dataBytes...) case mysql.ComProcessKill: data = pnet.DumpUint32(data, uint32(mockCmdInt)) case mysql.ComChangeUser: @@ -138,7 +140,7 @@ func (mc *mockClient) request(packetIO *pnet.PacketIO) error { return mc.requestPrepare(packetIO) case mysql.ComStmtSendLongData: data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) - data = append(data, mockCmdBytes...) + data = append(data, mc.dataBytes...) case mysql.ComStmtExecute: return mc.requestExecute(packetIO) case mysql.ComStmtFetch: @@ -158,7 +160,7 @@ func (mc *mockClient) request(packetIO *pnet.PacketIO) error { } func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error { - data := pnet.MakeChangeUser(mockUsername, mockDBName, mockAuthData) + data := pnet.MakeChangeUser(mc.username, mc.dbName, mc.authData) if err := packetIO.WritePacket(data, true); err != nil { return err } @@ -189,35 +191,33 @@ func (mc *mockClient) requestPrepare(packetIO *pnet.PacketIO) error { if err != nil { return err } - expectedEOFNum := 0 + expectedPacketNum := 0 if response[0] == mysql.OKHeader { numColumns := binary.LittleEndian.Uint16(response[5:]) - if numColumns > 0 { - expectedEOFNum++ - } numParams := binary.LittleEndian.Uint16(response[7:]) - if numParams > 0 { - expectedEOFNum++ - } - } - for i := 0; i < expectedEOFNum; i++ { - for { - if response, err = packetIO.ReadPacket(); err != nil { - return err + expectedPacketNum = int(numColumns) + int(numParams) + if mc.capability&mysql.ClientDeprecateEOF == 0 { + if numColumns > 0 { + expectedPacketNum++ } - if pnet.IsEOFPacket(response) { - break + if numParams > 0 { + expectedPacketNum++ } } } + for i := 0; i < expectedPacketNum; i++ { + if response, err = packetIO.ReadPacket(); err != nil { + return err + } + } return nil } func (mc *mockClient) requestExecute(packetIO *pnet.PacketIO) error { - data := make([]byte, 0, len(mockCmdBytes)+5) + data := make([]byte, 0, len(mc.dataBytes)+5) data = append(data, mysql.ComStmtExecute) data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) - data = append(data, mockCmdBytes...) + data = append(data, mc.dataBytes...) if err := packetIO.WritePacket(data, true); err != nil { return err } @@ -225,14 +225,15 @@ func (mc *mockClient) requestExecute(packetIO *pnet.PacketIO) error { } func (mc *mockClient) requestFetch(packetIO *pnet.PacketIO) error { - data := make([]byte, 0, len(mockCmdBytes)+5) + data := make([]byte, 0, len(mc.dataBytes)+5) data = append(data, mysql.ComStmtFetch) data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) - data = append(data, mockCmdBytes...) + data = append(data, mc.dataBytes...) if err := packetIO.WritePacket(data, true); err != nil { return err } - return mc.readErrOrUntilEOF(packetIO) + _, err := mc.readUntilResultEnd(packetIO) + return err } func (mc *mockClient) requestFieldList(packetIO *pnet.PacketIO) error { @@ -244,26 +245,30 @@ func (mc *mockClient) requestFieldList(packetIO *pnet.PacketIO) error { if err := packetIO.WritePacket(data, true); err != nil { return err } - return mc.readErrOrUntilEOF(packetIO) + _, err := mc.readUntilResultEnd(packetIO) + return err } -func (mc *mockClient) readErrOrUntilEOF(packetIO *pnet.PacketIO) error { - pkt, err := packetIO.ReadPacket() - if err != nil { - return err - } - if pkt[0] == mysql.ErrHeader || pnet.IsEOFPacket(pkt) { - return nil - } +func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, err error) { for { - if pkt, err = packetIO.ReadPacket(); err != nil { - return err + pkt, err = packetIO.ReadPacket() + if err != nil { + return } - if pnet.IsEOFPacket(pkt) { - break + if pkt[0] == mysql.ErrHeader { + return + } + if mc.capability&mysql.ClientDeprecateEOF == 0 { + if pnet.IsEOFPacket(pkt) { + break + } + } else { + if pnet.IsResultSetOKPacket(pkt) { + break + } } } - return nil + return } func (mc *mockClient) requestProcessInfo(packetIO *pnet.PacketIO) error { @@ -297,7 +302,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { return nil case mysql.LocalInFileHeader: for i := 0; i < mc.filePkts; i++ { - if err = packetIO.WritePacket(mockCmdBytes, false); err != nil { + if err = packetIO.WritePacket(mc.dataBytes, false); err != nil { return err } } @@ -314,19 +319,29 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { } default: // read result set - for { - if pkt, err = packetIO.ReadPacket(); err != nil { + if mc.capability&mysql.ClientDeprecateEOF == 0 { + if pkt, err = mc.readUntilResultEnd(packetIO); err != nil { return err } - if pnet.IsEOFPacket(pkt) { + if pkt[0] == mysql.ErrHeader { + return nil + } + serverStatus = binary.LittleEndian.Uint16(pkt[3:]) + if serverStatus&mysql.ServerStatusCursorExists > 0 { break } } - serverStatus = binary.LittleEndian.Uint16(pkt[3:]) - if serverStatus&mysql.ServerStatusCursorExists == 0 { - if err = mc.readErrOrUntilEOF(packetIO); err != nil { - return err - } + if pkt, err = mc.readUntilResultEnd(packetIO); err != nil { + return err + } + if pkt[0] == mysql.ErrHeader { + return nil + } + if mc.capability&mysql.ClientDeprecateEOF == 0 { + serverStatus = binary.LittleEndian.Uint16(pkt[3:]) + } else { + rs := pnet.ParseOKPacket(pkt) + serverStatus = rs.Status } } if serverStatus&mysql.ServerMoreResultsExists == 0 { diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 39b3f4a4..069f67e7 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -24,12 +24,14 @@ import ( type proxyConfig struct { frontendTLSConfig *tls.Config backendTLSConfig *tls.Config + capability uint32 sessionToken string waitRedirect bool } func newProxyConfig() *proxyConfig { return &proxyConfig{ + capability: defaultBackendCapability, sessionToken: mockToken, } } @@ -45,10 +47,12 @@ type mockProxy struct { } func newMockProxy(cfg *proxyConfig) *mockProxy { - return &mockProxy{ + mp := &mockProxy{ proxyConfig: cfg, BackendConnManager: NewBackendConnManager(0), } + mp.cmdProcessor.capability = cfg.capability + return mp } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index c797d36e..3f56b4bf 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -33,11 +33,10 @@ import ( // sent from the server and vice versa. const ( - defaultBackendCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | - mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientSSL | - mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows | - mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles | - mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive + defaultBackendCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | + mysql.ClientProtocol41 | mysql.ClientSSL | mysql.ClientTransactions | mysql.ClientSecureConnection | + mysql.ClientFoundRows | mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles | + mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive | mysql.ClientDeprecateEOF defaultClientCapability = defaultBackendCapability ) @@ -49,7 +48,6 @@ var ( mockToken = strings.Repeat("t", 512) mockCmdStr = "str" mockCmdInt = 100 - mockCmdByte = byte(1) mockCmdBytes = []byte("01234567890123456789") mockSessionStates = "{\"current-db\":\"test_db\"}" ) diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index fc56b07b..4db03600 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -231,7 +231,7 @@ func MakeChangeUser(username, db string, authData []byte) []byte { // ParseChangeUser parses the data of COM_CHANGE_USER. func ParseChangeUser(data []byte) (username, db string) { - user, data := ParseNullTermString(data) + user, data := ParseNullTermString(data[1:]) username = string(user) passLen := int(data[0]) data = data[passLen+1:] @@ -271,3 +271,10 @@ func ParseErrorPacket(data []byte) error { func IsEOFPacket(data []byte) bool { return data[0] == mysql.EOFHeader && len(data) <= 5 } + +// IsResultSetOKPacket returns true if it's an OK packet after the result set when CLIENT_DEPRECATE_EOF is enabled. +// A row packet may also begin with 0xfe, so we need to judge it with the packet length. +// See https://mariadb.com/kb/en/result-set-packets/ +func IsResultSetOKPacket(data []byte) bool { + return data[0] == mysql.EOFHeader && len(data) > 5 && len(data) < 0xFFFFFF +} diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index f4b3039a..527364b5 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -106,9 +106,9 @@ func (p *PacketIO) WriteErrPacket(merr *mysql.SQLError) error { } // WriteOKPacket writes an OK packet. It's only for testing. -func (p *PacketIO) WriteOKPacket(status uint16) error { +func (p *PacketIO) WriteOKPacket(status uint16, header byte) error { data := make([]byte, 0, 7) - data = append(data, mysql.OKHeader) + data = append(data, header) data = append(data, 0, 0) // ClientProtocol41 must be enabled. data = DumpUint16(data, status) From 145e96ebc8a15e2b93d1deb14ebb47c8d30774d4 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Wed, 24 Aug 2022 15:56:54 +0800 Subject: [PATCH 2/2] handle invalid values in set_option --- pkg/proxy/backend/authenticator.go | 2 +- pkg/proxy/backend/backend_conn_mgr.go | 5 ++++- pkg/proxy/backend/cmd_processor_exec.go | 2 +- pkg/proxy/backend/cmd_processor_query.go | 2 +- pkg/proxy/net/mysql.go | 13 ++++++++++++- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 4ffe005b..6440756f 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -179,7 +179,7 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve if serverPkt, err = backendIO.ReadPacket(); err != nil { return } - if serverPkt[0] == mysql.ErrHeader { + if pnet.IsErrorPacket(serverPkt) { err = pnet.ParseErrorPacket(serverPkt) return } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 674d2220..87710911 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -122,13 +122,16 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c case mysql.ComQuit: return nil case mysql.ComSetOption: - switch binary.LittleEndian.Uint16(request[1:]) { + val := binary.LittleEndian.Uint16(request[1:]) + switch val { case 0: mgr.authenticator.capability |= mysql.ClientMultiStatements mgr.cmdProcessor.capability |= mysql.ClientMultiStatements case 1: mgr.authenticator.capability &^= mysql.ClientMultiStatements mgr.cmdProcessor.capability &^= mysql.ClientMultiStatements + default: + return errors.Errorf("unrecognized set_option value:%d", val) } case mysql.ComChangeUser: username, db := pnet.ParseChangeUser(request) diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 3c66e398..77321cc2 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -110,7 +110,7 @@ func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO if err != nil { return 0, err } - if response[0] == mysql.ErrHeader { + if pnet.IsErrorPacket(response) { if err := clientIO.Flush(); err != nil { return 0, err } diff --git a/pkg/proxy/backend/cmd_processor_query.go b/pkg/proxy/backend/cmd_processor_query.go index b2899c94..7c11f7ed 100644 --- a/pkg/proxy/backend/cmd_processor_query.go +++ b/pkg/proxy/backend/cmd_processor_query.go @@ -126,7 +126,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql. } } // An error may occur when the backend writes rows. - if data[0] == mysql.ErrHeader { + if pnet.IsErrorPacket(data) { return cp.handleErrorPacket(data) } result.RowDatas = append(result.RowDatas, data) diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 4db03600..dd4b09c4 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -267,6 +267,11 @@ func ParseErrorPacket(data []byte) error { return e } +// IsOKPacket returns true if it's an OK packet (but not ResultSet OK). +func IsOKPacket(data []byte) bool { + return data[0] == mysql.OKHeader +} + // IsEOFPacket returns true if it's an EOF packet. func IsEOFPacket(data []byte) bool { return data[0] == mysql.EOFHeader && len(data) <= 5 @@ -276,5 +281,11 @@ func IsEOFPacket(data []byte) bool { // A row packet may also begin with 0xfe, so we need to judge it with the packet length. // See https://mariadb.com/kb/en/result-set-packets/ func IsResultSetOKPacket(data []byte) bool { - return data[0] == mysql.EOFHeader && len(data) > 5 && len(data) < 0xFFFFFF + // With CLIENT_PROTOCOL_41 enabled, the least length is 7. + return data[0] == mysql.EOFHeader && len(data) >= 7 && len(data) < 0xFFFFFF +} + +// IsErrorPacket returns true if it's an error packet. +func IsErrorPacket(data []byte) bool { + return data[0] == mysql.ErrHeader }