Skip to content

Commit

Permalink
Move UNION planning to the operators (#13450)
Browse files Browse the repository at this point in the history
Co-authored-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
systay and harshit-gangal authored Aug 8, 2023
1 parent 5ca1064 commit 53b3f80
Show file tree
Hide file tree
Showing 69 changed files with 4,028 additions and 2,452 deletions.
18 changes: 16 additions & 2 deletions go/slices2/slices.go → go/slice/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

// Package slices2 contains generic Slice helpers;
// Package slice contains generic Slice helpers;
// Some of this code is sourced from https://github.com/luraim/fun (Apache v2)
package slices2
package slice

// All returns true if all elements return true for given predicate
func All[T any](s []T, fn func(T) bool) bool {
Expand Down Expand Up @@ -48,3 +48,17 @@ func Map[From, To any](in []From, f func(From) To) []To {
}
return result
}

func MapWithError[From, To any](in []From, f func(From) (To, error)) (result []To, err error) {
if in == nil {
return nil, nil
}
result = make([]To, len(in))
for i, col := range in {
result[i], err = f(col)
if err != nil {
return nil, err
}
}
return
}
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func TestDistinct(t *testing.T) {
utils.Exec(t, mcmp.VtConn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`)

// multi distinct
utils.AssertMatches(t, mcmp.VtConn, `select distinct tcol1, tcol2 from t2`,
utils.AssertMatchesNoOrder(t, mcmp.VtConn, `select distinct tcol1, tcol2 from t2`,
`[[VARCHAR("A") VARCHAR("A")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")]]`)
}

Expand Down
246 changes: 124 additions & 122 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/random/query_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

"golang.org/x/exp/slices"

"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/sqlparser"
)
Expand Down Expand Up @@ -459,7 +459,7 @@ func createLimit() *sqlparser.Limit {
// returns a random expression and its type
func getRandomExpr(tables []tableT) sqlparser.Expr {
seed := time.Now().UnixNano()
g := sqlparser.NewGenerator(seed, 2, slices2.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...)
g := sqlparser.NewGenerator(seed, 2, slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...)
return g.Expression()
}

Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/random/random_expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"testing"
"time"

"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
)

Expand All @@ -47,6 +47,6 @@ func TestRandomExprWithTables(t *testing.T) {
}...)

seed := time.Now().UnixNano()
g := sqlparser.NewGenerator(seed, 3, slices2.Map(schemaTables, func(t tableT) sqlparser.ExprGenerator { return &t })...)
g := sqlparser.NewGenerator(seed, 3, slice.Map(schemaTables, func(t tableT) sqlparser.ExprGenerator { return &t })...)
g.Expression()
}
4 changes: 2 additions & 2 deletions go/viperutil/debug/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/spf13/viper"

"vitess.io/vitess/go/acl"
"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/viperutil/internal/registry"
)

