diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 99cf5ff6273..9eb682cfd4d 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -571,9 +571,12 @@ const ( // SSLockDeadlock is ER_LOCK_DEADLOCK SSLockDeadlock = "40001" - //SSClientError is the state on client errors + // SSClientError is the state on client errors SSClientError = "42000" + // SSDupFieldName is ER_DUP_FIELD_NAME + SSDupFieldName = "42S21" + // SSBadFieldError is ER_BAD_FIELD_ERROR SSBadFieldError = "42S22" diff --git a/go/mysql/sql_error.go b/go/mysql/sql_error.go index 509663ae78f..8c9905ae1a2 100644 --- a/go/mysql/sql_error.go +++ b/go/mysql/sql_error.go @@ -168,6 +168,7 @@ var stateToMysqlCode = map[vterrors.State]struct { vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, vterrors.DbDropExists: {num: ERDbDropExists, state: SSUnknownSQLState}, + vterrors.DupFieldName: {num: ERDupFieldName, state: SSDupFieldName}, vterrors.EmptyQuery: {num: EREmptyQuery, state: SSClientError}, vterrors.IncorrectGlobalLocalVar: {num: ERIncorrectGlobalLocalVar, state: SSUnknownSQLState}, vterrors.InnodbReadOnly: {num: ERInnodbReadOnly, state: SSUnknownSQLState}, diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 00c73231d51..eb1657ae537 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -41,6 +41,7 @@ const ( WrongValueForVar LockOrActiveTransaction MixOfGroupFuncAndFields + DupFieldName // failed precondition NoDB diff --git a/go/vt/vtgate/planbuilder/abstract/derived.go b/go/vt/vtgate/planbuilder/abstract/derived.go new file mode 100644 index 00000000000..49eb2872905 --- /dev/null +++ b/go/vt/vtgate/planbuilder/abstract/derived.go @@ -0,0 +1,50 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package abstract + +import ( + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +// Derived represents a derived table in the query +type Derived struct { + Sel *sqlparser.Select + Inner Operator + Alias string +} + +var _ Operator = (*Derived)(nil) + +// TableID implements the Operator interface +func (d *Derived) TableID() semantics.TableSet { + return d.Inner.TableID() +} + +// PushPredicate implements the Operator interface +func (d *Derived) PushPredicate(expr sqlparser.Expr, semTable *semantics.SemTable) error { + tableInfo, err := semTable.TableInfoForExpr(expr) + if err != nil { + return err + } + + newExpr, err := semantics.RewriteDerivedExpression(expr, tableInfo) + if err != nil { + return err + } + return d.Inner.PushPredicate(newExpr, semTable) +} diff --git a/go/vt/vtgate/planbuilder/abstract/join.go b/go/vt/vtgate/planbuilder/abstract/join.go index 1ca55868009..dd49dca8ed1 100644 --- a/go/vt/vtgate/planbuilder/abstract/join.go +++ b/go/vt/vtgate/planbuilder/abstract/join.go @@ -27,9 +27,11 @@ type Join struct { Exp sqlparser.Expr } +var _ Operator = (*Join)(nil) + // PushPredicate implements the Operator interface func (j *Join) PushPredicate(expr sqlparser.Expr, semTable *semantics.SemTable) error { - deps := semTable.Dependencies(expr) + deps := semTable.GetBaseTableDependencies(expr) switch { case deps.IsSolvedBy(j.LHS.TableID()): return j.LHS.PushPredicate(expr, semTable) diff --git a/go/vt/vtgate/planbuilder/abstract/outerjoin.go b/go/vt/vtgate/planbuilder/abstract/left_join.go similarity index 93% rename from go/vt/vtgate/planbuilder/abstract/outerjoin.go rename to go/vt/vtgate/planbuilder/abstract/left_join.go index ba75e40c838..886e9691be4 100644 --- a/go/vt/vtgate/planbuilder/abstract/outerjoin.go +++ b/go/vt/vtgate/planbuilder/abstract/left_join.go @@ -27,9 +27,11 @@ type LeftJoin struct { Predicate sqlparser.Expr } +var _ Operator = (*LeftJoin)(nil) + // PushPredicate implements the Operator interface func (oj *LeftJoin) PushPredicate(expr sqlparser.Expr, semTable *semantics.SemTable) error { - deps := semTable.Dependencies(expr) + deps := semTable.GetBaseTableDependencies(expr) if deps.IsSolvedBy(oj.Left.TableID()) { return oj.Left.PushPredicate(expr, semTable) } diff --git a/go/vt/vtgate/planbuilder/abstract/operator.go b/go/vt/vtgate/planbuilder/abstract/operator.go index c1e9395db46..3021b08c7de 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator.go +++ b/go/vt/vtgate/planbuilder/abstract/operator.go @@ -24,6 +24,7 @@ import ( type ( // Operator forms the tree of operators, representing the declarative query provided. // An operator can be: + // * Derived - which represents an expression that generates a table. // * QueryGraph - which represents a group of tables and predicates that can be evaluated in any order // while still preserving the results // * LeftJoin - A left join. These can't be evaluated in any order, so we keep them separate @@ -40,17 +41,31 @@ type ( func getOperatorFromTableExpr(tableExpr sqlparser.TableExpr, semTable *semantics.SemTable) (Operator, error) { switch tableExpr := tableExpr.(type) { case *sqlparser.AliasedTableExpr: - qg := newQueryGraph() - tableName := tableExpr.Expr.(sqlparser.TableName) - tableID := semTable.TableSetFor(tableExpr) - tableInfo, err := semTable.TableInfoFor(tableID) - if err != nil { - return nil, err + switch tbl := tableExpr.Expr.(type) { + case sqlparser.TableName: + qg := newQueryGraph() + tableID := semTable.TableSetFor(tableExpr) + tableInfo, err := semTable.TableInfoFor(tableID) + if err != nil { + return nil, err + } + isInfSchema := tableInfo.IsInfSchema() + qt := &QueryTable{Alias: tableExpr, Table: tbl, TableID: tableID, IsInfSchema: isInfSchema} + qg.Tables = append(qg.Tables, qt) + return qg, nil + case *sqlparser.DerivedTable: + sel, isSel := tbl.Select.(*sqlparser.Select) + if !isSel { + return nil, semantics.Gen4NotSupportedF("UNION") + } + inner, err := CreateOperatorFromSelect(sel, semTable) + if err != nil { + return nil, err + } + return &Derived{Alias: tableExpr.As.String(), Inner: inner, Sel: sel}, nil + default: + return nil, semantics.Gen4NotSupportedF("%T", tbl) } - isInfSchema := tableInfo.IsInfSchema() - qt := &QueryTable{Alias: tableExpr, Table: tableName, TableID: tableID, IsInfSchema: isInfSchema} - qg.Tables = append(qg.Tables, qt) - return qg, nil case *sqlparser.JoinTableExpr: switch tableExpr.Join { case sqlparser.NormalJoinType: diff --git a/go/vt/vtgate/planbuilder/abstract/operator_test.go b/go/vt/vtgate/planbuilder/abstract/operator_test.go index bf9971cf973..003e813d1fe 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator_test.go +++ b/go/vt/vtgate/planbuilder/abstract/operator_test.go @@ -154,6 +154,34 @@ JoinPredicates: Predicate: c.id = d.id } Predicate: a.id = c.id +}`, + }, { + input: "select 1 from (select 42 as id from tbl) as t", + output: `Derived t: { + Query: select 42 as id from tbl + Inner: QueryGraph: { + Tables: + 1:tbl + } +}`, + }, { + input: "select 1 from (select id from tbl limit 10) as t join (select foo, count(*) from usr group by foo) as s on t.id = s.foo", + output: `Join: { + LHS: Derived t: { + Query: select id from tbl limit 10 + Inner: QueryGraph: { + Tables: + 1:tbl + } + } + RHS: Derived s: { + Query: select foo, count(*) from usr group by foo + Inner: QueryGraph: { + Tables: + 4:usr + } + } + Predicate: t.id = s.foo }`, }} @@ -186,6 +214,10 @@ func testString(op Operator) string { leftStr := indent(testString(op.Left)) rightStr := indent(testString(op.Right)) return fmt.Sprintf("OuterJoin: {\n\tInner: %s\n\tOuter: %s\n\tPredicate: %s\n}", leftStr, rightStr, sqlparser.String(op.Predicate)) + case *Derived: + inner := indent(testString(op.Inner)) + query := sqlparser.String(op.Sel) + return fmt.Sprintf("Derived %s: {\n\tQuery: %s\n\tInner:%s\n}", op.Alias, query, inner) } return "implement me" } @@ -198,6 +230,8 @@ func indent(s string) string { return strings.Join(lines, "\n") } +// the following code is only used by tests + func (qt *QueryTable) testString() string { var alias string if !qt.Alias.As.IsEmpty() { diff --git a/go/vt/vtgate/planbuilder/abstract/querygraph.go b/go/vt/vtgate/planbuilder/abstract/querygraph.go index 98129eb6084..42c5f1c9592 100644 --- a/go/vt/vtgate/planbuilder/abstract/querygraph.go +++ b/go/vt/vtgate/planbuilder/abstract/querygraph.go @@ -52,6 +52,8 @@ type ( } ) +var _ Operator = (*QueryGraph)(nil) + // PushPredicate implements the Operator interface func (qg *QueryGraph) PushPredicate(expr sqlparser.Expr, semTable *semantics.SemTable) error { for _, e := range sqlparser.SplitAndExpression(nil, expr) { @@ -104,7 +106,7 @@ func (qg *QueryGraph) collectPredicates(sel *sqlparser.Select, semTable *semanti } func (qg *QueryGraph) collectPredicateTable(t sqlparser.TableExpr, predicate sqlparser.Expr, semTable *semantics.SemTable) error { - deps := semTable.Dependencies(predicate) + deps := semTable.GetBaseTableDependencies(predicate) switch deps.NumberOfTables() { case 0: qg.addNoDepsPredicate(predicate) @@ -135,7 +137,7 @@ func (qg *QueryGraph) collectPredicateTable(t sqlparser.TableExpr, predicate sql } func (qg *QueryGraph) collectPredicate(predicate sqlparser.Expr, semTable *semantics.SemTable) error { - deps := semTable.Dependencies(predicate) + deps := semTable.GetBaseTableDependencies(predicate) switch deps.NumberOfTables() { case 0: qg.addNoDepsPredicate(predicate) diff --git a/go/vt/vtgate/planbuilder/expand_star_test.go b/go/vt/vtgate/planbuilder/expand_star_test.go index 4eb9e566816..01b5b331b4c 100644 --- a/go/vt/vtgate/planbuilder/expand_star_test.go +++ b/go/vt/vtgate/planbuilder/expand_star_test.go @@ -169,14 +169,14 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) { assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) if tcase.otherTbl != -1 { assert.NotEqual(t, - semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), - semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + semTable.GetBaseTableDependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), + semTable.GetBaseTableDependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), ) } if tcase.sameTbl != -1 { assert.Equal(t, - semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), - semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + semTable.GetBaseTableDependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), + semTable.GetBaseTableDependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), ) } }) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 1b0e5dddf3a..4afb4be4ad6 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -55,7 +55,7 @@ func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *sem case *joinGen4: lhsSolves := node.Left.ContainsTables() rhsSolves := node.Right.ContainsTables() - deps := semTable.Dependencies(expr.Expr) + deps := semTable.GetBaseTableDependencies(expr.Expr) var column int var appended bool switch { @@ -103,19 +103,27 @@ func removeQualifierFromColName(expr *sqlparser.AliasedExpr) *sqlparser.AliasedE func checkIfAlreadyExists(expr *sqlparser.AliasedExpr, sel *sqlparser.Select) int { for i, selectExpr := range sel.SelectExprs { - if selectExpr, ok := selectExpr.(*sqlparser.AliasedExpr); ok { - if selectExpr.As.IsEmpty() { - // we don't have an alias, so we can compare the expressions - if sqlparser.EqualsExpr(selectExpr.Expr, expr.Expr) { - return i - } - // we have an aliased column, so let's check if the expression is matching the alias - } else if colName, ok := expr.Expr.(*sqlparser.ColName); ok { - if selectExpr.As.Equal(colName.Name) { - return i - } - } + selectExpr, ok := selectExpr.(*sqlparser.AliasedExpr) + if !ok { + continue + } + + selectExprCol, isSelectExprCol := selectExpr.Expr.(*sqlparser.ColName) + exprCol, isExprCol := expr.Expr.(*sqlparser.ColName) + + if selectExpr.As.IsEmpty() { + // we don't have an alias + if isSelectExprCol && isExprCol && exprCol.Name.Equal(selectExprCol.Name) { + // the expressions are ColName, we compare their name + return i + } else if sqlparser.EqualsExpr(selectExpr.Expr, expr.Expr) { + // the expressions are not ColName, so we just compare the expressions + return i + } + } else if isExprCol && selectExpr.As.Equal(exprCol.Name) { + // we have an aliased column, checking if the the expression is matching the alias + return i } } return -1 @@ -327,16 +335,15 @@ func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan log if err != nil { return 0, 0, false, err } - colName, ok := expr.(*sqlparser.ColName) + _, ok := expr.(*sqlparser.ColName) if !ok { return 0, 0, false, semantics.Gen4NotSupportedF("group by/order by non-column expression") } - table := semTable.Dependencies(colName) - tbl, err := semTable.TableInfoFor(table) - if err != nil { - return 0, 0, false, err + qt := semTable.TypeFor(expr) + wsNeeded := true + if qt != nil && sqltypes.IsNumber(*qt) { + wsNeeded = false } - wsNeeded := needsWeightString(tbl, colName) weightStringOffset := -1 var wAdded bool @@ -361,15 +368,6 @@ func weightStringFor(expr sqlparser.Expr) sqlparser.Expr { } -func needsWeightString(tbl semantics.TableInfo, colName *sqlparser.ColName) bool { - for _, c := range tbl.GetColumns() { - if colName.Name.String() == c.Name { - return !sqltypes.IsNumber(c.Type) - } - } - return true // we didn't find the column. better to add just to be safe1 -} - func (hp *horizonPlanning) planOrderByForJoin(orderExprs []abstract.OrderBy, plan *joinGen4) (logicalPlan, error) { if allLeft(orderExprs, hp.semTable, plan.Left.ContainsTables()) { newLeft, err := hp.planOrderBy(orderExprs, plan.Left) @@ -455,7 +453,7 @@ func (hp *horizonPlanning) createMemorySortPlan(plan logicalPlan, orderExprs []a func allLeft(orderExprs []abstract.OrderBy, semTable *semantics.SemTable, lhsTables semantics.TableSet) bool { for _, expr := range orderExprs { - exprDependencies := semTable.Dependencies(expr.Inner.Expr) + exprDependencies := semTable.GetBaseTableDependencies(expr.Inner.Expr) if !exprDependencies.IsSolvedBy(lhsTables) { return false } diff --git a/go/vt/vtgate/planbuilder/horizon_planning_test.go b/go/vt/vtgate/planbuilder/horizon_planning_test.go new file mode 100644 index 00000000000..37376347970 --- /dev/null +++ b/go/vt/vtgate/planbuilder/horizon_planning_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/vt/sqlparser" +) + +func TestCheckIfAlreadyExists(t *testing.T) { + tests := []struct { + name string + expr *sqlparser.AliasedExpr + sel *sqlparser.Select + want int + }{ + { + name: "No alias, both ColName", + want: 0, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}}}, + }, + { + name: "Aliased expression and ColName", + want: 0, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("user_id")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{As: sqlparser.NewColIdent("user_id"), Expr: sqlparser.NewColName("id")}}}, + }, + { + name: "Non-ColName expressions", + want: 0, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewStrLiteral("test")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: sqlparser.NewStrLiteral("test")}}}, + }, + { + name: "No alias, multiple ColName in projection", + want: 1, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: sqlparser.NewColName("foo")}, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}}}, + }, + { + name: "No matching entry", + want: -1, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: sqlparser.NewColName("foo")}, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("name")}}}, + }, + { + name: "No AliasedExpr in projection", + want: -1, + expr: &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("id")}, + sel: &sqlparser.Select{SelectExprs: []sqlparser.SelectExpr{&sqlparser.StarExpr{TableName: sqlparser.TableName{Name: sqlparser.NewTableIdent("user")}}, &sqlparser.StarExpr{TableName: sqlparser.TableName{Name: sqlparser.NewTableIdent("people")}}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := checkIfAlreadyExists(tt.expr, tt.sel) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/go/vt/vtgate/planbuilder/jointree.go b/go/vt/vtgate/planbuilder/jointree.go index 9ea7c5c07d2..ca58d507ecb 100644 --- a/go/vt/vtgate/planbuilder/jointree.go +++ b/go/vt/vtgate/planbuilder/jointree.go @@ -19,6 +19,9 @@ package planbuilder import ( "strings" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -29,48 +32,36 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) +// queryTree interface and implementations +// These representation helps in optimizing the join planning using tables and predicates. type ( - joinTree interface { + queryTree interface { // tableID returns the table identifiers that are solved by this plan tableID() semantics.TableSet - // cost is simply the number of routes in the joinTree + // cost is simply the number of routes in the queryTree cost() int - // creates a copy of the joinTree that can be updated without changing the original - clone() joinTree + // creates a copy of the queryTree that can be updated without changing the original + clone() queryTree - pushOutputColumns([]*sqlparser.ColName, *semantics.SemTable) []int + pushOutputColumns([]*sqlparser.ColName, *semantics.SemTable) ([]int, error) } - relation interface { - tableID() semantics.TableSet - tableNames() []string - } - - joinTables struct { - lhs, rhs relation - pred sqlparser.Expr - } + joinTree struct { + // columns needed to feed other plans + columns []int - routeTable struct { - qtable *abstract.QueryTable - vtable *vindexes.Table - } + // arguments that need to be copied from the LHS/RHS + vars map[string]int - outerTable struct { - right relation - pred sqlparser.Expr - } + // the children of this plan + lhs, rhs queryTree - // cost is used to make it easy to compare the cost of two plans with each other - cost struct { - vindexCost int - isUnique bool - opCode engine.RouteOpcode + outer bool } - routePlan struct { + routeTree struct { routeOpCode engine.RouteOpcode solved semantics.TableSet keyspace *vindexes.Keyspace @@ -103,39 +94,133 @@ type ( SysTableTableName map[string]evalengine.Expr } - joinPlan struct { - // columns needed to feed other plans - columns []int + derivedTree struct { + query *sqlparser.Select + inner queryTree + alias string + } +) - // arguments that need to be copied from the LHS/RHS - vars map[string]int +// relation interface and implementations +// They are representation of the tables in a routeTree +// When we are able to merge queryTree then it lives as relation otherwise it stays as joinTree +type ( + relation interface { + tableID() semantics.TableSet + tableNames() []string + } - // the children of this plan - lhs, rhs joinTree + joinTables struct { + lhs, rhs relation + pred sqlparser.Expr + } - outer bool + routeTable struct { + qtable *abstract.QueryTable + vtable *vindexes.Table } parenTables []relation - // vindexPlusPredicates is a struct used to store all the predicates that the vindex can be used to query - vindexPlusPredicates struct { - colVindex *vindexes.ColumnVindex - values []sqltypes.PlanValue + derivedTable struct { + // tables contains inner tables that are solved by this plan. + // the tables also contain any predicates that only depend on that particular table + tables parenTables - // when we have the predicates found, we also know how to interact with this vindex - foundVindex vindexes.Vindex - opcode engine.RouteOpcode - predicates []sqlparser.Expr + // predicates are the predicates evaluated by this plan + predicates []sqlparser.Expr + + // leftJoins are the join conditions evaluated by this plan + leftJoins []*outerTable + + alias string + + query *sqlparser.Select } ) +type outerTable struct { + right relation + pred sqlparser.Expr +} + +// cost is used to make it easy to compare the cost of two plans with each other +type cost struct { + vindexCost int + isUnique bool + opCode engine.RouteOpcode +} + +// vindexPlusPredicates is a struct used to store all the predicates that the vindex can be used to query +type vindexPlusPredicates struct { + colVindex *vindexes.ColumnVindex + values []sqltypes.PlanValue + + // when we have the predicates found, we also know how to interact with this vindex + foundVindex vindexes.Vindex + opcode engine.RouteOpcode + predicates []sqlparser.Expr +} + +func (d *derivedTree) tableID() semantics.TableSet { + return d.inner.tableID() +} + +func (d *derivedTree) cost() int { + panic("implement me") +} + +func (d *derivedTree) clone() queryTree { + other := *d + other.inner = d.inner.clone() + return &other +} + +func (d *derivedTree) pushOutputColumns(names []*sqlparser.ColName, _ *semantics.SemTable) (offsets []int, err error) { + for _, name := range names { + offset, err := d.findOutputColumn(name) + if err != nil { + return nil, err + } + offsets = append(offsets, offset) + } + return +} + +func (d *derivedTree) findOutputColumn(name *sqlparser.ColName) (int, error) { + for j, exp := range d.query.SelectExprs { + ae, ok := exp.(*sqlparser.AliasedExpr) + if !ok { + return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "expected AliasedExpr") + } + if !ae.As.IsEmpty() && ae.As.Equal(name.Name) { + return j, nil + } + if ae.As.IsEmpty() { + col, ok := ae.Expr.(*sqlparser.ColName) + if !ok { + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "complex expression needs column alias: %s", sqlparser.String(ae)) + } + if name.Name.Equal(col.Name) { + return j, nil + } + } + } + return 0, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", name.Name.String()) +} + // type assertions -var _ joinTree = (*routePlan)(nil) -var _ joinTree = (*joinPlan)(nil) +var _ queryTree = (*routeTree)(nil) +var _ queryTree = (*joinTree)(nil) +var _ queryTree = (*derivedTree)(nil) var _ relation = (*routeTable)(nil) var _ relation = (*joinTables)(nil) var _ relation = (parenTables)(nil) +var _ relation = (*derivedTable)(nil) + +func (d *derivedTable) tableID() semantics.TableSet { return d.tables.tableID() } + +func (d *derivedTable) tableNames() []string { return d.tables.tableNames() } func (rp *routeTable) tableID() semantics.TableSet { return rp.qtable.TableID } @@ -171,39 +256,51 @@ func (p parenTables) tableID() semantics.TableSet { return res } -// visit will traverse the route tables, going inside parenTables and visiting all routeTables -func visitTables(r relation, f func(tbl *routeTable) error) error { +// visitRelations visits all relations recursively and applies the function f on them +// If the function f returns false: the children of the current relation will not be visited. +func visitRelations(r relation, f func(tbl relation) (bool, error)) error { + + kontinue, err := f(r) + if err != nil { + return err + } + if !kontinue { + return nil + } + switch r := r.(type) { case *routeTable: - err := f(r) - if err != nil { - return err - } + // already visited when entering this method case parenTables: for _, r := range r { - err := visitTables(r, f) + err := visitRelations(r, f) if err != nil { return err } } return nil case *joinTables: - err := visitTables(r.lhs, f) + err := visitRelations(r.lhs, f) if err != nil { return err } - err = visitTables(r.rhs, f) + err = visitRelations(r.rhs, f) if err != nil { return err } return nil + case *derivedTable: + err := visitRelations(r.tables, f) + if err != nil { + return err + } } return nil } // clone returns a copy of the struct with copies of slices, // so changing the the contents of them will not be reflected in the original -func (rp *routePlan) clone() joinTree { +func (rp *routeTree) clone() queryTree { result := *rp result.vindexPreds = make([]*vindexPlusPredicates, len(rp.vindexPreds)) for i, pred := range rp.vindexPreds { @@ -214,17 +311,17 @@ func (rp *routePlan) clone() joinTree { return &result } -// tables implements the joinTree interface -func (rp *routePlan) tableID() semantics.TableSet { +// tables implements the queryTree interface +func (rp *routeTree) tableID() semantics.TableSet { return rp.solved } -func (rp *routePlan) hasOuterjoins() bool { +func (rp *routeTree) hasOuterjoins() bool { return len(rp.leftJoins) > 0 } -// cost implements the joinTree interface -func (rp *routePlan) cost() int { +// cost implements the queryTree interface +func (rp *routeTree) cost() int { switch rp.routeOpCode { case // these op codes will never be compared with each other - they are assigned by a rule and not a comparison engine.SelectDBA, @@ -250,7 +347,7 @@ func (rp *routePlan) cost() int { // addPredicate adds these predicates added to it. if the predicates can help, // they will improve the routeOpCode -func (rp *routePlan) addPredicate(predicates ...sqlparser.Expr) error { +func (rp *routeTree) addPredicate(predicates ...sqlparser.Expr) error { if rp.canImprove() { newVindexFound, err := rp.searchForNewVindexes(predicates) if err != nil { @@ -270,11 +367,11 @@ func (rp *routePlan) addPredicate(predicates ...sqlparser.Expr) error { } // canImprove returns true if additional predicates could help improving this plan -func (rp *routePlan) canImprove() bool { +func (rp *routeTree) canImprove() bool { return rp.routeOpCode != engine.SelectNone } -func (rp *routePlan) isImpossibleIN(node *sqlparser.ComparisonExpr) bool { +func (rp *routeTree) isImpossibleIN(node *sqlparser.ComparisonExpr) bool { switch nodeR := node.Right.(type) { case sqlparser.ValTuple: // WHERE col IN (null) @@ -286,7 +383,7 @@ func (rp *routePlan) isImpossibleIN(node *sqlparser.ComparisonExpr) bool { return false } -func (rp *routePlan) isImpossibleNotIN(node *sqlparser.ComparisonExpr) bool { +func (rp *routeTree) isImpossibleNotIN(node *sqlparser.ComparisonExpr) bool { switch node := node.Right.(type) { case sqlparser.ValTuple: for _, n := range node { @@ -300,7 +397,7 @@ func (rp *routePlan) isImpossibleNotIN(node *sqlparser.ComparisonExpr) bool { return false } -func (rp *routePlan) searchForNewVindexes(predicates []sqlparser.Expr) (bool, error) { +func (rp *routeTree) searchForNewVindexes(predicates []sqlparser.Expr) (bool, error) { newVindexFound := false for _, filter := range predicates { switch node := filter.(type) { @@ -367,7 +464,7 @@ func equalOrEqualUnique(vindex *vindexes.ColumnVindex) engine.RouteOpcode { return engine.SelectEqual } -func (rp *routePlan) planEqualOp(node *sqlparser.ComparisonExpr) (bool, error) { +func (rp *routeTree) planEqualOp(node *sqlparser.ComparisonExpr) (bool, error) { column, ok := node.Left.(*sqlparser.ColName) other := node.Right if !ok { @@ -386,7 +483,7 @@ func (rp *routePlan) planEqualOp(node *sqlparser.ComparisonExpr) (bool, error) { return rp.haveMatchingVindex(node, column, *val, equalOrEqualUnique, justTheVindex), err } -func (rp *routePlan) planSimpleInOp(node *sqlparser.ComparisonExpr, left *sqlparser.ColName) (bool, error) { +func (rp *routeTree) planSimpleInOp(node *sqlparser.ComparisonExpr, left *sqlparser.ColName) (bool, error) { value, err := sqlparser.NewPlanValue(node.Right) if err != nil { // if we are unable to create a PlanValue, we can't use a vindex, but we don't have to fail @@ -407,7 +504,7 @@ func (rp *routePlan) planSimpleInOp(node *sqlparser.ComparisonExpr, left *sqlpar return rp.haveMatchingVindex(node, left, value, opcode, justTheVindex), err } -func (rp *routePlan) planCompositeInOp(node *sqlparser.ComparisonExpr, left sqlparser.ValTuple) (bool, error) { +func (rp *routeTree) planCompositeInOp(node *sqlparser.ComparisonExpr, left sqlparser.ValTuple) (bool, error) { right, rightIsValTuple := node.Right.(sqlparser.ValTuple) if !rightIsValTuple { return false, nil @@ -415,7 +512,7 @@ func (rp *routePlan) planCompositeInOp(node *sqlparser.ComparisonExpr, left sqlp return rp.planCompositeInOpRecursive(node, left, right, nil) } -func (rp *routePlan) planCompositeInOpRecursive(node *sqlparser.ComparisonExpr, left, right sqlparser.ValTuple, coordinates []int) (bool, error) { +func (rp *routeTree) planCompositeInOpRecursive(node *sqlparser.ComparisonExpr, left, right sqlparser.ValTuple, coordinates []int) (bool, error) { foundVindex := false cindex := len(coordinates) coordinates = append(coordinates, 0) @@ -460,7 +557,7 @@ func (rp *routePlan) planCompositeInOpRecursive(node *sqlparser.ComparisonExpr, return foundVindex, nil } -func (rp *routePlan) planInOp(node *sqlparser.ComparisonExpr) (bool, error) { +func (rp *routeTree) planInOp(node *sqlparser.ComparisonExpr) (bool, error) { switch left := node.Left.(type) { case *sqlparser.ColName: return rp.planSimpleInOp(node, left) @@ -470,7 +567,7 @@ func (rp *routePlan) planInOp(node *sqlparser.ComparisonExpr) (bool, error) { return false, nil } -func (rp *routePlan) planLikeOp(node *sqlparser.ComparisonExpr) (bool, error) { +func (rp *routeTree) planLikeOp(node *sqlparser.ComparisonExpr) (bool, error) { column, ok := node.Left.(*sqlparser.ColName) if !ok { return false, nil @@ -494,7 +591,7 @@ func (rp *routePlan) planLikeOp(node *sqlparser.ComparisonExpr) (bool, error) { return rp.haveMatchingVindex(node, column, *val, selectEqual, vdx), err } -func (rp *routePlan) planIsExpr(node *sqlparser.IsExpr) (bool, error) { +func (rp *routeTree) planIsExpr(node *sqlparser.IsExpr) (bool, error) { // we only handle IS NULL correct. IsExpr can contain other expressions as well if node.Right != sqlparser.IsNullOp { return false, nil @@ -524,7 +621,7 @@ func makePlanValue(n sqlparser.Expr) (*sqltypes.PlanValue, error) { return &value, nil } -func (rp routePlan) hasVindex(column *sqlparser.ColName) bool { +func (rp routeTree) hasVindex(column *sqlparser.ColName) bool { for _, v := range rp.vindexPreds { for _, col := range v.colVindex.Columns { if column.Name.Equal(col) { @@ -535,7 +632,7 @@ func (rp routePlan) hasVindex(column *sqlparser.ColName) bool { return false } -func (rp *routePlan) haveMatchingVindex( +func (rp *routeTree) haveMatchingVindex( node sqlparser.Expr, column *sqlparser.ColName, value sqltypes.PlanValue, @@ -566,7 +663,7 @@ func (rp *routePlan) haveMatchingVindex( } // pickBestAvailableVindex goes over the available vindexes for this route and picks the best one available. -func (rp *routePlan) pickBestAvailableVindex() { +func (rp *routeTree) pickBestAvailableVindex() { for _, v := range rp.vindexPreds { if v.foundVindex == nil { continue @@ -584,16 +681,23 @@ func (rp *routePlan) pickBestAvailableVindex() { } // Predicates takes all known predicates for this route and ANDs them together -func (rp *routePlan) Predicates() sqlparser.Expr { +func (rp *routeTree) Predicates() sqlparser.Expr { predicates := rp.predicates - _ = visitTables(rp.tables, func(tbl *routeTable) error { - predicates = append(predicates, tbl.qtable.Predicates...) - return nil + _ = visitRelations(rp.tables, func(tbl relation) (bool, error) { + switch tbl := tbl.(type) { + case *routeTable: + predicates = append(predicates, tbl.qtable.Predicates...) + case *derivedTable: + // no need to copy the inner predicates to the outside + return false, nil + } + + return true, nil }) return sqlparser.AndExpressions(predicates...) } -func (rp *routePlan) pushOutputColumns(col []*sqlparser.ColName, _ *semantics.SemTable) []int { +func (rp *routeTree) pushOutputColumns(col []*sqlparser.ColName, _ *semantics.SemTable) ([]int, error) { idxs := make([]int, len(col)) outer: for i, newCol := range col { @@ -606,19 +710,19 @@ outer: idxs[i] = len(rp.columns) rp.columns = append(rp.columns, newCol) } - return idxs + return idxs, nil } -func (jp *joinPlan) tableID() semantics.TableSet { +func (jp *joinTree) tableID() semantics.TableSet { return jp.lhs.tableID() | jp.rhs.tableID() } -func (jp *joinPlan) cost() int { +func (jp *joinTree) cost() int { return jp.lhs.cost() + jp.rhs.cost() } -func (jp *joinPlan) clone() joinTree { - result := &joinPlan{ +func (jp *joinTree) clone() queryTree { + result := &joinTree{ lhs: jp.lhs.clone(), rhs: jp.rhs.clone(), outer: jp.outer, @@ -626,12 +730,17 @@ func (jp *joinPlan) clone() joinTree { return result } -func (jp *joinPlan) pushOutputColumns(columns []*sqlparser.ColName, semTable *semantics.SemTable) []int { +/* + +select id, t2.b from t1 , (select b from t2) t2 where t.id = 1 +*/ + +func (jp *joinTree) pushOutputColumns(columns []*sqlparser.ColName, semTable *semantics.SemTable) ([]int, error) { var toTheLeft []bool var lhs, rhs []*sqlparser.ColName for _, col := range columns { col.Qualifier.Qualifier = sqlparser.NewTableIdent("") - if semTable.Dependencies(col).IsSolvedBy(jp.lhs.tableID()) { + if semTable.GetBaseTableDependencies(col).IsSolvedBy(jp.lhs.tableID()) { lhs = append(lhs, col) toTheLeft = append(toTheLeft, true) } else { @@ -639,8 +748,15 @@ func (jp *joinPlan) pushOutputColumns(columns []*sqlparser.ColName, semTable *se toTheLeft = append(toTheLeft, false) } } - lhsOffset := jp.lhs.pushOutputColumns(lhs, semTable) - rhsOffset := jp.rhs.pushOutputColumns(rhs, semTable) + lhsOffset, err := jp.lhs.pushOutputColumns(lhs, semTable) + if err != nil { + return nil, err + } + rhsOffset, err := jp.rhs.pushOutputColumns(rhs, semTable) + if err != nil { + return nil, err + } + outputColumns := make([]int, len(toTheLeft)) var l, r int for i, isLeft := range toTheLeft { @@ -653,7 +769,7 @@ func (jp *joinPlan) pushOutputColumns(columns []*sqlparser.ColName, semTable *se r++ } } - return outputColumns + return outputColumns, nil } // costFor returns a cost struct to make route choices easier to compare diff --git a/go/vt/vtgate/planbuilder/jointree_transformers.go b/go/vt/vtgate/planbuilder/jointree_transformers.go index aff58d9d786..89b7c362dcc 100644 --- a/go/vt/vtgate/planbuilder/jointree_transformers.go +++ b/go/vt/vtgate/planbuilder/jointree_transformers.go @@ -29,24 +29,57 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func transformToLogicalPlan(tree joinTree, semTable *semantics.SemTable) (logicalPlan, error) { +func transformToLogicalPlan(tree queryTree, semTable *semantics.SemTable, processing *postProcessor) (logicalPlan, error) { switch n := tree.(type) { - case *routePlan: + case *routeTree: return transformRoutePlan(n, semTable) - case *joinPlan: - return transformJoinPlan(n, semTable) + case *joinTree: + return transformJoinPlan(n, semTable, processing) + + case *derivedTree: + // transforming the inner part of the derived table into a logical plan + // so that we can do horizon planning on the inner. If the logical plan + // we've produced is a Route, we set its Select.From field to be an aliased + // expression containing our derived table's inner select and the derived + // table's alias. + + plan, err := transformToLogicalPlan(n.inner, semTable, processing) + if err != nil { + return nil, err + } + processing.inDerived = true + plan, err = processing.planHorizon(plan, n.query) + if err != nil { + return nil, err + } + processing.inDerived = false + + rb, isRoute := plan.(*route) + if !isRoute { + return plan, nil + } + innerSelect := rb.Select + derivedTable := &sqlparser.DerivedTable{Select: innerSelect} + tblExpr := &sqlparser.AliasedTableExpr{ + Expr: derivedTable, + As: sqlparser.NewTableIdent(n.alias), + } + rb.Select = &sqlparser.Select{ + From: []sqlparser.TableExpr{tblExpr}, + } + return plan, nil } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unknown type encountered: %T", tree) } -func transformJoinPlan(n *joinPlan, semTable *semantics.SemTable) (logicalPlan, error) { - lhs, err := transformToLogicalPlan(n.lhs, semTable) +func transformJoinPlan(n *joinTree, semTable *semantics.SemTable, processing *postProcessor) (logicalPlan, error) { + lhs, err := transformToLogicalPlan(n.lhs, semTable, processing) if err != nil { return nil, err } - rhs, err := transformToLogicalPlan(n.rhs, semTable) + rhs, err := transformToLogicalPlan(n.rhs, semTable, processing) if err != nil { return nil, err } @@ -63,7 +96,7 @@ func transformJoinPlan(n *joinPlan, semTable *semantics.SemTable) (logicalPlan, }, nil } -func transformRoutePlan(n *routePlan, semTable *semantics.SemTable) (*route, error) { +func transformRoutePlan(n *routeTree, semTable *semantics.SemTable) (*route, error) { var tablesForSelect sqlparser.TableExprs tableNameMap := map[string]interface{}{} @@ -206,6 +239,26 @@ func relToTableExpr(t relation) (sqlparser.TableExpr, error) { On: t.pred, }, }, nil + case *derivedTable: + innerTables, err := relToTableExpr(t.tables) + if err != nil { + return nil, err + } + tbls := innerTables.(*sqlparser.ParenTableExpr) + + sel := &sqlparser.Select{ + SelectExprs: t.query.SelectExprs, + From: tbls.Exprs, + Where: &sqlparser.Where{Expr: sqlparser.AndExpressions(t.predicates...)}, + } + expr := &sqlparser.DerivedTable{ + Select: sel, + } + return &sqlparser.AliasedTableExpr{ + Expr: expr, + Partitions: nil, + As: sqlparser.NewTableIdent(t.alias), + }, nil default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unknown relation type: %T", t) } diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index 418d083e8f9..497b26ec67c 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -132,15 +132,16 @@ func (rb *route) SetLimit(limit *sqlparser.Limit) { rb.Select.SetLimit(limit) } -// Wireup2 implements the logicalPlan interface +// WireupGen4 implements the logicalPlan interface func (rb *route) WireupGen4(semTable *semantics.SemTable) error { rb.prepareTheAST() rb.eroute.Query = sqlparser.String(rb.Select) - buffer := sqlparser.NewTrackedBuffer(nil) - sqlparser.FormatImpossibleQuery(buffer, rb.Select) - rb.eroute.FieldQuery = buffer.ParsedQuery().Query + buffer := sqlparser.NewTrackedBuffer(sqlparser.FormatImpossibleQuery) + node := buffer.WriteNode(rb.Select) + query := node.ParsedQuery() + rb.eroute.FieldQuery = query.Query return nil } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index f9807528001..fb40d1064c8 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -62,6 +62,34 @@ func gen4Planner(_ string) func(sqlparser.Statement, *sqlparser.ReservedVars, Co } } +type postProcessor struct { + inDerived bool + semTable *semantics.SemTable + vschema ContextVSchema +} + +func (pp *postProcessor) planHorizon(plan logicalPlan, sel *sqlparser.Select) (logicalPlan, error) { + hp := horizonPlanning{ + sel: sel, + plan: plan, + semTable: pp.semTable, + vschema: pp.vschema, + inDerived: pp.inDerived, + } + + plan, err := hp.planHorizon() + if err != nil { + return nil, err + } + + plan, err = planLimit(sel.Limit, plan) + if err != nil { + return nil, err + } + return plan, nil + +} + func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedVars, vschema ContextVSchema) (logicalPlan, error) { ksName := "" if ks, _ := vschema.DefaultKeyspace(); ks != nil { @@ -87,24 +115,16 @@ func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedV return nil, err } - plan, err := transformToLogicalPlan(tree, semTable) - if err != nil { - return nil, err - } - - hp := horizonPlanning{ - sel: sel, - plan: plan, + postProcessing := &postProcessor{ semTable: semTable, vschema: vschema, } - - plan, err = hp.planHorizon() + plan, err := transformToLogicalPlan(tree, semTable, postProcessing) if err != nil { return nil, err } - plan, err = planLimit(sel.Limit, plan) + plan, err = postProcessing.planHorizon(plan, sel) if err != nil { return nil, err } @@ -127,7 +147,7 @@ func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedV return plan, nil } -func optimizeQuery(opTree abstract.Operator, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (joinTree, error) { +func optimizeQuery(opTree abstract.Operator, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (queryTree, error) { switch op := opTree.(type) { case *abstract.QueryGraph: switch { @@ -156,7 +176,16 @@ func optimizeQuery(opTree abstract.Operator, reservedVars *sqlparser.ReservedVar return nil, err } return mergeOrJoin(treeInner, treeOuter, []sqlparser.Expr{op.Exp}, semTable, true) - + case *abstract.Derived: + treeInner, err := optimizeQuery(op.Inner, reservedVars, semTable, vschema) + if err != nil { + return nil, err + } + return &derivedTree{ + query: op.Sel, + inner: treeInner, + alias: op.Alias, + }, nil default: return nil, semantics.Gen4NotSupportedF("optimizeQuery") } @@ -191,6 +220,7 @@ type horizonPlanning struct { semTable *semantics.SemTable vschema ContextVSchema qp *abstract.QueryProjection + inDerived bool needsTruncation bool vtgateGrouping bool } @@ -201,6 +231,14 @@ func (hp *horizonPlanning) planHorizon() (logicalPlan, error) { return nil, hp.semTable.ProjectionErr } + if hp.inDerived { + for _, expr := range hp.sel.SelectExprs { + if sqlparser.ContainsAggregation(expr) { + return nil, semantics.Gen4NotSupportedF("aggregation inside of derived table") + } + } + } + if ok && rb.isSingleShard() { createSingleShardRoutePlan(hp.sel, rb) return hp.plan, nil @@ -283,7 +321,7 @@ func exprHasUniqueVindex(vschema ContextVSchema, semTable *semantics.SemTable, e if !isCol { return false } - ts := semTable.Dependencies(expr) + ts := semTable.GetBaseTableDependencies(expr) tableInfo, err := semTable.TableInfoFor(ts) if err != nil { return false @@ -328,18 +366,18 @@ func checkUnsupportedConstructs(sel *sqlparser.Select) error { return nil } -func pushJoinPredicate(exprs []sqlparser.Expr, tree joinTree, semTable *semantics.SemTable) (joinTree, error) { +func pushJoinPredicate(exprs []sqlparser.Expr, tree queryTree, semTable *semantics.SemTable) (queryTree, error) { switch node := tree.(type) { - case *routePlan: - plan := node.clone().(*routePlan) + case *routeTree: + plan := node.clone().(*routeTree) err := plan.addPredicate(exprs...) if err != nil { return nil, err } return plan, nil - case *joinPlan: - node = node.clone().(*joinPlan) + case *joinTree: + node = node.clone().(*joinTree) // we break up the predicates so that colnames from the LHS are replaced by arguments var rhsPreds []sqlparser.Expr @@ -359,11 +397,34 @@ func pushJoinPredicate(exprs []sqlparser.Expr, tree joinTree, semTable *semantic return nil, err } - return &joinPlan{ + return &joinTree{ lhs: node.lhs, rhs: rhsPlan, outer: node.outer, }, nil + case *derivedTree: + plan := node.clone().(*derivedTree) + + newExpressions := make([]sqlparser.Expr, 0, len(exprs)) + for _, expr := range exprs { + tblInfo, err := semTable.TableInfoForExpr(expr) + if err != nil { + return nil, err + } + rewritten, err := semantics.RewriteDerivedExpression(expr, tblInfo) + if err != nil { + return nil, err + } + newExpressions = append(newExpressions, rewritten) + } + + newInner, err := pushJoinPredicate(newExpressions, plan.inner, semTable) + if err != nil { + return nil, err + } + + plan.inner = newInner + return plan, nil default: panic(fmt.Sprintf("BUG: unknown type %T", node)) } @@ -374,7 +435,7 @@ func breakPredicateInLHSandRHS(expr sqlparser.Expr, semTable *semantics.SemTable _ = sqlparser.Rewrite(predicate, nil, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.ColName: - deps := semTable.Dependencies(node) + deps := semTable.GetBaseTableDependencies(node) if deps == 0 { err = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unknown column. has the AST been copied?") return false @@ -394,17 +455,17 @@ func breakPredicateInLHSandRHS(expr sqlparser.Expr, semTable *semantics.SemTable return } -func mergeOrJoinInner(lhs, rhs joinTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) (joinTree, error) { +func mergeOrJoinInner(lhs, rhs queryTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) (queryTree, error) { return mergeOrJoin(lhs, rhs, joinPredicates, semTable, true) } -func mergeOrJoin(lhs, rhs joinTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable, inner bool) (joinTree, error) { +func mergeOrJoin(lhs, rhs queryTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable, inner bool) (queryTree, error) { newPlan := tryMerge(lhs, rhs, joinPredicates, semTable, inner) if newPlan != nil { return newPlan, nil } - tree := &joinPlan{lhs: lhs.clone(), rhs: rhs.clone(), outer: !inner} + tree := &joinTree{lhs: lhs.clone(), rhs: rhs.clone(), outer: !inner} return pushJoinPredicate(joinPredicates, tree, semTable) } @@ -412,7 +473,7 @@ type ( tableSetPair struct { left, right semantics.TableSet } - cacheMap map[tableSetPair]joinTree + cacheMap map[tableSetPair]queryTree ) /* @@ -421,7 +482,7 @@ type ( and removes the two inputs to this cheapest plan and instead adds the join. As an optimization, it first only considers joining tables that have predicates defined between them */ -func greedySolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (joinTree, error) { +func greedySolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (queryTree, error) { joinTrees, err := seedPlanList(qg, reservedVars, semTable, vschema) planCache := cacheMap{} if err != nil { @@ -435,7 +496,7 @@ func greedySolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, return tree, nil } -func mergeJoinTrees(qg *abstract.QueryGraph, semTable *semantics.SemTable, joinTrees []joinTree, planCache cacheMap, crossJoinsOK bool) (joinTree, error) { +func mergeJoinTrees(qg *abstract.QueryGraph, semTable *semantics.SemTable, joinTrees []queryTree, planCache cacheMap, crossJoinsOK bool) (queryTree, error) { if len(joinTrees) == 0 { return nil, nil } @@ -465,7 +526,7 @@ func mergeJoinTrees(qg *abstract.QueryGraph, semTable *semantics.SemTable, joinT return joinTrees[0], nil } -func (cm cacheMap) getJoinTreeFor(lhs, rhs joinTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) (joinTree, error) { +func (cm cacheMap) getJoinTreeFor(lhs, rhs queryTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) (queryTree, error) { solves := tableSetPair{left: lhs.tableID(), right: rhs.tableID()} cachedPlan := cm[solves] if cachedPlan != nil { @@ -483,10 +544,10 @@ func (cm cacheMap) getJoinTreeFor(lhs, rhs joinTree, joinPredicates []sqlparser. func findBestJoinTree( qg *abstract.QueryGraph, semTable *semantics.SemTable, - plans []joinTree, + plans []queryTree, planCache cacheMap, crossJoinsOK bool, -) (bestPlan joinTree, lIdx int, rIdx int, err error) { +) (bestPlan queryTree, lIdx int, rIdx int, err error) { for i, lhs := range plans { for j, rhs := range plans { if i == j { @@ -514,13 +575,13 @@ func findBestJoinTree( return bestPlan, lIdx, rIdx, nil } -func leftToRightSolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (joinTree, error) { +func leftToRightSolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) (queryTree, error) { plans, err := seedPlanList(qg, reservedVars, semTable, vschema) if err != nil { return nil, err } - var acc joinTree + var acc queryTree for _, plan := range plans { if acc == nil { acc = plan @@ -536,9 +597,9 @@ func leftToRightSolve(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedV return acc, nil } -// seedPlanList returns a routePlan for each table in the qg -func seedPlanList(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) ([]joinTree, error) { - plans := make([]joinTree, len(qg.Tables)) +// seedPlanList returns a routeTree for each table in the qg +func seedPlanList(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema ContextVSchema) ([]queryTree, error) { + plans := make([]queryTree, len(qg.Tables)) // we start by seeding the table with the single routes for i, table := range qg.Tables { @@ -555,17 +616,17 @@ func seedPlanList(qg *abstract.QueryGraph, reservedVars *sqlparser.ReservedVars, return plans, nil } -func removeAt(plans []joinTree, idx int) []joinTree { +func removeAt(plans []queryTree, idx int) []queryTree { return append(plans[:idx], plans[idx+1:]...) } -func createRoutePlan(table *abstract.QueryTable, solves semantics.TableSet, reservedVars *sqlparser.ReservedVars, vschema ContextVSchema) (*routePlan, error) { +func createRoutePlan(table *abstract.QueryTable, solves semantics.TableSet, reservedVars *sqlparser.ReservedVars, vschema ContextVSchema) (*routeTree, error) { if table.IsInfSchema { ks, err := vschema.AnyKeyspace() if err != nil { return nil, err } - rp := &routePlan{ + rp := &routeTree{ routeOpCode: engine.SelectDBA, solved: solves, keyspace: ks, @@ -604,7 +665,7 @@ func createRoutePlan(table *abstract.QueryTable, solves semantics.TableSet, rese table.Alias.As = sqlparser.NewTableIdent(name.String()) } } - plan := &routePlan{ + plan := &routeTree{ solved: solves, tables: []relation{&routeTable{ qtable: table, @@ -641,32 +702,36 @@ func createRoutePlan(table *abstract.QueryTable, solves semantics.TableSet, rese return plan, nil } -func findColumnVindex(a *routePlan, exp sqlparser.Expr, sem *semantics.SemTable) vindexes.SingleColumn { +func findColumnVindex(a *routeTree, exp sqlparser.Expr, sem *semantics.SemTable) vindexes.SingleColumn { left, isCol := exp.(*sqlparser.ColName) if !isCol { return nil } - leftDep := sem.Dependencies(left) + leftDep := sem.GetBaseTableDependencies(left) var singCol vindexes.SingleColumn - _ = visitTables(a.tables, func(table *routeTable) error { - if leftDep.IsSolvedBy(table.qtable.TableID) { - for _, vindex := range table.vtable.ColumnVindexes { + _ = visitRelations(a.tables, func(rel relation) (bool, error) { + rb, isRoute := rel.(*routeTable) + if !isRoute { + return true, nil + } + if leftDep.IsSolvedBy(rb.qtable.TableID) { + for _, vindex := range rb.vtable.ColumnVindexes { sC, isSingle := vindex.Vindex.(vindexes.SingleColumn) if isSingle && vindex.Columns[0].Equal(left.Name) { singCol = sC - return io.EOF + return false, io.EOF } } } - return nil + return false, nil }) return singCol } -func canMergeOnFilter(a, b *routePlan, predicate sqlparser.Expr, sem *semantics.SemTable) bool { +func canMergeOnFilter(a, b *routeTree, predicate sqlparser.Expr, sem *semantics.SemTable) bool { comparison, ok := predicate.(*sqlparser.ComparisonExpr) if !ok { return false @@ -692,7 +757,7 @@ func canMergeOnFilter(a, b *routePlan, predicate sqlparser.Expr, sem *semantics. return rVindex == lVindex } -func canMergeOnFilters(a, b *routePlan, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) bool { +func canMergeOnFilters(a, b *routeTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable) bool { for _, predicate := range joinPredicates { for _, expr := range sqlparser.SplitAndExpression(nil, predicate) { if canMergeOnFilter(a, b, expr, semTable) { @@ -703,7 +768,7 @@ func canMergeOnFilters(a, b *routePlan, joinPredicates []sqlparser.Expr, semTabl return false } -func tryMerge(a, b joinTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable, inner bool) joinTree { +func tryMerge(a, b queryTree, joinPredicates []sqlparser.Expr, semTable *semantics.SemTable, inner bool) queryTree { aRoute, bRoute := joinTreesToRoutes(a, b) if aRoute == nil || bRoute == nil { return nil @@ -716,7 +781,7 @@ func tryMerge(a, b joinTree, joinPredicates []sqlparser.Expr, semTable *semantic newTabletSet := aRoute.solved | bRoute.solved - var r *routePlan + var r *routeTree if inner { r = createRoutePlanForInner(aRoute, bRoute, newTabletSet, joinPredicates) } else { @@ -752,19 +817,50 @@ func tryMerge(a, b joinTree, joinPredicates []sqlparser.Expr, semTable *semantic return r } -func joinTreesToRoutes(a, b joinTree) (*routePlan, *routePlan) { - aRoute, ok := a.(*routePlan) +func makeRoute(j queryTree) *routeTree { + rb, ok := j.(*routeTree) + if ok { + return rb + } + + x, ok := j.(*derivedTree) if !ok { + return nil + } + dp := x.clone().(*derivedTree) + + inner := makeRoute(dp.inner) + if inner == nil { + return nil + } + + dt := &derivedTable{ + tables: inner.tables, + query: dp.query, + predicates: inner.predicates, + leftJoins: inner.leftJoins, + alias: dp.alias, + } + + inner.tables = parenTables{dt} + inner.predicates = nil + inner.leftJoins = nil + return inner +} + +func joinTreesToRoutes(a, b queryTree) (*routeTree, *routeTree) { + aRoute := makeRoute(a) + if aRoute == nil { return nil, nil } - bRoute, ok := b.(*routePlan) - if !ok { + bRoute := makeRoute(b) + if bRoute == nil { return nil, nil } return aRoute, bRoute } -func createRoutePlanForInner(aRoute *routePlan, bRoute *routePlan, newTabletSet semantics.TableSet, joinPredicates []sqlparser.Expr) *routePlan { +func createRoutePlanForInner(aRoute *routeTree, bRoute *routeTree, newTabletSet semantics.TableSet, joinPredicates []sqlparser.Expr) *routeTree { var tables parenTables if !aRoute.hasOuterjoins() { tables = append(aRoute.tables, bRoute.tables...) @@ -782,7 +878,7 @@ func createRoutePlanForInner(aRoute *routePlan, bRoute *routePlan, newTabletSet } } - return &routePlan{ + return &routeTree{ routeOpCode: aRoute.routeOpCode, solved: newTabletSet, tables: tables, @@ -814,12 +910,12 @@ func findTables(deps semantics.TableSet, tables parenTables) (relation, relation return nil, nil, tables } -func createRoutePlanForOuter(aRoute, bRoute *routePlan, semTable *semantics.SemTable, newTabletSet semantics.TableSet, joinPredicates []sqlparser.Expr) *routePlan { +func createRoutePlanForOuter(aRoute, bRoute *routeTree, semTable *semantics.SemTable, newTabletSet semantics.TableSet, joinPredicates []sqlparser.Expr) *routeTree { // create relation slice with all tables tables := bRoute.tables // we are doing an outer join where the outer part contains multiple tables - we have to turn the outer part into a join or two for _, predicate := range bRoute.predicates { - deps := semTable.Dependencies(predicate) + deps := semTable.GetBaseTableDependencies(predicate) aTbl, bTbl, newTables := findTables(deps, tables) tables = newTables if aTbl != nil && bTbl != nil { @@ -839,7 +935,7 @@ func createRoutePlanForOuter(aRoute, bRoute *routePlan, semTable *semantics.SemT outer = tables } - return &routePlan{ + return &routeTree{ routeOpCode: aRoute.routeOpCode, solved: newTabletSet, tables: aRoute.tables, diff --git a/go/vt/vtgate/planbuilder/route_planning_test.go b/go/vt/vtgate/planbuilder/route_planning_test.go index e392ae25035..a5372759423 100644 --- a/go/vt/vtgate/planbuilder/route_planning_test.go +++ b/go/vt/vtgate/planbuilder/route_planning_test.go @@ -31,23 +31,23 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -func unsharded(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routePlan { - return &routePlan{ +func unsharded(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routeTree { + return &routeTree{ routeOpCode: engine.SelectUnsharded, solved: solved, keyspace: keyspace, } } -func selectDBA(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routePlan { - return &routePlan{ +func selectDBA(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routeTree { + return &routeTree{ routeOpCode: engine.SelectDBA, solved: solved, keyspace: keyspace, } } -func selectScatter(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routePlan { - return &routePlan{ +func selectScatter(solved semantics.TableSet, keyspace *vindexes.Keyspace) *routeTree { + return &routeTree{ routeOpCode: engine.SelectScatter, solved: solved, keyspace: keyspace, @@ -59,7 +59,7 @@ func TestMergeJoins(t *testing.T) { ks2 := &vindexes.Keyspace{Name: "banan", Sharded: false} type testCase struct { - l, r, expected joinTree + l, r, expected queryTree predicates []sqlparser.Expr } @@ -108,14 +108,14 @@ func TestMergeJoins(t *testing.T) { } func TestClone(t *testing.T) { - original := &routePlan{ + original := &routeTree{ routeOpCode: engine.SelectEqualUnique, vindexPreds: []*vindexPlusPredicates{{}}, } clone := original.clone() - clonedRP := clone.(*routePlan) + clonedRP := clone.(*routeTree) clonedRP.routeOpCode = engine.SelectDBA assert.Equal(t, clonedRP.routeOpCode, engine.SelectDBA) assert.Equal(t, original.routeOpCode, engine.SelectEqualUnique) @@ -148,7 +148,7 @@ func TestCreateRoutePlanForOuter(t *testing.T) { }, vtable: &vindexes.Table{}, } - a := &routePlan{ + a := &routeTree{ routeOpCode: engine.SelectUnsharded, solved: semantics.TableSet(1), tables: []relation{m1}, @@ -159,7 +159,7 @@ func TestCreateRoutePlanForOuter(t *testing.T) { col2 := sqlparser.NewColNameWithQualifier("id", sqlparser.TableName{ Name: sqlparser.NewTableIdent("m2"), }) - b := &routePlan{ + b := &routeTree{ routeOpCode: engine.SelectUnsharded, solved: semantics.TableSet(6), tables: []relation{m2, m3}, diff --git a/go/vt/vtgate/planbuilder/system_tables.go b/go/vt/vtgate/planbuilder/system_tables.go index 136918ff1fd..a6cc18d01ce 100644 --- a/go/vt/vtgate/planbuilder/system_tables.go +++ b/go/vt/vtgate/planbuilder/system_tables.go @@ -44,7 +44,7 @@ func (pb *primitiveBuilder) findSysInfoRoutingPredicates(expr sqlparser.Expr, ru return nil } -func (rp *routePlan) findSysInfoRoutingPredicatesGen4(reservedVars *sqlparser.ReservedVars) error { +func (rp *routeTree) findSysInfoRoutingPredicatesGen4(reservedVars *sqlparser.ReservedVars) error { for _, pred := range rp.predicates { isTableSchema, bvName, out, err := extractInfoSchemaRoutingPredicate(pred, reservedVars) if err != nil { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index 57a2fd9b4d0..21ea09210f7 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -1578,7 +1578,7 @@ Gen4 plan same as above } Gen4 plan same as above -# subquery +# derived table "select id from (select id, col from user where id = 5) as t" { "QueryType": "SELECT", @@ -1599,8 +1599,9 @@ Gen4 plan same as above "Vindex": "user_index" } } +Gen4 plan same as above -# subquery with join +# derived table with join "select t.id from (select id from user where id = 5) as t join user_extra on t.id = user_extra.user_id" { "QueryType": "SELECT", @@ -1621,8 +1622,27 @@ Gen4 plan same as above "Vindex": "user_index" } } +{ + "QueryType": "SELECT", + "Original": "select t.id from (select id from user where id = 5) as t join user_extra on t.id = user_extra.user_id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t, user_extra where 1 != 1", + "Query": "select t.id from (select id from `user` where id = 5) as t, user_extra where t.id = user_extra.user_id", + "Table": "`user`, user_extra", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} -# subquery with join, and aliased references +# derived table with join, and aliased references "select t.id from (select user.id from user where user.id = 5) as t join user_extra on t.id = user_extra.user_id" { "QueryType": "SELECT", @@ -1643,12 +1663,32 @@ Gen4 plan same as above "Vindex": "user_index" } } +{ + "QueryType": "SELECT", + "Original": "select t.id from (select user.id from user where user.id = 5) as t join user_extra on t.id = user_extra.user_id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select `user`.id from `user` where 1 != 1) as t, user_extra where 1 != 1", + "Query": "select t.id from (select `user`.id from `user` where `user`.id = 5) as t, user_extra where t.id = user_extra.user_id", + "Table": "`user`, user_extra", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} -# subquery with join, duplicate columns +# derived table with join, duplicate columns "select t.id from (select user.id, id from user where user.id = 5) as t join user_extra on t.id = user_extra.user_id" "duplicate column aliases: id" +Gen4 error: Duplicate column name 'id' -# subquery in RHS of join +# derived table in RHS of join "select t.id from user_extra join (select id from user where id = 5) as t on t.id = user_extra.user_id" { "QueryType": "SELECT", @@ -1665,8 +1705,27 @@ Gen4 plan same as above "Table": "user_extra, `user`" } } +{ + "QueryType": "SELECT", + "Original": "select t.id from user_extra join (select id from user where id = 5) as t on t.id = user_extra.user_id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from user_extra, (select id from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.id from user_extra, (select id from `user` where id = 5) as t where t.id = user_extra.user_id", + "Table": "`user`, user_extra", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} -# subquery in FROM with cross-shard join +# derived table in FROM with cross-shard join "select t.id from (select id from user where id = 5) as t join user_extra on t.id = user_extra.col" { "QueryType": "SELECT", @@ -1706,8 +1765,9 @@ Gen4 plan same as above ] } } +Gen4 plan same as above -# routing rules for subquery +# routing rules for derived table "select id from (select id, col from route1 where id = 5) as t" { "QueryType": "SELECT", @@ -1728,8 +1788,28 @@ Gen4 plan same as above "Vindex": "user_index" } } +Gen4 plan same as above -# routing rules for subquery where the constraint is in the outer query +# derived table missing columns +"select t.id from (select id from user) as t join user_extra on t.id = user_extra.user_id where t.col = 42" +{ + "QueryType": "SELECT", + "Original": "select t.id from (select id from user) as t join user_extra on t.id = user_extra.user_id where t.col = 42", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t join user_extra on t.id = user_extra.user_id where 1 != 1", + "Query": "select t.id from (select id from `user`) as t join user_extra on t.id = user_extra.user_id where t.col = 42", + "Table": "`user`, user_extra" + } +} +Gen4 error: symbol t.col not found + +# routing rules for derived table where the constraint is in the outer query "select id from (select id, col from route1) as t where id = 5" { "QueryType": "SELECT", @@ -1750,6 +1830,124 @@ Gen4 plan same as above "Vindex": "user_index" } } +{ + "QueryType": "SELECT", + "Original": "select id from (select id, col from route1) as t where id = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id, col from `user` as route1 where 1 != 1) as t where 1 != 1", + "Query": "select id from (select id, col from `user` as route1 where id = 5) as t", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} + +# routing rules for derived table where the constraint is in the outer query +"select id from (select id+col as foo from route1) as t where foo = 5" +{ + "QueryType": "SELECT", + "Original": "select id from (select id+col as foo from route1) as t where foo = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id + col as foo from `user` as route1 where 1 != 1) as t where 1 != 1", + "Query": "select id from (select id + col as foo from `user` as route1) as t where foo = 5", + "Table": "`user`" + } +} +{ + "QueryType": "SELECT", + "Original": "select id from (select id+col as foo from route1) as t where foo = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id + col as foo from `user` as route1 where 1 != 1) as t where 1 != 1", + "Query": "select id from (select id + col as foo from `user` as route1 where id + col = 5) as t", + "Table": "`user`" + } +} + +# push predicate on joined derived tables +"select t.id from (select id, textcol1 as baz from route1) as t join (select id, textcol1+textcol1 as baz from user) as s ON t.id = s.id WHERE t.baz = '3' AND s.baz = '3'" +{ + "QueryType": "SELECT", + "Original": "select t.id from (select id, textcol1 as baz from route1) as t join (select id, textcol1+textcol1 as baz from user) as s ON t.id = s.id WHERE t.baz = '3' AND s.baz = '3'", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select id, textcol1 as baz from `user` as route1 where 1 != 1) as t join (select id, textcol1 + textcol1 as baz from `user` where 1 != 1) as s on t.id = s.id where 1 != 1", + "Query": "select t.id from (select id, textcol1 as baz from `user` as route1) as t join (select id, textcol1 + textcol1 as baz from `user`) as s on t.id = s.id where t.baz = '3' and s.baz = '3'", + "Table": "`user`" + } +} +{ + "QueryType": "SELECT", + "Original": "select t.id from (select id, textcol1 as baz from route1) as t join (select id, textcol1+textcol1 as baz from user) as s ON t.id = s.id WHERE t.baz = '3' AND s.baz = '3'", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select id, textcol1 as baz from `user` as route1 where 1 != 1) as t, (select id, textcol1 + textcol1 as baz from `user` where 1 != 1) as s where 1 != 1", + "Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3') as s where t.id = s.id", + "Table": "`user`" + } +} + +# recursive derived table predicate push down +"select bar from (select foo+4 as bar from (select colA+colB as foo from user) as u) as t where bar = 5" +{ + "QueryType": "SELECT", + "Original": "select bar from (select foo+4 as bar from (select colA+colB as foo from user) as u) as t where bar = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select bar from (select foo + 4 as bar from (select colA + colB as foo from `user` where 1 != 1) as u where 1 != 1) as t where 1 != 1", + "Query": "select bar from (select foo + 4 as bar from (select colA + colB as foo from `user`) as u) as t where bar = 5", + "Table": "`user`" + } +} +{ + "QueryType": "SELECT", + "Original": "select bar from (select foo+4 as bar from (select colA+colB as foo from user) as u) as t where bar = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select bar from (select foo + 4 as bar from (select colA + colB as foo from `user` where 1 != 1) as u where 1 != 1) as t where 1 != 1", + "Query": "select bar from (select foo + 4 as bar from (select colA + colB as foo from `user` where colA + colB + 4 = 5) as u) as t", + "Table": "`user`" + } +} # recursive derived table lookups "select id from (select id from (select id from user) as u) as t where id = 5" @@ -1772,8 +1970,27 @@ Gen4 plan same as above "Vindex": "user_index" } } +{ + "QueryType": "SELECT", + "Original": "select id from (select id from (select id from user) as u) as t where id = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id from (select id from `user` where 1 != 1) as u where 1 != 1) as t where 1 != 1", + "Query": "select id from (select id from (select id from `user` where id = 5) as u) as t", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} -# merge subqueries with single-shard routes +# merge derived tables with single-shard routes "select u.col, e.col from (select col from user where id = 5) as u join (select col from user_extra where user_id = 5) as e" { "QueryType": "SELECT", @@ -1871,7 +2088,7 @@ Gen4 plan same as above } Gen4 plan same as above -# wire-up on join with cross-shard subquery +# wire-up on join with cross-shard derived table "select t.col1 from (select user.id, user.col1 from user join user_extra) as t join unsharded on unsharded.col1 = t.col1 and unsharded.id = t.id" { "QueryType": "SELECT", @@ -1936,7 +2153,7 @@ Gen4 plan same as above } } -# wire-up on within cross-shard subquery +# wire-up on within cross-shard derived table "select t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t" { "QueryType": "SELECT", @@ -1981,7 +2198,7 @@ Gen4 plan same as above } } -# Join with cross-shard subquery on rhs +# Join with cross-shard derived table on rhs "select t.col1 from unsharded_a ua join (select user.id, user.col1 from user join user_extra) as t" { "QueryType": "SELECT", @@ -2044,6 +2261,116 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select t.col1 from unsharded_a ua join (select user.id, user.col1 from user join user_extra) as t", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "2", + "TableName": "unsharded_a_`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1 from unsharded_a as ua where 1 != 1", + "Query": "select 1 from unsharded_a as ua", + "Table": "unsharded_a" + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col1 from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + } +} + +# Join with cross-shard derived table on rhs - push down join predicate to derived table +"select t.col1 from unsharded_a ua join (select user.id, user.col1 from user join user_extra) as t on t.id = ua.id" +"unsupported: filtering on results of cross-shard subquery" +{ + "QueryType": "SELECT", + "Original": "select t.col1 from unsharded_a ua join (select user.id, user.col1 from user join user_extra) as t on t.id = ua.id", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "2", + "TableName": "unsharded_a_`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select ua.id from unsharded_a as ua where 1 != 1", + "Query": "select ua.id from unsharded_a as ua", + "Table": "unsharded_a" + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col1 from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra where :user_id = :ua_id", + "Table": "user_extra" + } + ] + } + ] + } +} # subquery in ON clause, single route "select unsharded_a.col from unsharded_a join unsharded_b on (select col from user)" @@ -2456,7 +2783,7 @@ Gen4 plan same as above } } -# subquery with join primitive (FROM) +# derived table with join primitive (FROM) "select id, t.id from (select user.id from user join user_extra) as t" { "QueryType": "SELECT", diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt index 3f081a470f3..c9f459ad8a8 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt @@ -404,9 +404,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select a, `user`.textcol1, b, weight_string(a), textcol1, weight_string(textcol1), weight_string(b) from `user` where 1 != 1", - "OrderBy": "(0|3) ASC, (4|5) ASC, (2|6) ASC", - "Query": "select a, `user`.textcol1, b, weight_string(a), textcol1, weight_string(textcol1), weight_string(b) from `user` order by a asc, textcol1 asc, b asc", + "FieldQuery": "select a, `user`.textcol1, b, weight_string(a), weight_string(textcol1), weight_string(b) from `user` where 1 != 1", + "OrderBy": "(0|3) ASC, (1|4) ASC, (2|5) ASC", + "Query": "select a, `user`.textcol1, b, weight_string(a), weight_string(textcol1), weight_string(b) from `user` order by a asc, textcol1 asc, b asc", "ResultColumns": 3, "Table": "`user`" } diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 44bc88e0f51..9664c6f5303 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -22,6 +22,8 @@ import ( "strconv" "strings" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/vindexes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -34,12 +36,14 @@ type ( analyzer struct { si SchemaInformation - Tables []TableInfo - scopes []*scope - exprDeps ExprDependencies - err error - currentDb string - inProjection []bool + Tables []TableInfo + scopes []*scope + exprRecursiveDeps ExprDependencies + exprDeps ExprDependencies + exprTypes map[sqlparser.Expr]querypb.Type + err error + currentDb string + inProjection []bool rScope map[*sqlparser.Select]*scope wScope map[*sqlparser.Select]*scope @@ -50,11 +54,13 @@ type ( // newAnalyzer create the semantic analyzer func newAnalyzer(dbName string, si SchemaInformation) *analyzer { return &analyzer{ - exprDeps: map[sqlparser.Expr]TableSet{}, - rScope: map[*sqlparser.Select]*scope{}, - wScope: map[*sqlparser.Select]*scope{}, - currentDb: dbName, - si: si, + exprRecursiveDeps: map[sqlparser.Expr]TableSet{}, + exprDeps: map[sqlparser.Expr]TableSet{}, + exprTypes: map[sqlparser.Expr]querypb.Type{}, + rScope: map[*sqlparser.Select]*scope{}, + wScope: map[*sqlparser.Select]*scope{}, + currentDb: dbName, + si: si, } } @@ -66,7 +72,15 @@ func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInf if err != nil { return nil, err } - return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr, Comments: statement.GetComments()}, nil + return &SemTable{ + ExprBaseTableDeps: analyzer.exprRecursiveDeps, + ExprDeps: analyzer.exprDeps, + exprTypes: analyzer.exprTypes, + Tables: analyzer.Tables, + selectScope: analyzer.rScope, + ProjectionErr: analyzer.projErr, + Comments: statement.GetComments(), + }, nil } func (a *analyzer) setError(err error) { @@ -101,56 +115,27 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.rScope[node] = currScope a.wScope[node] = newScope(nil) - case *sqlparser.DerivedTable: - a.setError(Gen4NotSupportedF("derived tables")) case *sqlparser.Subquery: a.setError(Gen4NotSupportedF("subquery")) case sqlparser.TableExpr: if isParentSelect(cursor) { a.push(newScope(nil)) } - switch node := node.(type) { - case *sqlparser.AliasedTableExpr: - a.setError(a.bindTable(node, node.Expr)) - case *sqlparser.JoinTableExpr: - if node.Condition.Using != nil { - a.setError(vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: join with USING(column_list) clause for complex queries")) - } - if node.Join == sqlparser.NaturalJoinType || node.Join == sqlparser.NaturalRightJoinType || node.Join == sqlparser.NaturalLeftJoinType { - a.setError(vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: "+node.Join.ToString())) - } - - } case *sqlparser.Union: a.push(newScope(current)) case sqlparser.SelectExprs: - if isParentSelect(cursor) { - a.inProjection = append(a.inProjection, true) - } sel, ok := cursor.Parent().(*sqlparser.Select) if !ok { break } + a.inProjection = append(a.inProjection, true) wScope, exists := a.wScope[sel] if !exists { break } - vTbl := &vTableInfo{} - for _, selectExpr := range node { - expr, ok := selectExpr.(*sqlparser.AliasedExpr) - if !ok { - continue - } - vTbl.cols = append(vTbl.cols, expr.Expr) - if !expr.As.IsEmpty() { - vTbl.columnNames = append(vTbl.columnNames, expr.As.String()) - } else { - vTbl.columnNames = append(vTbl.columnNames, sqlparser.String(expr)) - } - } - wScope.tables = append(wScope.tables, vTbl) + wScope.tables = append(wScope.tables, a.createVTableInfoForExpressions(node)) case sqlparser.OrderBy: a.changeScopeForOrderBy(cursor) case *sqlparser.Order: @@ -161,11 +146,15 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.analyzeOrderByGroupByExprForLiteral(grpExpr, "group statement") } case *sqlparser.ColName: - t, err := a.resolveColumn(node, current) + tsRecursive, ts, qt, err := a.resolveColumn(node, current) if err != nil { a.setError(err) } else { - a.exprDeps[node] = t + a.exprRecursiveDeps[node] = tsRecursive + a.exprDeps[node] = ts + if qt != nil { + a.exprTypes[node] = *qt + } } case *sqlparser.FuncExpr: if node.Distinct { @@ -215,12 +204,12 @@ func (a *analyzer) analyzeOrderByGroupByExprForLiteral(input sqlparser.Expr, cal _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { expr, ok := node.(sqlparser.Expr) if ok { - deps = deps.Merge(a.exprDeps[expr]) + deps = deps.Merge(a.exprRecursiveDeps[expr]) } return true, nil }, expr.Expr) - a.exprDeps[input] = deps + a.exprRecursiveDeps[input] = deps } func (a *analyzer) changeScopeForOrderBy(cursor *sqlparser.Cursor) { @@ -247,65 +236,125 @@ func isParentSelect(cursor *sqlparser.Cursor) bool { return isSelect } -func (a *analyzer) resolveColumn(colName *sqlparser.ColName, current *scope) (TableSet, error) { +func (a *analyzer) resolveColumn(colName *sqlparser.ColName, current *scope) (TableSet, TableSet, *querypb.Type, error) { if colName.Qualifier.IsEmpty() { return a.resolveUnQualifiedColumn(current, colName) } - t, err := a.resolveQualifiedColumn(current, colName) - if err != nil { - return 0, err + return a.resolveQualifiedColumn(current, colName) +} + +// tableInfoFor returns the table info for the table set. It should contains only single table. +func (a *analyzer) tableInfoFor(id TableSet) (TableInfo, error) { + numberOfTables := id.NumberOfTables() + if numberOfTables == 0 { + return nil, nil } - if t == nil { - return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(colName))) + if numberOfTables > 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] should only be used for single tables") } - return a.tableSetFor(t.GetExpr()), nil + return a.Tables[id.TableOffset()], nil } // resolveQualifiedColumn handles `tabl.col` expressions -func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableInfo, error) { +func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableSet, TableSet, *querypb.Type, error) { // search up the scope stack until we find a match for current != nil { for _, table := range current.tables { - if table.Matches(expr.Qualifier) { - return table, nil + if !table.Matches(expr.Qualifier) { + continue + } + if table.IsActualTable() { + actualTable, ts, typ := a.resolveQualifiedColumnOnActualTable(table, expr) + return actualTable, ts, typ, nil + } + recursiveTs, typ, err := table.RecursiveDepsFor(expr, a, len(current.tables) == 1) + if err != nil { + return 0, 0, nil, err } + if recursiveTs == nil { + return 0, 0, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "symbol %s not found", sqlparser.String(expr)) + } + + ts, err := table.DepsFor(expr, a, len(current.tables) == 1) + if err != nil { + return 0, 0, nil, err + } + return *recursiveTs, *ts, typ, nil } current = current.parent } - return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "symbol %s not found", sqlparser.String(expr)) + return 0, 0, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "symbol %s not found", sqlparser.String(expr)) +} + +func (a *analyzer) resolveQualifiedColumnOnActualTable(table TableInfo, expr *sqlparser.ColName) (TableSet, TableSet, *querypb.Type) { + ts := a.tableSetFor(table.GetExpr()) + for _, colInfo := range table.GetColumns() { + if expr.Name.EqualString(colInfo.Name) { + // A column can't be of type NULL, that is the default value indicating that we dont know the actual type + // But expressions can be of NULL type, so we use nil to represent an unknown type + if colInfo.Type == querypb.Type_NULL_TYPE { + return ts, ts, nil + } + return ts, ts, &colInfo.Type + } + } + return ts, ts, nil } type originable interface { tableSetFor(t *sqlparser.AliasedTableExpr) TableSet - depsForExpr(expr sqlparser.Expr) TableSet + depsForExpr(expr sqlparser.Expr) (TableSet, *querypb.Type) } -func (a *analyzer) depsForExpr(expr sqlparser.Expr) TableSet { - return a.exprDeps.Dependencies(expr) +func (a *analyzer) depsForExpr(expr sqlparser.Expr) (TableSet, *querypb.Type) { + ts := a.exprRecursiveDeps.Dependencies(expr) + qt, isFound := a.exprTypes[expr] + if !isFound { + return ts, nil + } + return ts, &qt } // resolveUnQualifiedColumn -func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableSet, error) { - var tsp *TableSet - -tryAgain: - for _, tbl := range current.tables { - ts := tbl.DepsFor(expr, a, len(current.tables) == 1) - if ts != nil && tsp != nil { - return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) - } - if ts != nil { - tsp = ts +func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableSet, TableSet, *querypb.Type, error) { + var tspRecursive, tsp *TableSet + var typp *querypb.Type + + for current != nil && tspRecursive == nil { + for _, tbl := range current.tables { + recursiveTs, typ, err := tbl.RecursiveDepsFor(expr, a, len(current.tables) == 1) + if err != nil { + return 0, 0, nil, err + } + if recursiveTs != nil && tspRecursive != nil { + return 0, 0, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) + } + if recursiveTs != nil { + tspRecursive = recursiveTs + typp = typ + } + if tbl.IsActualTable() { + continue + } + ts, err := tbl.DepsFor(expr, a, len(current.tables) == 1) + if err != nil { + return 0, 0, nil, err + } + if ts != nil { + tsp = ts + } } - } - if tsp == nil && current.parent != nil { + current = current.parent - goto tryAgain + } + + if tspRecursive == nil { + return 0, 0, nil, nil } if tsp == nil { - return 0, nil + return *tspRecursive, 0, typp, nil } - return *tsp, nil + return *tspRecursive, *tsp, typp, nil } func (a *analyzer) tableSetFor(t *sqlparser.AliasedTableExpr) TableSet { @@ -342,7 +391,22 @@ func (a *analyzer) createTable(t sqlparser.TableName, alias *sqlparser.AliasedTa func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.SimpleTableExpr) error { switch t := expr.(type) { case *sqlparser.DerivedTable: - return Gen4NotSupportedF("derived table") + sel, isSelect := t.Select.(*sqlparser.Select) + if !isSelect { + return Gen4NotSupportedF("union in derived table") + } + + tableInfo := a.createVTableInfoForExpressions(sel.SelectExprs) + if err := tableInfo.checkForDuplicates(); err != nil { + return err + } + + tableInfo.ASTNode = alias + tableInfo.tableName = alias.As.String() + + a.Tables = append(a.Tables, tableInfo) + scope := a.currentScope() + return scope.addTable(tableInfo) case sqlparser.TableName: var tbl *vindexes.Table var isInfSchema bool @@ -367,6 +431,20 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S return nil } +func (v *vTableInfo) checkForDuplicates() error { + for i, name := range v.columnNames { + for j, name2 := range v.columnNames { + if i == j { + continue + } + if name == name2 { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DupFieldName, "Duplicate column name '%s'", name) + } + } + } + return nil +} + func (a *analyzer) analyze(statement sqlparser.Statement) error { _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) return a.err @@ -376,7 +454,7 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { if !a.shouldContinue() { return false } - switch cursor.Node().(type) { + switch node := cursor.Node().(type) { case sqlparser.SelectExprs: if isParentSelect(cursor) { a.popProjection() @@ -397,11 +475,45 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { } } } + switch node := node.(type) { + case *sqlparser.AliasedTableExpr: + a.setError(a.bindTable(node, node.Expr)) + case *sqlparser.JoinTableExpr: + if node.Condition.Using != nil { + a.setError(vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: join with USING(column_list) clause for complex queries")) + } + if node.Join == sqlparser.NaturalJoinType || node.Join == sqlparser.NaturalRightJoinType || node.Join == sqlparser.NaturalLeftJoinType { + a.setError(vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: "+node.Join.ToString())) + } + } } return a.shouldContinue() } +func (a *analyzer) createVTableInfoForExpressions(expressions sqlparser.SelectExprs) *vTableInfo { + vTbl := &vTableInfo{} + for _, selectExpr := range expressions { + expr, ok := selectExpr.(*sqlparser.AliasedExpr) + if !ok { + continue + } + vTbl.cols = append(vTbl.cols, expr.Expr) + if expr.As.IsEmpty() { + switch expr := expr.Expr.(type) { + case *sqlparser.ColName: + // for projections, we strip out the qualifier and keep only the column name + vTbl.columnNames = append(vTbl.columnNames, expr.Name.String()) + default: + vTbl.columnNames = append(vTbl.columnNames, sqlparser.String(expr)) + } + } else { + vTbl.columnNames = append(vTbl.columnNames, expr.As.String()) + } + } + return vTbl +} + func (a *analyzer) popProjection() { a.inProjection = a.inProjection[:len(a.inProjection)-1] } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 2caec9cf235..2eb6191b751 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -58,7 +58,7 @@ from x as t` // extract the `t.col2` expression from the subquery sel2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery).Select.(*sqlparser.Select) - s1 := semTable.Dependencies(extract(sel2, 0)) + s1 := semTable.GetBaseTableDependencies(extract(sel2, 0)) // if scoping works as expected, we should be able to see the inner table being used by the inner expression assert.Equal(t, T2, s1) @@ -82,7 +82,7 @@ func TestBindingSingleTable(t *testing.T) { ts := semTable.TableSetFor(t1) assert.EqualValues(t, 1, ts) - d := semTable.Dependencies(extract(sel, 0)) + d := semTable.GetBaseTableDependencies(extract(sel, 0)) require.Equal(t, T1, d, query) }) } @@ -136,7 +136,7 @@ func TestOrderByBindingSingleTable(t *testing.T) { stmt, semTable := parseAndAnalyze(t, tc.sql, "d") sel, _ := stmt.(*sqlparser.Select) order := sel.OrderBy[0].Expr - d := semTable.Dependencies(order) + d := semTable.GetBaseTableDependencies(order) require.Equal(t, tc.deps, d, tc.sql) }) } @@ -176,14 +176,14 @@ func TestGroupByBindingSingleTable(t *testing.T) { T1, }, { "select t.id from t, t1 group by id", - T2, + T1, }} for _, tc := range tcases { t.Run(tc.sql, func(t *testing.T) { stmt, semTable := parseAndAnalyze(t, tc.sql, "d") sel, _ := stmt.(*sqlparser.Select) grp := sel.GroupBy[0] - d := semTable.Dependencies(grp) + d := semTable.GetBaseTableDependencies(grp) require.Equal(t, tc.deps, d, tc.sql) }) } @@ -205,7 +205,7 @@ func TestBindingSingleAliasedTable(t *testing.T) { ts := semTable.TableSetFor(t1) assert.EqualValues(t, 1, ts) - d := semTable.Dependencies(extract(sel, 0)) + d := semTable.GetBaseTableDependencies(extract(sel, 0)) require.Equal(t, T1, d, query) }) } @@ -248,8 +248,8 @@ func TestUnion(t *testing.T) { assert.EqualValues(t, 1, ts1) assert.EqualValues(t, 2, ts2) - d1 := semTable.Dependencies(extract(sel1, 0)) - d2 := semTable.Dependencies(extract(sel2, 0)) + d1 := semTable.GetBaseTableDependencies(extract(sel1, 0)) + d2 := semTable.GetBaseTableDependencies(extract(sel2, 0)) assert.Equal(t, T1, d1) assert.Equal(t, T2, d2) } @@ -306,7 +306,7 @@ func TestBindingMultiTable(t *testing.T) { t.Run(query.query, func(t *testing.T) { stmt, semTable := parseAndAnalyze(t, query.query, "user") sel, _ := stmt.(*sqlparser.Select) - assert.Equal(t, query.deps, semTable.Dependencies(extract(sel, 0)), query.query) + assert.Equal(t, query.deps, semTable.GetBaseTableDependencies(extract(sel, 0)), query.query) }) } }) @@ -338,7 +338,7 @@ func TestBindingSingleDepPerTable(t *testing.T) { stmt, semTable := parseAndAnalyze(t, query, "") sel, _ := stmt.(*sqlparser.Select) - d := semTable.Dependencies(extract(sel, 0)) + d := semTable.GetBaseTableDependencies(extract(sel, 0)) assert.Equal(t, 1, d.NumberOfTables(), "size wrong") assert.Equal(t, T1, d) } @@ -405,17 +405,29 @@ func TestUnknownColumnMap2(t *testing.T) { Name: sqlparser.NewTableIdent("a"), Columns: []vindexes.Column{{ Name: sqlparser.NewColIdent("col"), - Type: querypb.Type_VARCHAR, + Type: querypb.Type_INT32, + }}, + ColumnListAuthoritative: true, + } + authoritativeTblBWithInt := vindexes.Table{ + Name: sqlparser.NewTableIdent("b"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("col"), + Type: querypb.Type_INT32, }}, ColumnListAuthoritative: true, } - parse, _ := sqlparser.Parse(query) + varchar := querypb.Type_VARCHAR + int := querypb.Type_INT32 + parse, _ := sqlparser.Parse(query) + expr := extract(parse.(*sqlparser.Select), 0) tests := []struct { name string schema map[string]*vindexes.Table err bool + typ *querypb.Type }{ { name: "no info about tables", @@ -431,16 +443,25 @@ func TestUnknownColumnMap2(t *testing.T) { name: "non authoritative columns - one authoritative and one not", schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &authoritativeTblB}, err: false, + typ: &varchar, }, { name: "non authoritative columns - one authoritative and one not", schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &nonAuthoritativeTblB}, err: false, + typ: &varchar, }, { name: "authoritative columns", schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblB}, err: false, + typ: &varchar, + }, + { + name: "authoritative columns", + schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblBWithInt}, + err: false, + typ: &int, }, { name: "authoritative columns with overlap", @@ -458,6 +479,8 @@ func TestUnknownColumnMap2(t *testing.T) { require.Error(t, tbl.ProjectionErr) } else { require.NoError(t, tbl.ProjectionErr) + typ := tbl.TypeFor(expr) + assert.Equal(t, test.typ, typ) } }) } @@ -526,6 +549,73 @@ func TestScoping(t *testing.T) { } } +func TestScopingWDerivedTables(t *testing.T) { + queries := []struct { + query string + errorMessage string + recursiveExpectation TableSet + expectation TableSet + }{ + { + query: "select id from (select id from user where id = 5) as t", + recursiveExpectation: T1, + expectation: T2, + }, { + query: "select id from (select foo as id from user) as t", + recursiveExpectation: T1, + expectation: T2, + }, { + query: "select id from (select foo as id from (select x as foo from user) as c) as t", + recursiveExpectation: T1, + expectation: T3, + }, { + query: "select t.id from (select foo as id from user) as t", + recursiveExpectation: T1, + expectation: T2, + }, { + query: "select t.id2 from (select foo as id from user) as t", + errorMessage: "symbol t.id2 not found", + }, { + query: "select id from (select 42 as id) as t", + recursiveExpectation: T0, + expectation: T2, + }, { + query: "select t.id from (select 42 as id) as t", + recursiveExpectation: T0, + expectation: T2, + }, { + query: "select ks.t.id from (select 42 as id) as t", + errorMessage: "symbol ks.t.id not found", + }, { + query: "select * from (select id, id from user) as t", + errorMessage: "Duplicate column name 'id'", + }, { + query: "select t.baz = 1 from (select id as baz from user) as t", + expectation: T2, + recursiveExpectation: T1, + }, + } + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewTableIdent("t")}, + }, + }) + if query.errorMessage != "" { + require.EqualError(t, err, query.errorMessage) + } else { + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursiveExpectation, st.GetBaseTableDependencies(extract(sel, 0))) + assert.Equal(t, query.expectation, st.Dependencies(extract(sel, 0))) + } + }) + } +} + func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) { t.Helper() parse, err := sqlparser.Parse(query) diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 7e0321e4266..d93726efffc 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -35,9 +35,17 @@ type ( Name() (sqlparser.TableName, error) GetExpr() *sqlparser.AliasedTableExpr GetColumns() []ColumnInfo - IsVirtual() bool - DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet + IsActualTable() bool + + // RecursiveDepsFor returns a pointer to the table set for the table that this column belongs to, if it can be found + // if the column is not found, nil will be returned instead. If the column is a derived table column, this method + // will recursively find the dependencies of the expression inside the derived table + RecursiveDepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, *querypb.Type, error) + + // DepsFor finds the table that a column depends on. No recursing is done on derived tables + DepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, error) IsInfSchema() bool + GetExprFor(s string) (sqlparser.Expr, error) } // ColumnInfo contains information about columns @@ -62,7 +70,11 @@ type ( isInfSchema bool } + // vTableInfo is used to represent projected results, not real tables. It is used for + // ORDER BY and GROUP BY that need to access result columns, and also for derived tables. vTableInfo struct { + tableName string + ASTNode *sqlparser.AliasedTableExpr columnNames []string cols []sqlparser.Expr } @@ -80,10 +92,20 @@ type ( Tables []TableInfo // ProjectionErr stores the error that we got during the semantic analysis of the SelectExprs. // This is only a real error if we are unable to plan the query as a single route - ProjectionErr error - exprDependencies ExprDependencies - selectScope map[*sqlparser.Select]*scope - Comments sqlparser.Comments + ProjectionErr error + + // ExprBaseTableDeps contains the dependencies from the expression to the actual tables + // in the query (i.e. not including derived tables). If an expression is a column on a derived table, + // this map will contain the accumulated dependencies for the column expression inside the derived table + ExprBaseTableDeps ExprDependencies + + // ExprDeps keeps information about dependencies for expressions, no matter if they are + // against real tables or derived tables + ExprDeps ExprDependencies + + exprTypes map[sqlparser.Expr]querypb.Type + selectScope map[*sqlparser.Select]*scope + Comments sqlparser.Comments } scope struct { @@ -98,48 +120,116 @@ type ( } ) -// DepsFor implements the TableInfo interface -func (v *vTableInfo) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { - if !col.Qualifier.IsEmpty() { - return nil +// GetExprFor implements the TableInfo interface +func (v *vTableInfo) GetExprFor(s string) (sqlparser.Expr, error) { + for i, colName := range v.columnNames { + if colName == s { + return v.cols[i], nil + } + } + return nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", s) +} + +// GetExprFor implements the TableInfo interface +func (a *AliasedTable) GetExprFor(s string) (sqlparser.Expr, error) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Unknown column '%s' in 'field list'", s) +} + +// GetExprFor implements the TableInfo interface +func (r *RealTable) GetExprFor(s string) (sqlparser.Expr, error) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Unknown column '%s' in 'field list'", s) +} + +// RecursiveDepsFor implements the TableInfo interface +func (v *vTableInfo) RecursiveDepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, *querypb.Type, error) { + if !col.Qualifier.IsEmpty() && (v.ASTNode == nil || v.tableName != col.Qualifier.Name.String()) { + // if we have a table qualifier in the expression, we know that it is not referencing an aliased table + return nil, nil, nil } for i, colName := range v.columnNames { if col.Name.String() == colName { - ts := org.depsForExpr(v.cols[i]) - return &ts + ts, qt := org.depsForExpr(v.cols[i]) + return &ts, qt, nil } } - return nil + return nil, nil, nil } // DepsFor implements the TableInfo interface -func (a *AliasedTable) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { - if single { - ts := org.tableSetFor(a.ASTNode) - return &ts +func (v *vTableInfo) DepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, error) { + if v.ASTNode == nil { + return nil, nil } - for _, info := range a.GetColumns() { - if col.Name.String() == info.Name { - ts := org.tableSetFor(a.ASTNode) - return &ts + if !col.Qualifier.IsEmpty() && (v.ASTNode == nil || v.tableName != col.Qualifier.Name.String()) { + // if we have a table qualifier in the expression, we know that it is not referencing an aliased table + return nil, nil + } + for _, colName := range v.columnNames { + if col.Name.String() == colName { + ts := org.tableSetFor(v.ASTNode) + return &ts, nil } } - return nil + return nil, nil +} + +// RecursiveDepsFor implements the TableInfo interface +func (a *AliasedTable) RecursiveDepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, *querypb.Type, error) { + return depsFor(col, org, single, a.ASTNode, a.GetColumns(), a.Authoritative()) } // DepsFor implements the TableInfo interface -func (r *RealTable) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { +func (a *AliasedTable) DepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, error) { + ts, _, err := a.RecursiveDepsFor(col, org, single) + return ts, err +} + +// RecursiveDepsFor implements the TableInfo interface +func (r *RealTable) RecursiveDepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, *querypb.Type, error) { + return depsFor(col, org, single, r.ASTNode, r.GetColumns(), r.Authoritative()) +} + +// DepsFor implements the TableInfo interface +func (r *RealTable) DepsFor(col *sqlparser.ColName, org originable, single bool) (*TableSet, error) { + ts, _, err := r.RecursiveDepsFor(col, org, single) + return ts, err +} + +// depsFor implements the TableInfo interface for RealTable and AliasedTable +func depsFor( + col *sqlparser.ColName, + org originable, + single bool, + astNode *sqlparser.AliasedTableExpr, + cols []ColumnInfo, + authoritative bool, +) (*TableSet, *querypb.Type, error) { + // if we know that we are the only table in the scope, there is no doubt - the column must belong to the table if single { - ts := org.tableSetFor(r.ASTNode) - return &ts + ts := org.tableSetFor(astNode) + + for _, info := range cols { + if col.Name.String() == info.Name { + return &ts, &info.Type, nil + } + } + + if authoritative { + // if we are authoritative and we can't find the column, we should fail + return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col.Name.String()) + } + + // it's probably the correct table, but we don't have enough info to be sure or figure out the type of the column + return &ts, nil, nil } - for _, info := range r.GetColumns() { + + for _, info := range cols { if col.Name.String() == info.Name { - ts := org.tableSetFor(r.ASTNode) - return &ts + ts := org.tableSetFor(astNode) + return &ts, &info.Type, nil } } - return nil + return nil, nil, nil } // IsInfSchema implements the TableInfo interface @@ -157,19 +247,19 @@ func (r *RealTable) IsInfSchema() bool { return r.isInfSchema } -// IsVirtual implements the TableInfo interface -func (v *vTableInfo) IsVirtual() bool { - return true +// IsActualTable implements the TableInfo interface +func (v *vTableInfo) IsActualTable() bool { + return false } -// IsVirtual implements the TableInfo interface -func (a *AliasedTable) IsVirtual() bool { - return false +// IsActualTable implements the TableInfo interface +func (a *AliasedTable) IsActualTable() bool { + return true } -// IsVirtual implements the TableInfo interface -func (r *RealTable) IsVirtual() bool { - return false +// IsActualTable implements the TableInfo interface +func (r *RealTable) IsActualTable() bool { + return true } var _ TableInfo = (*RealTable)(nil) @@ -177,7 +267,7 @@ var _ TableInfo = (*AliasedTable)(nil) var _ TableInfo = (*vTableInfo)(nil) func (v *vTableInfo) Matches(name sqlparser.TableName) bool { - return false + return v.tableName == name.Name.String() && name.Qualifier.IsEmpty() } func (v *vTableInfo) Authoritative() bool { @@ -185,11 +275,11 @@ func (v *vTableInfo) Authoritative() bool { } func (v *vTableInfo) Name() (sqlparser.TableName, error) { - return sqlparser.TableName{}, nil + return v.ASTNode.TableName() } func (v *vTableInfo) GetExpr() *sqlparser.AliasedTableExpr { - return nil + return v.ASTNode } func (v *vTableInfo) GetColumns() []ColumnInfo { @@ -273,7 +363,7 @@ func (r *RealTable) Matches(name sqlparser.TableName) bool { // NewSemTable creates a new empty SemTable func NewSemTable() *SemTable { - return &SemTable{exprDependencies: map[sqlparser.Expr]TableSet{}} + return &SemTable{ExprBaseTableDeps: map[sqlparser.Expr]TableSet{}} } // TableSetFor returns the bitmask for this particular table @@ -294,9 +384,43 @@ func (st *SemTable) TableInfoFor(id TableSet) (TableInfo, error) { return st.Tables[id.TableOffset()], nil } +// GetBaseTableDependencies return the table dependencies of the expression. +func (st *SemTable) GetBaseTableDependencies(expr sqlparser.Expr) TableSet { + return st.ExprBaseTableDeps.Dependencies(expr) +} + // Dependencies return the table dependencies of the expression. func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { - return st.exprDependencies.Dependencies(expr) + return st.ExprDeps.Dependencies(expr) +} + +// TableInfoForExpr returns the table info of the table that this expression depends on. +// Careful: this only works for expressions that have a single table dependency +func (st *SemTable) TableInfoForExpr(expr sqlparser.Expr) (TableInfo, error) { + return st.TableInfoFor(st.ExprDeps.Dependencies(expr)) +} + +// GetSelectTables returns the table in the select. +func (st *SemTable) GetSelectTables(node *sqlparser.Select) []TableInfo { + scope := st.selectScope[node] + return scope.tables +} + +// AddExprs adds new select exprs to the SemTable. +func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) { + tableSet := st.TableSetFor(tbl) + for _, col := range cols { + st.ExprBaseTableDeps[col.(*sqlparser.AliasedExpr).Expr] = tableSet + } +} + +// TypeFor returns the type of expressions in the query +func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type { + typ, found := st.exprTypes[e] + if found { + return &typ + } + return nil } // Dependencies return the table dependencies of the expression. This method finds table dependencies recursively @@ -318,20 +442,6 @@ func (d ExprDependencies) Dependencies(expr sqlparser.Expr) TableSet { return deps } -// GetSelectTables returns the table in the select. -func (st *SemTable) GetSelectTables(node *sqlparser.Select) []TableInfo { - scope := st.selectScope[node] - return scope.tables -} - -// AddExprs adds new select exprs to the SemTable. -func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) { - tableSet := st.TableSetFor(tbl) - for _, col := range cols { - st.exprDependencies[col.(*sqlparser.AliasedExpr).Expr] = tableSet - } -} - func newScope(parent *scope) *scope { return &scope{parent: parent} } @@ -395,3 +505,24 @@ func (ts TableSet) Constituents() (result []TableSet) { func (ts TableSet) Merge(other TableSet) TableSet { return ts | other } + +// RewriteDerivedExpression rewrites all the ColName instances in the supplied expression with +// the expressions behind the column definition of the derived table +// SELECT foo FROM (SELECT id+42 as foo FROM user) as t +// We need `foo` to be translated to `id+42` on the inside of the derived table +func RewriteDerivedExpression(expr sqlparser.Expr, vt TableInfo) (sqlparser.Expr, error) { + newExpr := sqlparser.CloneExpr(expr) + sqlparser.Rewrite(newExpr, func(cursor *sqlparser.Cursor) bool { + switch node := cursor.Node().(type) { + case *sqlparser.ColName: + exp, err := vt.GetExprFor(node.Name.String()) + if err != nil { + return false + } + cursor.Replace(exp) + return false + } + return true + }, nil) + return newExpr, nil +}