Skip to content

Commit

Permalink
opt: add rule to replace ST_Distance with ST_DWithin
Browse files Browse the repository at this point in the history
`ST_DWithin` is equivalent to the expression `ST_Distance <= x`.
Similar reformulations can be made for different comparison operators
(<, >, >=). This commit adds two rules that replace expressions of the
latter form with either `ST_DWithin` or `ST_DWithinExclusive`. This
replacement is desirable because the `ST_DWithin` function can exit early.
`ST_DWithin` can also be used to generate an inverted index scan.

Fixes cockroachdb#52028

Release note: None
  • Loading branch information
DrewKimball committed Aug 19, 2020
1 parent 2088d55 commit 1e3a70a
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 0 deletions.
95 changes: 95 additions & 0 deletions pkg/sql/opt/norm/comp_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,98 @@ func findTimeZoneFunction(typ *types.T) (*tree.FunctionProperties, *tree.Overloa
}
panic(errors.AssertionFailedf("could not find overload for timezone"))
}

// MakeSTDWithinLeft returns an ST_DWithin function that replaces an expression
// of the following form: ST_Distance(a,b) <= x. Note that the ST_Distance
// function is on the left side of the inequality.
func (c *CustomFuncs) MakeSTDWithinLeft(
op opt.Operator, args memo.ScalarListExpr, bound opt.ScalarExpr,
) opt.ScalarExpr {
return c.makeSTDWithin(op, args, bound, true /* fnIsLeftArg */)
}

// MakeSTDWithinRight returns an ST_DWithin function that replaces an expression
// of the following form: x <= ST_Distance(a,b). Note that the ST_Distance
// function is on the right side of the inequality.
func (c *CustomFuncs) MakeSTDWithinRight(
op opt.Operator, args memo.ScalarListExpr, bound opt.ScalarExpr,
) opt.ScalarExpr {
return c.makeSTDWithin(op, args, bound, false /* fnIsLeftArg */)
}

// makeSTDWithin returns an ST_DWithin function that replaces an expression of
// the following form: ST_Distance(a,b) <= x. The ST_Distance function can be on
// either side of the inequality, and the inequality can be one of the
// following: '<', '<=', '>', '>='. This replacement allows early-exit behavior,
// and may enable use of an inverted index scan.
func (c *CustomFuncs) makeSTDWithin(
op opt.Operator, args memo.ScalarListExpr, bound opt.ScalarExpr, fnIsLeftArg bool,
) opt.ScalarExpr {
var not bool
var name string
const incName = "st_dwithin"
const exName = "_st_dwithinexclusive"
switch op {
case opt.GeOp:
if fnIsLeftArg {
// Matched expression: ST_Distance(a,b) >= x.
not = true
name = exName
} else {
// Matched expression: x >= ST_Distance(a,b).
not = false
name = incName
}

case opt.GtOp:
if fnIsLeftArg {
// Matched expression: ST_Distance(a,b) > x.
not = true
name = incName
} else {
// Matched expression: x > ST_Distance(a,b).
not = false
name = exName
}

case opt.LeOp:
if fnIsLeftArg {
// Matched expression: ST_Distance(a,b) <= x.
not = false
name = incName
} else {
// Matched expression: x <= ST_Distance(a,b).
not = true
name = exName
}

case opt.LtOp:
if fnIsLeftArg {
// Matched expression: ST_Distance(a,b) < x.
not = false
name = exName
} else {
// Matched expression: x < ST_Distance(a,b).
not = true
name = incName
}
}
props, overload, ok := memo.FindFunction(&args, name)
if !ok {
panic(errors.AssertionFailedf("could not find overload for %s", name))
}
within := c.f.ConstructFunction(append(args, bound), &memo.FunctionPrivate{
Name: name,
Typ: types.Bool,
Properties: props,
Overload: overload,
})
if not {
// ST_DWithin and ST_DWithinExclusive are equivalent to ST_Distance <= x and
// ST_Distance < x respectively. The comparison operator in the matched
// expression (if ST_Distance is normalized to be on the left) is either '>'
// or '>='. Therefore, we have to take the opposite of within.
within = c.f.ConstructNot(within)
}
return within
}
29 changes: 29 additions & 0 deletions pkg/sql/opt/norm/rules/comp.opt
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,32 @@
$tz
(MakeTimeZoneFunction (FirstScalarListExpr $args) $right)
)

