diff --git a/ast/dml.go b/ast/dml.go index 2046f3ad2..7f4e24fd0 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -739,11 +739,13 @@ func (n *InsertStmt) Accept(v Visitor) (Node, bool) { n.Select = node.(ResultSetNode) } - node, ok := n.Table.Accept(v) - if !ok { - return n, false + if n.Table != nil { + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableRefsClause) } - n.Table = node.(*TableRefsClause) for i, val := range n.Columns { node, ok := val.Accept(v) @@ -807,34 +809,36 @@ func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { } n = newNode.(*DeleteStmt) - node, ok := n.TableRefs.Accept(v) - if !ok { - return n, false - } - n.TableRefs = node.(*TableRefsClause) + if n.TableRefs != nil { + node, ok := n.TableRefs.Accept(v) + if !ok { + return n, false + } + n.TableRefs = node.(*TableRefsClause) - node, ok = n.Tables.Accept(v) - if !ok { - return n, false + node, ok = n.Tables.Accept(v) + if !ok { + return n, false + } + n.Tables = node.(*DeleteTableList) } - n.Tables = node.(*DeleteTableList) if n.Where != nil { - node, ok = n.Where.Accept(v) + node, ok := n.Where.Accept(v) if !ok { return n, false } n.Where = node.(ExprNode) } if n.Order != nil { - node, ok = n.Order.Accept(v) + node, ok := n.Order.Accept(v) if !ok { return n, false } n.Order = node.(*OrderByClause) } if n.Limit != nil { - node, ok = n.Limit.Accept(v) + node, ok := n.Limit.Accept(v) if !ok { return n, false } @@ -866,34 +870,36 @@ func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { return v.Leave(newNode) } n = newNode.(*UpdateStmt) - node, ok := n.TableRefs.Accept(v) - if !ok { - return n, false - } - n.TableRefs = node.(*TableRefsClause) - for i, val := range n.List { - node, ok = val.Accept(v) + if n.TableRefs != nil { + node, ok := n.TableRefs.Accept(v) if !ok { return n, false } - n.List[i] = node.(*Assignment) + n.TableRefs = node.(*TableRefsClause) + for i, val := range n.List { + node, ok = val.Accept(v) + if !ok { + return n, false + } + n.List[i] = node.(*Assignment) + } } if n.Where != nil { - node, ok = n.Where.Accept(v) + node, ok := n.Where.Accept(v) if !ok { return n, false } n.Where = node.(ExprNode) } if n.Order != nil { - node, ok = n.Order.Accept(v) + node, ok := n.Order.Accept(v) if !ok { return n, false } n.Order = node.(*OrderByClause) } if n.Limit != nil { - node, ok = n.Limit.Accept(v) + node, ok := n.Limit.Accept(v) if !ok { return n, false }