Expand All @@ -53,7 +53,7 @@ func HandlerFunc(w http.ResponseWriter, r *http.Request) {
switch {
case format == "":
v.DebugTo(w)
case slices2.Any(viper.SupportedExts, func(ext string) bool { return ext == format }):
case slice.Any(viper.SupportedExts, func(ext string) bool { return ext == format }):
// Got a supported format; write the config to a tempfile in that format,
// then copy it to the response.
//
Expand Down
1 change: 1 addition & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type (
GetColumnCount() int
GetColumns() SelectExprs
Commented
IsDistinct() bool
}

// DDLStatement represents any DDL Statement
Expand Down
35 changes: 35 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,11 @@ func (node *Select) MakeDistinct() {
node.Distinct = true
}

// IsDistinct implements the SelectStatement interface
func (node *Select) IsDistinct() bool {
return node.Distinct
}

// GetColumnCount return SelectExprs count.
func (node *Select) GetColumnCount() int {
return len(node.SelectExprs)
Expand Down Expand Up @@ -1192,6 +1197,11 @@ func (node *Union) MakeDistinct() {
node.Distinct = true
}

// IsDistinct implements the SelectStatement interface
func (node *Union) IsDistinct() bool {
return node.Distinct
}

// GetColumnCount implements the SelectStatement interface
func (node *Union) GetColumnCount() int {
return node.Left.GetColumnCount()
Expand Down Expand Up @@ -2502,3 +2512,28 @@ func MakeColumns(colNames ...string) Columns {
}
return cols
}

func VisitAllSelects(in SelectStatement, f func(p *Select, idx int) error) error {
v := visitor{}
return v.visitAllSelects(in, f)
}

type visitor struct {
idx int
}

func (v *visitor) visitAllSelects(in SelectStatement, f func(p *Select, idx int) error) error {
switch sel := in.(type) {
case *Select:
err := f(sel, v.idx)
v.idx++
return err
case *Union:
err := v.visitAllSelects(sel.Left, f)
if err != nil {
return err
}
return v.visitAllSelects(sel.Right, f)
}
panic("switch should be exhaustive")
}
7 changes: 3 additions & 4 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ import (
"fmt"
"strconv"

"vitess.io/vitess/go/vt/vterrors"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/sqltypes"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)
Expand Down Expand Up @@ -279,7 +278,7 @@ func convertFinal(current []sqltypes.Value, aggregates []*AggregateParams) ([]sq
}

func convertFields(fields []*querypb.Field, aggrs []*AggregateParams) []*querypb.Field {
fields = slices2.Map(fields, func(from *querypb.Field) *querypb.Field {
fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field {
return proto.Clone(from).(*querypb.Field)
})
for _, aggr := range aggrs {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ package engine
import (
"context"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

var _ Primitive = (*Filter)(nil)
Expand All @@ -35,6 +33,8 @@ type Filter struct {
ASTPredicate sqlparser.Expr
Input Primitive

Truncate int

noTxNeeded
}

Expand Down Expand Up @@ -73,7 +73,7 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s
}
}
result.Rows = rows
return result, nil
return result.Truncate(f.Truncate), nil
}

// TryStreamExecute satisfies the Primitive interface.
Expand All @@ -96,7 +96,7 @@ func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
}
}
results.Rows = rows
return callback(results)
return callback(results.Truncate(f.Truncate))
}

return vcursor.StreamExecutePrimitive(ctx, f.Input, bindVars, wantfields, filter)
Expand All @@ -114,7 +114,8 @@ func (f *Filter) Inputs() []Primitive {

func (f *Filter) description() PrimitiveDescription {
other := map[string]any{
"Predicate": sqlparser.String(f.ASTPredicate),
"Predicate": sqlparser.String(f.ASTPredicate),
"ResultColumns": f.Truncate,
}

return PrimitiveDescription{
Expand Down
10 changes: 4 additions & 6 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,16 @@ import (

"github.com/google/uuid"

"vitess.io/vitess/go/mysql/hex"

"vitess.io/vitess/go/mysql/icuregex"

"vitess.io/vitess/go/hack"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/collations/charset"
"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/decimal"
"vitess.io/vitess/go/mysql/fastparse"
"vitess.io/vitess/go/mysql/hex"
"vitess.io/vitess/go/mysql/icuregex"
"vitess.io/vitess/go/mysql/json"
"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
Expand Down Expand Up @@ -2172,7 +2170,7 @@ func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path)
}

func (asm *assembler) Fn_JSON_EXTRACT0(jp []*json.Path) {
multi := len(jp) > 1 || slices2.Any(jp, func(path *json.Path) bool { return path.ContainsWildcards() })
multi := len(jp) > 1 || slice.Any(jp, func(path *json.Path) bool { return path.ContainsWildcards() })

if multi {
asm.emit(func(env *ExpressionEnv) int {
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/fn_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package evalengine
import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/json"
"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
Expand Down Expand Up @@ -131,7 +131,7 @@ func (call *builtinJSONExtract) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

if slices2.All(call.Arguments[1:], func(expr Expr) bool { return expr.constant() }) {
if slice.All(call.Arguments[1:], func(expr Expr) bool { return expr.constant() }) {
paths := make([]*json.Path, 0, len(call.Arguments[1:]))

for _, arg := range call.Arguments[1:] {
Expand Down Expand Up @@ -406,7 +406,7 @@ func (call *builtinJSONContainsPath) compile(c *compiler) (ctype, error) {
return ctype{}, c.unsupported(call)
}

if !slices2.All(call.Arguments[2:], func(expr Expr) bool { return expr.constant() }) {
if !slice.All(call.Arguments[2:], func(expr Expr) bool { return expr.constant() }) {
return ctype{}, c.unsupported(call)
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (ast *astCompiler) translateColOffset(col *sqlparser.Offset) (Expr, error)

func (ast *astCompiler) translateColName(colname *sqlparser.ColName) (Expr, error) {
if ast.cfg.ResolveColumn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot lookup column (column access not supported here)")
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot lookup column '%s' (column access not supported here)", sqlparser.String(colname))
}
idx, err := ast.cfg.ResolveColumn(colname)
if err != nil {
Expand Down
24 changes: 12 additions & 12 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3642,15 +3642,15 @@ func TestSelectAggregationNoData(t *testing.T) {
},
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1", "int64")),
expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(0)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|1|weight_string(col2)", "int64|int64|varbinary")),
expSandboxQ: "select col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[]`,
},
Expand Down Expand Up @@ -3726,29 +3726,29 @@ func TestSelectAggregationData(t *testing.T) {
},
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1"),
expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|1|weight_string(col2)", "int64|int64|varbinary"), "3|1|NULL", "2|1|NULL"),
expSandboxQ: "select col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`,
},
{
sql: `select count(col1) from (select id, col1 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1", "varchar"), "a", "b"),
expSandboxQ: "select col1 from (select id, col1 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"),
expSandboxQ: "select id, col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col2)", "varchar|int64|varbinary"), "a|3|NULL", "b|2|NULL"),
expSandboxQ: "select col1, col2, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"),
expSandboxQ: "select col2, col1, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`,
expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`,
},
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/collations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func TestOrderedAggregateCollations(t *testing.T) {
collations: []collationInTable{{ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol1"}},
query: "select distinct textcol1 from user",
check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) {
oa, isOA := primitive.(*engine.OrderedAggregate)
require.True(t, isOA, "should be an OrderedAggregate")
require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].CollationID)
distinct, isDistinct := primitive.(*engine.Distinct)
require.True(t, isDistinct, "should be a distinct")
require.Equal(t, collid(colls[0].collationName), distinct.CheckCols[0].Collation)
},
},
{
Expand Down
8 changes: 0 additions & 8 deletions go/vt/vtgate/planbuilder/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ func newDistinct(source logicalPlan, checkCols []engine.CheckCol, truncateColumn
}
}

func newDistinctGen4Legacy(source logicalPlan, checkCols []engine.CheckCol, needToTruncate bool) logicalPlan {
return &distinct{
logicalPlanCommon: newBuilderCommon(source),
checkCols: checkCols,
needToTruncate: needToTruncate,
}
}

func (d *distinct) Primitive() engine.Primitive {
truncate := d.truncateColumn
if d.needToTruncate {
Expand Down
Loading

0 comments on commit 53b3f80

Please sign in to comment.