diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index d206f58e17c..731cbde651d 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -107,6 +107,32 @@ func TestAggregateTypes(t *testing.T) { }) } +func TestTraceAggregateTypes(t *testing.T) { + mcmp, closer := start(t) + defer closer() + test := func(q string) { + res := utils.Exec(t, mcmp.VtConn, "vexplain trace "+q) + fmt.Printf("Query: %s\n", q) + fmt.Printf("Result: %s\n", res.Rows[0][0].ToString()) + } + mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") + mcmp.Exec("insert into aggr_test(id, val1, val2) values(6,'d',null), (7,'e',null), (8,'E',1)") + test("select val1, count(distinct val2), count(*) from aggr_test group by val1") + test("select val1, count(distinct val2), count(*) from aggr_test group by val1") + test("select val1, sum(distinct val2), sum(val2) from aggr_test group by val1") + test("select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1") + test("select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1 limit 4") + + test("select ascii(val1) as a, count(*) from aggr_test group by a") + test("select ascii(val1) as a, count(*) from aggr_test group by a order by a") + test("select ascii(val1) as a, count(*) from aggr_test group by a order by 2, a") + + test("select val1 as a, count(*) from aggr_test group by a") + test("select val1 as a, count(*) from aggr_test group by a order by a") + test("select val1 as a, count(*) from aggr_test group by a order by 2, a") + test("select sum(val1) from aggr_test") +} + func TestGroupBy(t *testing.T) { mcmp, closer := start(t) defer closer() diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index c4bf71cafa1..fd6136d3dce 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -17,7 +17,11 @@ limitations under the License. package union import ( + "fmt" + "golang.org/x/exp/rand" + "strings" "testing" + "time" "github.com/stretchr/testify/require" @@ -25,7 +29,7 @@ import ( "vitess.io/vitess/go/test/endtoend/utils" ) -func start(t *testing.T) (utils.MySQLCompare, func()) { +func start(t testing.TB) (utils.MySQLCompare, func()) { mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams) require.NoError(t, err) @@ -40,6 +44,33 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll() + err = utils.WaitForColumn(t, clusterInstance.VtgateProcess, keyspaceName, "region", `R_COMMENT`) + require.NoError(t, err) + + // Set the size parameter here. Increase for more data. + size := 100 + + // Seed the random number generator + rand.Seed(12345) + + // Generate dynamic data + regions := generateRegions() + nations := generateNations(size, regions) + suppliers := generateSuppliers(size, nations) + parts := generateParts(size) + customers := generateCustomers(size, nations) + orders := generateOrders(size, customers) + lineitems := generateLineItems(orders, parts, suppliers) + + // Execute inserts + mcmp.Exec(buildInsertQuery("region", regions)) + mcmp.Exec(buildInsertQuery("nation", nations)) + mcmp.Exec(buildInsertQuery("supplier", suppliers)) + mcmp.Exec(buildInsertQuery("part", parts)) + mcmp.Exec(buildInsertQuery("customer", customers)) + mcmp.Exec(buildInsertQuery("orders", orders)) + mcmp.Exec(buildInsertQuery("lineitem", lineitems)) + return mcmp, func() { deleteAll() mcmp.Close() @@ -234,3 +265,226 @@ from (select l.l_extendedprice * o.o_totalprice }) } } + +func BenchmarkQuery(b *testing.B) { + mcmp, closer := start(b) + defer closer() + + for i := 0; i < b.N; i++ { + _ = utils.Exec(b, mcmp.VtConn, "vexplain trace "+q) + } +} + +const q = `SELECT + o.o_orderpriority, + EXTRACT(YEAR FROM o.o_orderdate) AS order_year, + COUNT(DISTINCT o.o_orderkey) AS order_count, + SUM(l.l_extendedprice * (1 - l.l_discount)) AS total_revenue +FROM + orders o +JOIN + lineitem l ON o.o_orderkey > l.l_orderkey +WHERE + o.o_orderdate BETWEEN '1995-01-01' AND '1996-12-31' +GROUP BY + o.o_orderpriority, + EXTRACT(YEAR FROM o.o_orderdate) +ORDER BY + o.o_orderpriority, + order_year` + +func TestVexplain(t *testing.T) { + mcmp, closer := start(t) + defer closer() + err := utils.WaitForColumn(t, clusterInstance.VtgateProcess, keyspaceName, "region", `R_COMMENT`) + require.NoError(t, err) + + res := utils.Exec(t, mcmp.VtConn, "vexplain trace "+q) + fmt.Printf("Query: %s\n", q) + fmt.Printf("Result: %s\n", res.Rows[0][0].ToString()) +} + +func generateRegions() [][]interface{} { + regions := [][]interface{}{ + {1, "AMERICA", "New World"}, + {2, "ASIA", "Eastern Asia"}, + {3, "EUROPE", "Old World"}, + {4, "AFRICA", "Dark Continent"}, + {5, "AUSTRALIA", "Down Under"}, + } + return regions +} + +func generateNations(size int, regions [][]interface{}) [][]interface{} { + var nations [][]interface{} + for i := 0; i < size/5; i++ { + for _, region := range regions { + nationKey := len(nations) + 1 + regionKey := region[0].(int) + name := fmt.Sprintf("Nation_%d_%d", regionKey, i) + if regionKey == 1 && i == 0 { + name = "BRAZIL" + } + nations = append(nations, []interface{}{nationKey, name, regionKey, fmt.Sprintf("Comment for %s", name)}) + } + } + return nations +} + +func generateSuppliers(size int, nations [][]interface{}) [][]interface{} { + var suppliers [][]interface{} + for i := 0; i < size; i++ { + nation := nations[rand.Intn(len(nations))] + suppliers = append(suppliers, []interface{}{ + i + 1, + fmt.Sprintf("Supplier_%d", i+1), + fmt.Sprintf("Address_%d", i+1), + nation[0], + fmt.Sprintf("%d-123-4567", rand.Intn(100)), + float64(rand.Intn(10000)) + rand.Float64(), + fmt.Sprintf("Comment for Supplier_%d", i+1), + }) + } + return suppliers +} + +func generateParts(size int) [][]interface{} { + var parts [][]interface{} + types := []string{"ECONOMY ANODIZED STEEL", "LARGE BRUSHED BRASS", "STANDARD POLISHED COPPER", "SMALL PLATED STEEL", "MEDIUM BURNISHED TIN"} + for i := 0; i < size; i++ { + parts = append(parts, []interface{}{ + i + 1, + fmt.Sprintf("Part_%d", i+1), + fmt.Sprintf("Manufacturer_%d", rand.Intn(5)+1), + fmt.Sprintf("Brand_%d", rand.Intn(5)+1), + types[rand.Intn(len(types))], + rand.Intn(50) + 1, + fmt.Sprintf("%s BOX", []string{"SM", "LG", "MED", "JUMBO", "WRAP"}[rand.Intn(5)]), + float64(rand.Intn(1000)) + rand.Float64(), + fmt.Sprintf("Comment for Part_%d", i+1), + }) + } + return parts +} + +func generateCustomers(size int, nations [][]interface{}) [][]interface{} { + var customers [][]interface{} + for i := 0; i < size; i++ { + nation := nations[rand.Intn(len(nations))] + customers = append(customers, []interface{}{ + i + 1, + fmt.Sprintf("Customer_%d", i+1), + fmt.Sprintf("Address_%d", i+1), + nation[0], + fmt.Sprintf("%d-987-6543", rand.Intn(100)), + float64(rand.Intn(10000)) + rand.Float64(), + []string{"AUTOMOBILE", "BUILDING", "FURNITURE", "MACHINERY", "HOUSEHOLD"}[rand.Intn(5)], + fmt.Sprintf("Comment for Customer_%d", i+1), + }) + } + return customers +} + +func generateOrders(size int, customers [][]interface{}) [][]interface{} { + var orders [][]interface{} + startDate := time.Date(1995, 1, 1, 0, 0, 0, 0, time.UTC) + endDate := time.Date(1996, 12, 31, 0, 0, 0, 0, time.UTC) + for i := 0; i < size*10; i++ { + customer := customers[rand.Intn(len(customers))] + orderDate := startDate.Add(time.Duration(rand.Int63n(int64(endDate.Sub(startDate))))) + orders = append(orders, []interface{}{ + i + 1, + customer[0], + []string{"O", "F", "P"}[rand.Intn(3)], + float64(rand.Intn(100000)) + rand.Float64(), + orderDate.Format("2006-01-02"), + fmt.Sprintf("%d-URGENT", rand.Intn(5)+1), + fmt.Sprintf("Clerk#%05d", rand.Intn(1000)), + rand.Intn(5), + fmt.Sprintf("Comment for Order_%d", i+1), + }) + } + return orders +} + +func generateLineItems(orders [][]interface{}, parts [][]interface{}, suppliers [][]interface{}) [][]interface{} { + var lineItems [][]interface{} + for _, order := range orders { + for j := 0; j < rand.Intn(7)+1; j++ { + part := parts[rand.Intn(len(parts))] + supplier := suppliers[rand.Intn(len(suppliers))] + orderDate, _ := time.Parse("2006-01-02", order[4].(string)) + shipDate := orderDate.Add(time.Duration(rand.Intn(30)) * 24 * time.Hour) + commitDate := orderDate.Add(time.Duration(rand.Intn(30)) * 24 * time.Hour) + receiptDate := shipDate.Add(time.Duration(rand.Intn(30)) * 24 * time.Hour) + lineItems = append(lineItems, []interface{}{ + order[0], + part[0], + supplier[0], + j + 1, + rand.Intn(50) + 1, + float64(rand.Intn(100000)) + rand.Float64(), + rand.Float64(), + rand.Float64(), + []string{"N", "R", "A"}[rand.Intn(3)], + []string{"O", "F"}[rand.Intn(2)], + shipDate.Format("2006-01-02"), + commitDate.Format("2006-01-02"), + receiptDate.Format("2006-01-02"), + []string{"DELIVER IN PERSON", "COLLECT COD", "NONE", "TAKE BACK RETURN"}[rand.Intn(4)], + []string{"TRUCK", "MAIL", "RAIL", "AIR", "SHIP"}[rand.Intn(5)], + fmt.Sprintf("Comment for Lineitem_%d_%d", order[0], j+1), + }) + } + } + return lineItems +} + +func buildInsertQuery(tableName string, data [][]interface{}) string { + if len(data) == 0 { + return "" + } + columns := getColumns(tableName) + valueStrings := make([]string, 0, len(data)) + for _, row := range data { + valueStrings = append(valueStrings, formatRow(row)) + } + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", tableName, strings.Join(columns, ", "), strings.Join(valueStrings, ",\n")) + return query +} + +func getColumns(tableName string) []string { + switch tableName { + case "region": + return []string{"R_REGIONKEY", "R_NAME", "R_COMMENT"} + case "nation": + return []string{"N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"} + case "supplier": + return []string{"S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"} + case "part": + return []string{"P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"} + case "customer": + return []string{"C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"} + case "orders": + return []string{"O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"} + case "lineitem": + return []string{"L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"} + default: + return []string{} + } +} + +func formatRow(row []interface{}) string { + values := make([]string, len(row)) + for i, v := range row { + switch v := v.(type) { + case string: + values[i] = fmt.Sprintf("'%s'", strings.Replace(v, "'", "''", -1)) + case float64: + values[i] = fmt.Sprintf("%.2f", v) + default: + values[i] = fmt.Sprintf("%v", v) + } + } + return "(" + strings.Join(values, ", ") + ")" +} diff --git a/go/vt/vtgate/engine/plan.go b/go/vt/vtgate/engine/plan.go index 769c69aaa06..9ea9f07655c 100644 --- a/go/vt/vtgate/engine/plan.go +++ b/go/vt/vtgate/engine/plan.go @@ -72,7 +72,7 @@ func (p *Plan) Stats() (execCount uint64, execTime time.Duration, shardQueries, func (p *Plan) MarshalJSON() ([]byte, error) { var instructions *PrimitiveDescription if p.Instructions != nil { - description := PrimitiveToPlanDescription(p.Instructions) + description := PrimitiveToPlanDescription(p.Instructions, nil) instructions = &description } diff --git a/go/vt/vtgate/engine/plan_description.go b/go/vt/vtgate/engine/plan_description.go index 9edeae0453a..863a93b0608 100644 --- a/go/vt/vtgate/engine/plan_description.go +++ b/go/vt/vtgate/engine/plan_description.go @@ -45,7 +45,6 @@ type PrimitiveDescription struct { TargetTabletType topodatapb.TabletType Other map[string]any - ID PrimitiveID InputName string Inputs []PrimitiveDescription @@ -97,7 +96,11 @@ func (pd PrimitiveDescription) MarshalJSON() ([]byte, error) { if err := marshalAdd(prepend, buf, "NoOfCalls", len(pd.Stats)); err != nil { return nil, err } - if err := marshalAdd(prepend, buf, "Rows", pd.Stats); err != nil { + + if err := marshalAdd(prepend, buf, "AvgRowSize", average(pd.Stats)); err != nil { + return nil, err + } + if err := marshalAdd(prepend, buf, "MedianRowSize", median(pd.Stats)); err != nil { return nil, err } } @@ -117,6 +120,28 @@ func (pd PrimitiveDescription) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } +func average(nums []int) float64 { + total := 0 + for _, num := range nums { + total += num + } + return float64(total) / float64(len(nums)) +} + +func median(nums []int) float64 { + sortedNums := make([]int, len(nums)) + copy(sortedNums, nums) + sort.Ints(sortedNums) + + n := len(sortedNums) + if n%2 == 0 { + mid1 := sortedNums[n/2-1] + mid2 := sortedNums[n/2] + return float64(mid1+mid2) / 2.0 + } + return float64(sortedNums[n/2]) +} + func (pd PrimitiveDescription) addToGraph(g *graphviz.Graph) (*graphviz.Node, error) { var nodes []*graphviz.Node for _, input := range pd.Inputs { @@ -157,7 +182,7 @@ func (pd PrimitiveDescription) addToGraph(g *graphviz.Graph) (*graphviz.Node, er func GraphViz(p Primitive) (*graphviz.Graph, error) { g := graphviz.New() - description := PrimitiveToPlanDescription(p) + description := PrimitiveToPlanDescription(p, nil) _, err := description.addToGraph(g) if err != nil { return nil, err @@ -193,15 +218,16 @@ func marshalAdd(prepend string, buf *bytes.Buffer, name string, obj any) error { } // PrimitiveToPlanDescription transforms a primitive tree into a corresponding PlanDescription tree -func PrimitiveToPlanDescription(in Primitive) PrimitiveDescription { +// If stats is not nil, it will be used to populate the stats field of the PlanDescription +func PrimitiveToPlanDescription(in Primitive, stats map[int]RowsReceived) PrimitiveDescription { this := in.description() - if id := in.GetID(); id > 0 { - this.ID = id + if id := in.GetID(); stats != nil && id > 0 { + this.Stats = stats[int(id)] } inputs, infos := in.Inputs() for idx, input := range inputs { - pd := PrimitiveToPlanDescription(input) + pd := PrimitiveToPlanDescription(input, stats) if infos != nil { for k, v := range infos[idx] { if k == inputName { diff --git a/go/vt/vtgate/engine/plan_description_test.go b/go/vt/vtgate/engine/plan_description_test.go index dfed7d7f675..9f20e37976a 100644 --- a/go/vt/vtgate/engine/plan_description_test.go +++ b/go/vt/vtgate/engine/plan_description_test.go @@ -31,7 +31,7 @@ import ( func TestCreateRoutePlanDescription(t *testing.T) { route := createRoute() - planDescription := PrimitiveToPlanDescription(route) + planDescription := PrimitiveToPlanDescription(route, nil) expected := PrimitiveDescription{ OperatorType: "Route", @@ -76,7 +76,7 @@ func TestPlanDescriptionWithInputs(t *testing.T) { Input: route, } - planDescription := PrimitiveToPlanDescription(limit) + planDescription := PrimitiveToPlanDescription(limit, nil) expected := PrimitiveDescription{ OperatorType: "Limit", diff --git a/go/vt/vtgate/engine/trace.go b/go/vt/vtgate/engine/trace.go index ec5249c07da..a30407fb3b1 100644 --- a/go/vt/vtgate/engine/trace.go +++ b/go/vt/vtgate/engine/trace.go @@ -64,16 +64,9 @@ func (t *Trace) NeedsTransaction() bool { return t.Inner.NeedsTransaction() } -func preWalk(desc *PrimitiveDescription, f func(*PrimitiveDescription)) { - f(desc) - for _, input := range desc.Inputs { - preWalk(&input, f) - } -} - func (t *Trace) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { getOpStats := vcursor.StartPrimitiveTrace() - _, err := t.Inner.TryExecute(ctx, vcursor, bindVars, wantfields) + _, err := vcursor.ExecutePrimitive(ctx, t.Inner, bindVars, wantfields) if err != nil { return nil, err } @@ -84,7 +77,7 @@ func (t *Trace) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st func (t *Trace) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { getOpsStats := vcursor.StartPrimitiveTrace() noop := func(result *sqltypes.Result) error { return nil } - err := t.Inner.TryStreamExecute(ctx, vcursor, bindVars, wantfields, noop) + err := vcursor.StreamExecutePrimitive(ctx, t.Inner, bindVars, wantfields, noop) if err != nil { return err } @@ -98,17 +91,7 @@ func (t *Trace) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars } func (t *Trace) getExplainTraceOutput(getOpStats func() map[int]RowsReceived) (*sqltypes.Result, error) { - description := PrimitiveToPlanDescription(t.Inner) - statsMap := getOpStats() - - // let's add the stats to the description - preWalk(&description, func(desc *PrimitiveDescription) { - stats, found := statsMap[int(desc.ID)] - if !found { - return - } - desc.Stats = stats - }) + description := PrimitiveToPlanDescription(t.Inner, getOpStats()) output, err := json.MarshalIndent(description, "", "\t") if err != nil { diff --git a/go/vt/vtgate/executor_vexplain_test.go b/go/vt/vtgate/executor_vexplain_test.go index a8691142b44..642656271c9 100644 --- a/go/vt/vtgate/executor_vexplain_test.go +++ b/go/vt/vtgate/executor_vexplain_test.go @@ -74,7 +74,7 @@ func TestSimpleVexplainTrace(t *testing.T) { executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - query := "vexplain trace select col1, col2 from music order by col2 desc" + query := "vexplain trace select count(*), col2 from music group by col2" session := &vtgatepb.Session{ TargetString: "@primary", } @@ -82,7 +82,7 @@ func TestSimpleVexplainTrace(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from music order by music.col2 desc", + Sql: "select count(*), col2, weight_string(col2) from music group by col2, weight_string(col2) order by col2 asc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -90,21 +90,34 @@ func TestSimpleVexplainTrace(t *testing.T) { } expectedRowString := `{ - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "TestExecutor", - "Sharded": true - }, + "OperatorType": "Aggregate", + "Variant": "Ordered", "NoOfCalls": 1, "Rows": [ - 16 + 4 ], - "FieldQuery": "select col1, col2, weight_string(col2) from music where 1 != 1", - "OrderBy": "(1|2) DESC", - "Query": "select col1, col2, weight_string(col2) from music order by music.col2 desc", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "(1|2)", "ResultColumns": 2, - "Table": "music" + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "TestExecutor", + "Sharded": true + }, + "NoOfCalls": 2, + "Rows": [ + 16, + 16 + ], + "FieldQuery": "select count(*), col2, weight_string(col2) from music where 1 != 1 group by col2, weight_string(col2)", + "OrderBy": "(1|2) ASC", + "Query": "select count(*), col2, weight_string(col2) from music group by col2, weight_string(col2) order by col2 asc", + "Table": "music" + } + ] }` gotRowString := gotResult.Rows[0][0].ToString() diff --git a/go/vt/vtgate/planbuilder/vexplain.go b/go/vt/vtgate/planbuilder/vexplain.go index c27002a29e9..1b3a9754ad7 100644 --- a/go/vt/vtgate/planbuilder/vexplain.go +++ b/go/vt/vtgate/planbuilder/vexplain.go @@ -89,7 +89,7 @@ func buildVExplainVtgatePlan(ctx context.Context, explainStatement sqlparser.Sta if err != nil { return nil, err } - description := engine.PrimitiveToPlanDescription(innerInstruction.primitive) + description := engine.PrimitiveToPlanDescription(innerInstruction.primitive, nil) output, err := json.MarshalIndent(description, "", "\t") if err != nil { return nil, err