Skip to content

Commit

Permalink
Improve modeling of special date/time functions in parser
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Feb 21, 2024
1 parent 111c860 commit f9c1b5e
Show file tree
Hide file tree
Showing 16 changed files with 739 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,11 @@ primaryExpression
| value=primaryExpression '[' index=valueExpression ']' #subscript
| identifier #columnReference
| base=primaryExpression '.' fieldName=identifier #dereference
| name=CURRENT_DATE #specialDateTimeFunction
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction
| name=CURRENT_DATE #currentDate
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #currentTime
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #currentTimestamp
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #localTime
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #localTimestamp
| name=CURRENT_USER #currentUser
| name=CURRENT_CATALOG #currentCatalog
| name=CURRENT_SCHEMA #currentSchema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentDate;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
import io.trino.sql.tree.CurrentTime;
import io.trino.sql.tree.CurrentTimestamp;
import io.trino.sql.tree.CurrentUser;
import io.trino.sql.tree.DataType;
import io.trino.sql.tree.DecimalLiteral;
Expand Down Expand Up @@ -111,6 +113,8 @@
import io.trino.sql.tree.LambdaArgumentDeclaration;
import io.trino.sql.tree.LambdaExpression;
import io.trino.sql.tree.LikePredicate;
import io.trino.sql.tree.LocalTime;
import io.trino.sql.tree.LocalTimestamp;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.MeasureDefinition;
Expand Down Expand Up @@ -162,7 +166,6 @@
import java.util.function.BiFunction;
import java.util.function.Function;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -230,6 +233,7 @@
import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampType.createTimestampType;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
Expand Down Expand Up @@ -649,34 +653,50 @@ protected Type visitRow(Row node, StackableAstVisitorContext<Context> context)
return setExpressionType(node, type);
}

@Override
protected Type visitCurrentDate(CurrentDate node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(node, DATE);
}

@Override
protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext<Context> context)
{
return switch (node.getFunction()) {
case DATE -> {
checkArgument(node.getPrecision() == null);
yield setExpressionType(node, DATE);
}
case TIME -> {
if (node.getPrecision() != null) {
yield setExpressionType(node, createTimeWithTimeZoneType(node.getPrecision()));
}
yield setExpressionType(node, TIME_TZ_MILLIS);
}
case LOCALTIME -> {
if (node.getPrecision() != null) {
yield setExpressionType(node, createTimeType(node.getPrecision()));
}
yield setExpressionType(node, TIME_MILLIS);
}
case TIMESTAMP -> setExpressionType(node, createTimestampWithTimeZoneType(firstNonNull(node.getPrecision(), TimestampWithTimeZoneType.DEFAULT_PRECISION)));
case LOCALTIMESTAMP -> {
if (node.getPrecision() != null) {
yield setExpressionType(node, createTimestampType(node.getPrecision()));
}
yield setExpressionType(node, TIMESTAMP_MILLIS);
}
};
return setExpressionType(
node,
node.getPrecision()
.map(TimeWithTimeZoneType::createTimeWithTimeZoneType)
.orElse(TIME_TZ_MILLIS));
}

@Override
protected Type visitCurrentTimestamp(CurrentTimestamp node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(
node,
node.getPrecision()
.map(TimestampWithTimeZoneType::createTimestampWithTimeZoneType)
.orElse(TIMESTAMP_TZ_MILLIS));
}

@Override
protected Type visitLocalTime(LocalTime node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(
node,
node.getPrecision()
.map(TimeType::createTimeType)
.orElse(TIME_MILLIS));
}

@Override
protected Type visitLocalTimestamp(LocalTimestamp node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(
node,
node.getPrecision()
.map(TimestampType::createTimestampType)
.orElse(TIMESTAMP_MILLIS));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.ComparisonExpression.Operator;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentDate;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
import io.trino.sql.tree.CurrentTime;
import io.trino.sql.tree.CurrentTimestamp;
import io.trino.sql.tree.CurrentUser;
import io.trino.sql.tree.DereferenceExpression;
import io.trino.sql.tree.ExistsPredicate;
Expand All @@ -84,6 +86,8 @@
import io.trino.sql.tree.LambdaExpression;
import io.trino.sql.tree.LikePredicate;
import io.trino.sql.tree.Literal;
import io.trino.sql.tree.LocalTime;
import io.trino.sql.tree.LocalTimestamp;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
Expand Down Expand Up @@ -1024,36 +1028,54 @@ protected Object visitAtTimeZone(AtTimeZone node, Object context)
throw new IllegalArgumentException("Unexpected type: " + valueType);
}

@Override
protected Object visitCurrentDate(CurrentDate node, Object context)
{
return functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("current_date", ImmutableList.of()),
connectorSession,
ImmutableList.of());
}

@Override
protected Object visitCurrentTime(CurrentTime node, Object context)
{
return switch (node.getFunction()) {
case DATE -> functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("current_date", ImmutableList.of()),
connectorSession,
ImmutableList.of());
case TIME -> functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$current_time", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
case LOCALTIME -> functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$localtime", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
case TIMESTAMP -> functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$current_timestamp", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
case LOCALTIMESTAMP -> functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$localtimestamp", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
};
return functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$current_time", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
}

@Override
protected Object visitCurrentTimestamp(CurrentTimestamp node, Object context)
{
return functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$current_timestamp", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
}

@Override
protected Object visitLocalTime(LocalTime node, Object context)
{
return functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$localtime", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
}

@Override
protected Object visitLocalTimestamp(LocalTimestamp node, Object context)
{
return functionInvoker.invoke(
plannerContext.getMetadata()
.resolveBuiltinFunction("$localtimestamp", TypeSignatureProvider.fromTypes(type(node))),
connectorSession,
singletonList(null));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.tree.CurrentDate;
import io.trino.sql.tree.CurrentTime;
import io.trino.sql.tree.CurrentTimestamp;
import io.trino.sql.tree.DefaultExpressionTraversalVisitor;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.LocalTime;
import io.trino.sql.tree.LocalTimestamp;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand Down Expand Up @@ -72,19 +76,47 @@ public static boolean containsCurrentTimeFunctions(Expression expression)
{
requireNonNull(expression, "expression is null");

AtomicBoolean currentTime = new AtomicBoolean(false);
new CurrentTimeVisitor().process(expression, currentTime);
return currentTime.get();
AtomicBoolean hasTemporalFunction = new AtomicBoolean(false);
new TemporalFunctionVisitor().process(expression, hasTemporalFunction);
return hasTemporalFunction.get();
}

private static class CurrentTimeVisitor
private static class TemporalFunctionVisitor
extends DefaultExpressionTraversalVisitor<AtomicBoolean>
{
@Override
protected Void visitCurrentDate(CurrentDate node, AtomicBoolean currentTime)
{
currentTime.set(true);
return null;
}

@Override
protected Void visitCurrentTime(CurrentTime node, AtomicBoolean currentTime)
{
currentTime.set(true);
return null;
}

@Override
protected Void visitCurrentTimestamp(CurrentTimestamp node, AtomicBoolean currentTime)
{
currentTime.set(true);
return null;
}

@Override
protected Void visitLocalTime(LocalTime node, AtomicBoolean currentTime)
{
currentTime.set(true);
return null;
}

@Override
protected Void visitLocalTimestamp(LocalTimestamp node, AtomicBoolean currentTime)
{
currentTime.set(true);
return null;
}
}
}
Loading

0 comments on commit f9c1b5e

Please sign in to comment.