# FoldCmpSTDistanceLeft replaces an expression of the form:
# 'ST_Distance(...) <= x' with a call to ST_DWithin or ST_DWithinExclusive. This
# replacement allows early-exit behavior, and may enable use of an inverted
# index scan. See the MakeSTDWithin method for the specific variation on
# ST_DWithin that is used to replace expressions with different comparison
# operators (e.g. '<' vs '<=').
[FoldCmpSTDistanceLeft, Normalize]
(Ge | Gt | Le | Lt
$left:(Function
$args:*
$private:(FunctionPrivate "st_distance")
)
$right:*
)
=>
(MakeSTDWithinLeft (OpName) $args $right)

# FoldCmpSTDistanceRight mirrors FoldCmpSTDistanceLeft.
[FoldCmpSTDistanceRight, Normalize]
(Ge | Gt | Le | Lt
$left:*
$right:(Function
$args:*
$private:(FunctionPrivate "st_distance")
)
)
=>
(MakeSTDWithinRight (OpName) $args $left)
136 changes: 136 additions & 0 deletions pkg/sql/opt/norm/testdata/rules/comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ exec-ddl
CREATE TABLE a (k INT PRIMARY KEY, i INT, f FLOAT, s STRING, j JSON, d DATE)
----

exec-ddl
CREATE TABLE geom_geog (
geom GEOMETRY,
geog GEOGRAPHY,
val FLOAT
)
----

# --------------------------------------------------
# CommuteVarInequality
# --------------------------------------------------
Expand Down Expand Up @@ -719,3 +727,131 @@ project
│ └── columns: ts:1
└── projections
└── ts:1 <= '2020-06-01 13:35:55' [as="?column?":5, outer=(1)]

# --------------------------------------------------
# FoldCmpSTDistanceLeft
# --------------------------------------------------

# Geometry case with '<=' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance(geom, 'point(0.0 0.0)') <= 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── st_dwithin(geom:1, '010100000000000000000000000000000000000000', 5.0) [outer=(1), immutable]

# Geometry case with '<' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance('point(0.0 0.0)', geom) < 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── _st_dwithinexclusive('010100000000000000000000000000000000000000', geom:1, 5.0) [outer=(1), immutable]

# Geometry case with '>=' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance(geom, 'point(0.0 0.0)') >= 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── NOT _st_dwithinexclusive(geom:1, '010100000000000000000000000000000000000000', 5.0) [outer=(1), immutable]

# Geometry case with '>' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance(geom, 'point(0.0 0.0)') > 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── NOT st_dwithin(geom:1, '010100000000000000000000000000000000000000', 5.0) [outer=(1), immutable]

# Geography case with '<=' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance(geog, 'point(0.0 0.0)') <= 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── st_dwithin(geog:2, '0101000020E610000000000000000000000000000000000000', 5.0) [outer=(2), immutable]

# Geography case with '<' operator.
norm expect=FoldCmpSTDistanceLeft
SELECT * FROM geom_geog WHERE st_distance(geog, 'point(0.0 0.0)') < 5
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── _st_dwithinexclusive(geog:2, '0101000020E610000000000000000000000000000000000000', 5.0) [outer=(2), immutable]

# --------------------------------------------------
# FoldCmpSTDistanceRight
# --------------------------------------------------

# Case with '<=' operator.
norm expect=FoldCmpSTDistanceRight
SELECT * FROM geom_geog WHERE val <= st_distance(geom, 'point(0.0 0.0)')
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── NOT _st_dwithinexclusive(geom:1, '010100000000000000000000000000000000000000', val:3) [outer=(1,3), immutable]

# Case with '<' operator.
norm expect=FoldCmpSTDistanceRight
SELECT * FROM geom_geog WHERE val < st_distance(geom, 'point(0.0 0.0)')
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── NOT st_dwithin(geom:1, '010100000000000000000000000000000000000000', val:3) [outer=(1,3), immutable]

# Case with '>=' operator.
norm expect=FoldCmpSTDistanceRight
SELECT * FROM geom_geog WHERE val >= st_distance(geom, 'point(0.0 0.0)')
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── st_dwithin(geom:1, '010100000000000000000000000000000000000000', val:3) [outer=(1,3), immutable]

# Case with '>' operator.
norm expect=FoldCmpSTDistanceRight
SELECT * FROM geom_geog WHERE val > st_distance(geom, 'point(0.0 0.0)')
----
select
├── columns: geom:1 geog:2 val:3
├── immutable
├── scan geom_geog
│ └── columns: geom:1 geog:2 val:3
└── filters
└── _st_dwithinexclusive(geom:1, '010100000000000000000000000000000000000000', val:3) [outer=(1,3), immutable]

0 comments on commit 1e3a70a

Please sign in to comment.