diff --git a/go.mod b/go.mod index 105e64417..d7405820e 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/lib/pq v1.10.9 github.com/libp2p/go-reuseport v0.4.0 github.com/opentracing/opentracing-go v1.2.0 - github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1 + github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 github.com/sevlyar/go-daemon v0.1.6 diff --git a/go.sum b/go.sum index 8bc4d86f1..6329f36c4 100644 --- a/go.sum +++ b/go.sum @@ -164,8 +164,8 @@ github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrB github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1 h1:AwlQkwnrqRyL8lqZTTAzfQ09niEc+6oFiDvQkMImTPE= -github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58= +github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c h1:4sXBG7ZDtG/rN2jqgmzsMawfcTKQvTCTTo8iQ7eR6VU= +github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/router/frontend/frontend_test.go b/router/frontend/frontend_test.go index cdb8daa1a..4b2de149b 100644 --- a/router/frontend/frontend_test.go +++ b/router/frontend/frontend_test.go @@ -329,6 +329,15 @@ func TestFrontendSimpleCopyIn(t *testing.T) { qr := mockqr.NewMockQueryRouter(ctrl) cmngr := mockcmgr.NewMockPoolMgr(ctrl) + sh1 := mocksh.NewMockShard(ctrl) + sh1.EXPECT().Name().AnyTimes().Return("sh1") + sh1.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh1"}) + sh1.EXPECT().ID().AnyTimes().Return(uint(1)) + sh2 := mocksh.NewMockShard(ctrl) + sh2.EXPECT().Name().AnyTimes().Return("sh2") + sh2.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh2"}) + sh2.EXPECT().ID().AnyTimes().Return(uint(2)) + frrule := &config.FrontendRule{ DB: "db1", Usr: "user1", @@ -337,7 +346,7 @@ func TestFrontendSimpleCopyIn(t *testing.T) { beRule := &config.BackendRule{} srv.EXPECT().Name().AnyTimes().Return("serv1") - srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{}) + srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{sh1, sh2}) cl.EXPECT().Server().AnyTimes().Return(srv) cl.EXPECT().MaintainParams().AnyTimes().Return(false) @@ -375,22 +384,29 @@ func TestFrontendSimpleCopyIn(t *testing.T) { cmngr.EXPECT().TXEndCB(gomock.Any()).AnyTimes() + tableref := &lyx.RangeVar{ + RelationName: "xx", + } + qr.EXPECT().Route(gomock.Any(), &lyx.Copy{ - TableRef: &lyx.RangeVar{ - RelationName: "xx", - }, - Where: &lyx.AExprEmpty{}, - IsFrom: true, - }, gomock.Any()).Return(routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh1", - }, - }, - }, nil).Times(1) + TableRef: tableref, + Where: &lyx.AExprEmpty{}, + IsFrom: true, + }, gomock.Any()).Return(routingstate.MultiMatchState{}, nil).Times(1) + + qr.EXPECT().Route(gomock.Any(), &lyx.Insert{ + TableRef: tableref, + SubSelect: &lyx.ValueClause{Values: []lyx.Node{&lyx.AExprSConst{Value: "1"}}}, + }, cl).Times(4).Return(routingstate.ShardMatchState{Route: &routingstate.DataShardRoute{Shkey: sh1.SHKey()}}, nil) + + qr.EXPECT().DataShardsRoutes().AnyTimes().Return([]*routingstate.DataShardRoute{ + &routingstate.DataShardRoute{Shkey: sh1.SHKey()}, + &routingstate.DataShardRoute{Shkey: sh2.SHKey()}}, + ) route := route.NewRoute(beRule, frrule, map[string]*config.Shard{ "sh1": {}, + "sh2": {}, }) cl.EXPECT().Route().AnyTimes().Return(route) @@ -401,12 +417,12 @@ func TestFrontendSimpleCopyIn(t *testing.T) { cl.EXPECT().Receive().Times(1).Return(query, nil) - cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{}, nil) + cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{Data: []byte("1\n")}, nil) cl.EXPECT().Receive().Times(1).Return(&pgproto3.CopyDone{}, nil) srv.EXPECT().Send(query).Times(1).Return(nil) - srv.EXPECT().Send(&pgproto3.CopyData{}).Times(4).Return(nil) + sh1.EXPECT().Send(&pgproto3.CopyData{Data: []byte("1\n")}).Times(4).Return(nil) srv.EXPECT().Send(&pgproto3.CopyDone{}).Times(1).Return(nil) srv.EXPECT().Receive().Times(1).Return(&pgproto3.CopyInResponse{}, nil) diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index b1d630fc7..9a2b0cd9c 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -667,21 +667,6 @@ func (qr *ProxyQrouter) deparseShardingMapping( _ = qr.deparseFromNode(stmt.TableRef, meta) - return qr.routeByClause(ctx, clause, meta) - case *lyx.Copy: - if !stmt.IsFrom { - return fmt.Errorf("copy from stdin is not implemented") - } - - _ = qr.deparseFromNode(stmt.TableRef, meta) - - clause := stmt.Where - - if clause == nil { - // will not work - return nil - } - return qr.routeByClause(ctx, clause, meta) } @@ -1001,13 +986,15 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s return routingstate.RandomMatchState{}, nil } - case *lyx.Delete, *lyx.Update, *lyx.Copy: + case *lyx.Delete, *lyx.Update: // UPDATE and/or DELETE, COPY stmts, which // would be routed with their WHERE clause err := qr.deparseShardingMapping(ctx, stmt, meta) if err != nil { return nil, err } + case *lyx.Copy: + return routingstate.MultiMatchState{}, nil default: spqrlog.Zero.Debug().Interface("statement", stmt).Msg("proxy-routing message to all shards") } diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 53c3a1f2f..7865bf672 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -1426,25 +1426,8 @@ func TestCopySingleShard(t *testing.T) { for _, tt := range []tcase{ { query: "COPY xx FROM STDIN WHERE i = 1;", - exp: routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh1", - }, - Matchedkr: &kr.KeyRange{ - ShardID: "sh1", - ID: "id1", - Distribution: distribution, - LowerBound: []interface{}{ - int64(1), - }, - - ColumnTypes: []string{qdb.ColumnTypeInteger}, - }, - }, - TargetSessionAttrs: "any", - }, - err: nil, + exp: routingstate.MultiMatchState{}, + err: nil, }, } { parserRes, err := lyx.Parse(tt.query) diff --git a/router/relay/relay.go b/router/relay/relay.go index bf7b9a2f3..396991a3c 100644 --- a/router/relay/relay.go +++ b/router/relay/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "strings" "time" "github.com/pg-sharding/lyx/lyx" @@ -67,7 +68,7 @@ type RelayStateMgr interface { RelayRunCommand(msg pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error ProcQuery(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) (txstatus.TXStatus, []pgproto3.BackendMessage, bool, error) - ProcCopy(query pgproto3.FrontendMessage) error + ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error ProcCommand(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error @@ -656,15 +657,72 @@ func (rst *RelayStateImpl) RelayRunCommand(msg pgproto3.FrontendMessage, waitFor } // TODO : unit tests -func (rst *RelayStateImpl) ProcCopy(query pgproto3.FrontendMessage) error { +func (rst *RelayStateImpl) ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error { spqrlog.Zero.Debug(). Uint("client", rst.Client().ID()). - Type("query-type", query). Msg("client process copy") - _ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", query)) // TODO perfomance issue + _ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", data)) // TODO perfomance issue rst.Client().RLock() defer rst.Client().RUnlock() - return rst.Client().Server().Send(query) + + // Read delimiter from COPY options + delimiter := byte('\t') + for _, opt := range stmt.Options { + if o := opt.(*lyx.Option); strings.ToLower(o.Name) == "delimiter" { + delimiter = o.Arg.(*lyx.AExprSConst).Value[0] + } + } + + // Parse data + // and decide where to route + prevDelimiter := 0 + prevLine := 0 + valueClause := &lyx.ValueClause{} + for i, b := range data.Data { + if i+2 < len(data.Data) && string(data.Data[i:i+2]) == "\\." { + prevLine = len(data.Data) + break + } + if b == '\n' || b == delimiter { + valueClause.Values = append(valueClause.Values, &lyx.AExprSConst{Value: string(data.Data[prevDelimiter:i])}) + prevDelimiter = i + 1 + } + if b != '\n' { + continue + } + + // check where this tuple should go + r, err := rst.QueryRouter().Route(context.TODO(), &lyx.Insert{TableRef: stmt.TableRef, Columns: stmt.Columns, SubSelect: valueClause}, rst.Cl) + if err != nil { + return err + } + + smt, ok := r.(routingstate.ShardMatchState) + if !ok { + return fmt.Errorf("multishard copy is not supported") + } + + if expRoute.Shkey.Name == "" { + *expRoute = *smt.Route + } + if smt.Route.Shkey.Name != expRoute.Shkey.Name { + return fmt.Errorf("multishard copy is not supported") + } + + valueClause = &lyx.ValueClause{} + prevLine = i + 1 + } + + for _, sh := range rst.Client().Server().Datashards() { + if expRoute != nil && sh.Name() == expRoute.Shkey.Name { + err := sh.Send(&pgproto3.CopyData{Data: data.Data[:prevLine]}) + data.Data = data.Data[prevLine:] + return err + } + } + + // shouldn't exit from here + return nil } // TODO : unit tests @@ -680,16 +738,16 @@ func (rst *RelayStateImpl) ProcCopyComplete(query *pgproto3.FrontendMessage) err } for { - if msg, err := rst.Client().Server().Receive(); err != nil { + msg, err := rst.Client().Server().Receive() + if err != nil { return err - } else { - switch msg.(type) { - case *pgproto3.CommandComplete, *pgproto3.ErrorResponse: - return rst.Client().Send(msg) - default: - if err := rst.Client().Send(msg); err != nil { - return err - } + } + switch msg.(type) { + case *pgproto3.CommandComplete, *pgproto3.ErrorResponse: + return rst.Client().Send(msg) + default: + if err := rst.Client().Send(msg); err != nil { + return err } } } @@ -748,16 +806,21 @@ func (rst *RelayStateImpl) ProcQuery(query pgproto3.FrontendMessage, waitForResp return txstatus.TXERR, nil, false, err } + q := rst.qp.Stmt().(*lyx.Copy) + if err := func() error { + msg := &pgproto3.CopyData{Data: make([]byte, 0)} + route := &routingstate.DataShardRoute{} for { cpMsg, err := rst.Client().Receive() if err != nil { return err } - switch cpMsg.(type) { + switch newMsg := cpMsg.(type) { case *pgproto3.CopyData: - if err := rst.ProcCopy(cpMsg); err != nil { + msg.Data = append(msg.Data, newMsg.Data...) + if err = rst.ProcCopy(q, msg, route); err != nil { return err } case *pgproto3.CopyDone, *pgproto3.CopyFail: diff --git a/router/server/multishard.go b/router/server/multishard.go index 37e952df9..47dc69ab8 100644 --- a/router/server/multishard.go +++ b/router/server/multishard.go @@ -31,7 +31,8 @@ const ( RunningState ServerErrorState CommandCompleteState - CopyState + CopyOutState + CopyInState ) type MultiShardServer struct { @@ -178,7 +179,8 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { var saveRd *pgproto3.RowDescription = nil var saveCC *pgproto3.CommandComplete = nil var saveRFQ *pgproto3.ReadyForQuery = nil - /* Step one: ensure all shard backend are stared */ + var saveCIn *pgproto3.CopyInResponse = nil + /* Step one: ensure all shard backend are started */ for i := range m.activeShards { for { // all shards should be in rfq state @@ -200,12 +202,19 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { switch retMsg := msg.(type) { case *pgproto3.CopyOutResponse: - if m.multistate != InitialState && m.multistate != CopyState { + if m.multistate != InitialState && m.multistate != CopyOutState { return nil, MultiShardSyncBroken } m.states[i] = ShardCopyState - m.multistate = CopyState + m.multistate = CopyOutState m.copyBuf = append(m.copyBuf, retMsg) + case *pgproto3.CopyInResponse: + if m.multistate != InitialState && m.multistate != CopyInState { + return nil, MultiShardSyncBroken + } + m.states[i] = ShardCopyState + m.multistate = CopyInState + saveCIn = retMsg case *pgproto3.CommandComplete: m.states[i] = ShardCCState saveCC = retMsg // @@ -257,7 +266,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { m.multistate = InitialState return saveRFQ, nil } - if m.multistate == CopyState { + if m.multistate == CopyOutState { n := len(m.copyBuf) var currMsg *pgproto3.CopyOutResponse m.copyBuf, currMsg = m.copyBuf[n-2:], m.copyBuf[n-1] @@ -267,10 +276,14 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { Msg("miltishard server: flush copy buff") return currMsg, nil } + if m.multistate == CopyInState { + m.multistate = RunningState + return saveCIn, nil + } m.multistate = RunningState return saveRd, nil - case CopyState: + case CopyOutState: if len(m.copyBuf) > 0 { spqrlog.Zero.Debug().Msg("miltishard server: flush copy buff") n := len(m.copyBuf) @@ -302,7 +315,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { spqrlog.Zero.Info(). Uint("shard", m.activeShards[i].ID()). Type("message-type", msg). - Msg("multishard server: recived message from shard") + Msg("multishard server: received message from shard") switch msg.(type) { case *pgproto3.CommandComplete: @@ -322,6 +335,10 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { return &pgproto3.CommandComplete{ CommandTag: []byte{}, // XXX : fix this }, nil + case CopyInState: + return &pgproto3.CommandComplete{ + CommandTag: []byte{}, + }, nil case RunningState: /* Step two: fetch all datarow ms gs */ for i := range m.activeShards { diff --git a/test/regress/tests/router/expected/copy_routing.out b/test/regress/tests/router/expected/copy_routing.out index 89e5df1b7..114a33580 100644 --- a/test/regress/tests/router/expected/copy_routing.out +++ b/test/regress/tests/router/expected/copy_routing.out @@ -34,8 +34,8 @@ ALTER DISTRIBUTION ds1 ATTACH RELATION copy_test DISTRIBUTION KEY id; \c regress CREATE TABLE copy_test (id int); NOTICE: send query to shard(s) : sh1,sh2 -COPY copy_test FROM STDIN WHERE id <= 10; -NOTICE: send query to shard(s) : sh1 +COPY copy_test(id) FROM STDIN WHERE id <= 10; +NOTICE: send query to shard(s) : sh1,sh2 SELECT * FROM copy_test WHERE id <= 10; NOTICE: send query to shard(s) : sh1 id @@ -47,10 +47,24 @@ NOTICE: send query to shard(s) : sh1 5 (5 rows) -COPY copy_test FROM STDIN WHERE id <= 30; -NOTICE: send query to shard(s) : sh2 -SELECT * FROM copy_test WHERE id <= 30 ORDER BY copy_test; -NOTICE: send query to shard(s) : sh2 +COPY copy_test(id) FROM STDIN; +NOTICE: send query to shard(s) : sh1,sh2 +ERROR: client processing error: multishard copy is not supported, tx status IDLE +SELECT * FROM copy_test; +NOTICE: send query to shard(s) : sh1,sh2 + id +---- + 1 + 2 + 3 + 4 + 5 +(5 rows) + +COPY copy_test(id) FROM STDIN; +NOTICE: send query to shard(s) : sh1,sh2 +SELECT * FROM copy_test; +NOTICE: send query to shard(s) : sh1,sh2 id ---- 1 @@ -58,10 +72,12 @@ NOTICE: send query to shard(s) : sh2 3 4 5 - 12 - 22 - 23 -(8 rows) + 41 + 42 + 43 + 44 + 45 +(10 rows) DROP TABLE copy_test; NOTICE: send query to shard(s) : sh1,sh2 diff --git a/test/regress/tests/router/sql/copy_routing.sql b/test/regress/tests/router/sql/copy_routing.sql index 8e4010484..8b8ef16dc 100644 --- a/test/regress/tests/router/sql/copy_routing.sql +++ b/test/regress/tests/router/sql/copy_routing.sql @@ -7,20 +7,17 @@ ALTER DISTRIBUTION ds1 ATTACH RELATION copy_test DISTRIBUTION KEY id; \c regress CREATE TABLE copy_test (id int); -COPY copy_test FROM STDIN WHERE id <= 10; +COPY copy_test(id) FROM STDIN WHERE id <= 10; 1 2 3 4 5 -12 -3434 -43 \. SELECT * FROM copy_test WHERE id <= 10; -COPY copy_test FROM STDIN WHERE id <= 30; +COPY copy_test(id) FROM STDIN; 1 2 3 @@ -35,7 +32,17 @@ COPY copy_test FROM STDIN WHERE id <= 30; 43 \. -SELECT * FROM copy_test WHERE id <= 30 ORDER BY copy_test; +SELECT * FROM copy_test; + +COPY copy_test(id) FROM STDIN; +41 +42 +43 +44 +45 +\. + +SELECT * FROM copy_test; DROP TABLE copy_test;