Skip to content

Commit

Permalink
adding DATETIME cast support (opendistro-for-elasticsearch#310)
Browse files Browse the repository at this point in the history
Adding full support for CAST(), and adding it as a function. Fixed Datetime casting to be UTC-timezone default.
  • Loading branch information
davidcui1225 authored Jan 30, 2020
1 parent d9fe9dc commit 68b971f
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public enum ScalarFunction implements TypeExpression {
ASIN(func(T(NUMBER)).to(T)),
ATAN(func(T(NUMBER)).to(T)),
ATAN2(func(T(NUMBER), NUMBER).to(T)),
CAST(),
CBRT(func(T(NUMBER)).to(T)),
CEIL(func(T(NUMBER)).to(T)),
CONCAT(), // TODO: varargs support required
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
package com.amazon.opendistroforelasticsearch.sql.executor.format;

import com.alibaba.druid.sql.ast.expr.SQLCaseExpr;
import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.amazon.opendistroforelasticsearch.sql.domain.Field;
import com.amazon.opendistroforelasticsearch.sql.domain.JoinSelect;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
import com.amazon.opendistroforelasticsearch.sql.domain.Query;
import com.amazon.opendistroforelasticsearch.sql.domain.ScriptMethodField;
import com.amazon.opendistroforelasticsearch.sql.domain.Select;
import com.amazon.opendistroforelasticsearch.sql.domain.TableOnJoinSelect;
import com.amazon.opendistroforelasticsearch.sql.esdomain.mapping.FieldMapping;
Expand Down Expand Up @@ -324,14 +322,10 @@ private Schema.Type fetchMethodReturnType(Field field) {
// TODO: return type information is disconnected from the function definitions in SQLFunctions.
// Refactor SQLFunctions to have functions self-explanatory (types, scripts) and pluggable
// (similar to Strategy pattern)
if (field.getExpression() instanceof SQLCastExpr) {
return SQLFunctions.getCastFunctionReturnType(
((SQLCastExpr) field.getExpression()).getDataType().getName());
} else if (field.getExpression() instanceof SQLCaseExpr) {
if (field.getExpression() instanceof SQLCaseExpr) {
return Schema.Type.TEXT;
}
return SQLFunctions.getScriptFunctionReturnType(
((ScriptMethodField) field).getFunctionName());
return SQLFunctions.getScriptFunctionReturnType(field);
}
default:
throw new UnsupportedOperationException(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,9 @@ private Field makeFieldImpl(SQLExpr expr, String alias, String tableAlias) throw
if (alias == null) {
alias = "cast_" + castExpr.getExpr().toString();
}
String scriptCode = new CastParser(castExpr, alias, tableAlias).parse(true);
List<KVValue> methodParameters = new ArrayList<>();
methodParameters.add(new KVValue(alias));
methodParameters.add(new KVValue(scriptCode));
return new MethodField("script", methodParameters, null, alias);
ArrayList<SQLExpr> methodParameters = new ArrayList<>();
methodParameters.add(((SQLCastExpr) expr).getExpr());
return makeMethodField("CAST", methodParameters, null, alias, tableAlias, true);
} else if (expr instanceof SQLNumericLiteralExpr) {
SQLMethodInvokeExpr methodInvokeExpr = new SQLMethodInvokeExpr("assign", null);
methodInvokeExpr.addParameter(expr);
Expand Down Expand Up @@ -344,7 +342,12 @@ public MethodField makeMethodField(String name, List<SQLExpr> arguments, SQLAggr
String scriptCode = new CaseWhenParser((SQLCaseExpr) object, alias, tableAlias).parse();
paramers.add(new KVValue("script", new SQLCharExpr(scriptCode)));
} else if (object instanceof SQLCastExpr) {
String scriptCode = new CastParser((SQLCastExpr) object, alias, tableAlias).parse(false);
String castName = sqlFunctions.nextId("cast");
List<KVValue> methodParameters = new ArrayList<>();
methodParameters.add(new KVValue(((SQLCastExpr) object).getExpr().toString()));
String castType = ((SQLCastExpr) object).getDataType().getName();
String scriptCode = sqlFunctions.getCastScriptStatement(castName, castType, methodParameters);
methodParameters.add(new KVValue(scriptCode));
paramers.add(new KVValue("script", new SQLCharExpr(scriptCode)));
} else if (object instanceof SQLAggregateExpr) {
SQLObject parent = object.getParent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLBooleanExpr;
import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
Expand Down Expand Up @@ -98,8 +99,6 @@ public Where findWhere() throws SqlParseException {
}

public void parseWhere(SQLExpr expr, Where where) throws SqlParseException {


if (expr instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr bExpr = (SQLBinaryOpExpr) expr;
if (explainSpecialCondWithBothSidesAreLiterals(bExpr, where)) {
Expand Down Expand Up @@ -199,7 +198,8 @@ private boolean isCond(SQLBinaryOpExpr expr) {
}
return leftSide instanceof SQLIdentifierExpr
|| leftSide instanceof SQLPropertyExpr
|| leftSide instanceof SQLVariantRefExpr;
|| leftSide instanceof SQLVariantRefExpr
|| leftSide instanceof SQLCastExpr;
}

private boolean isAllowedMethodOnConditionLeft(SQLMethodInvokeExpr method, SQLBinaryOperator operator) {
Expand Down Expand Up @@ -233,6 +233,7 @@ private void routeCond(SQLBinaryOpExpr bExpr, SQLExpr sub, Where where) throws S
private void explainCond(String opear, SQLExpr expr, Where where) throws SqlParseException {
if (expr instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr soExpr = (SQLBinaryOpExpr) expr;

boolean methodAsOpear = false;

boolean isNested = false;
Expand Down Expand Up @@ -522,11 +523,23 @@ private MethodField parseSQLMethodInvokeExprWithFunctionInWhere(SQLMethodInvokeE
return methodField;
}

private MethodField parseSQLCastExprWithFunctionInWhere(SQLCastExpr soExpr) throws SqlParseException {
ArrayList<SQLExpr> parameters = new ArrayList<>();
parameters.add(soExpr.getExpr());
return fieldMaker.makeMethodField(
"CAST",
parameters,
null,
null,
query != null ? query.getFrom().getAlias() : null,
false
);
}

private SQLMethodInvokeExpr parseSQLBinaryOpExprWhoIsConditionInWhere(SQLBinaryOpExpr soExpr)
throws SqlParseException {

if (!(soExpr.getLeft() instanceof SQLMethodInvokeExpr
|| soExpr.getRight() instanceof SQLMethodInvokeExpr)) {
if (bothSideAreNotFunction(soExpr) && bothSidesAreNotCast(soExpr)) {
return null;
}

Expand Down Expand Up @@ -567,6 +580,13 @@ private SQLMethodInvokeExpr parseSQLBinaryOpExprWhoIsConditionInWhere(SQLBinaryO
rightMethod = parseSQLMethodInvokeExprWithFunctionInWhere((SQLMethodInvokeExpr) soExpr.getRight());
}

if (soExpr.getLeft() instanceof SQLCastExpr) {
leftMethod = parseSQLCastExprWithFunctionInWhere((SQLCastExpr) soExpr.getLeft());
}
if (soExpr.getRight() instanceof SQLCastExpr) {
rightMethod = parseSQLCastExprWithFunctionInWhere((SQLCastExpr) soExpr.getRight());
}

String v1 = leftMethod.getParams().get(0).value.toString();
String v1Dec = leftMethod.getParams().size() == 2 ? leftMethod.getParams().get(1).value.toString() + ";" : "";

Expand All @@ -588,6 +608,14 @@ private SQLMethodInvokeExpr parseSQLBinaryOpExprWhoIsConditionInWhere(SQLBinaryO

}

private Boolean bothSideAreNotFunction(SQLBinaryOpExpr soExpr) {
return !(soExpr.getLeft() instanceof SQLMethodInvokeExpr || soExpr.getRight() instanceof SQLMethodInvokeExpr);
}

private Boolean bothSidesAreNotCast(SQLBinaryOpExpr soExpr) {
return !(soExpr.getLeft() instanceof SQLCastExpr || soExpr.getRight() instanceof SQLCastExpr);
}

private Object[] getMethodValuesWithSubQueries(SQLMethodInvokeExpr method) throws SqlParseException {
List<Object> values = new ArrayList<>();
for (SQLExpr innerExpr : method.getParameters()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.amazon.opendistroforelasticsearch.sql.domain.KVValue;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
import com.amazon.opendistroforelasticsearch.sql.domain.Order;
import com.amazon.opendistroforelasticsearch.sql.domain.ScriptMethodField;
import com.amazon.opendistroforelasticsearch.sql.domain.Select;
import com.amazon.opendistroforelasticsearch.sql.domain.Where;
import com.amazon.opendistroforelasticsearch.sql.domain.hints.Hint;
Expand Down Expand Up @@ -273,8 +272,7 @@ private ScriptSortType getScriptSortType(Order order) {
scriptFunctionReturnType = SQLFunctions.getCastFunctionReturnType(
((SQLCastExpr) order.getSortField().getExpression()).getDataType().getName());
} else {
ScriptMethodField smf = (ScriptMethodField) order.getSortField();
scriptFunctionReturnType = SQLFunctions.getScriptFunctionReturnType(smf.getFunctionName());
scriptFunctionReturnType = SQLFunctions.getScriptFunctionReturnType(order.getSortField());
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBooleanExpr;
import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLNullExpr;
import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.expr.SQLTextLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.amazon.opendistroforelasticsearch.sql.domain.Field;
import com.amazon.opendistroforelasticsearch.sql.domain.KVValue;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
import com.amazon.opendistroforelasticsearch.sql.domain.ScriptMethodField;
import com.amazon.opendistroforelasticsearch.sql.exception.SqlParseException;
import com.amazon.opendistroforelasticsearch.sql.executor.format.Schema;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
Expand Down Expand Up @@ -84,7 +88,7 @@ public class SQLFunctions {
"if", "ifnull", "isnull"
);

private static final Set<String> utilityFunctions = Sets.newHashSet("field", "assign");
private static final Set<String> utilityFunctions = Sets.newHashSet("field", "assign", "cast");

public static final Set<String> builtInFunctions = Stream.of(
numberOperators,
Expand Down Expand Up @@ -117,9 +121,15 @@ public static boolean isFunctionTranslatedToScript(String function) {
}

public Tuple<String, String> function(String methodName, List<KVValue> paramers, String name,
boolean returnValue) {
boolean returnValue) throws SqlParseException {
Tuple<String, String> functionStr = null;
switch (methodName.toLowerCase()) {
case "cast": {
SQLCastExpr castExpr = (SQLCastExpr) ((SQLIdentifierExpr) paramers.get(0).value).getParent();
String typeName = castExpr.getDataType().getName();
functionStr = cast(typeName, paramers);
break;
}
case "lower": {
functionStr = lower(
(SQLExpr) paramers.get(0).value,
Expand Down Expand Up @@ -401,6 +411,12 @@ public String getLocaleForCaseChangingFunction(List<KVValue> paramers) {
return locale;
}

public Tuple<String, String> cast(String castType, List<KVValue> paramers) throws SqlParseException {
String name = nextId("cast");
return new Tuple<>(name, getCastScriptStatement(name, castType, paramers));
}


public Tuple<String, String> upper(SQLExpr field, String locale, String valueName) {
String name = nextId("upper");

Expand Down Expand Up @@ -929,15 +945,40 @@ private Tuple<String, String> isnull(SQLExpr expr) {
return new Tuple<>(name, def(name, resultStr));
}

public String getCastScriptStatement(String name, String castType, List<KVValue> paramers)
throws SqlParseException {
String castFieldName = String.format("doc['%s'].value", paramers.get(0).toString());
switch (StringUtils.toUpper(castType)) {
case "INT":
return String.format("def %s = Double.parseDouble(%s.toString()).intValue()", name, castFieldName);
case "LONG":
return String.format("def %s = Double.parseDouble(%s.toString()).longValue()", name, castFieldName);
case "FLOAT":
return String.format("def %s = Double.parseDouble(%s.toString()).floatValue()", name, castFieldName);
case "DOUBLE":
return String.format("def %s = Double.parseDouble(%s.toString()).doubleValue()", name, castFieldName);
case "STRING":
return String.format("def %s = %s.toString()", name, castFieldName);
case "DATETIME":
return String.format("def %s = DateTimeFormatter.ofPattern(\"yyyy-MM-dd'T'HH:mm:ss.SSS'Z'\").format("
+ "DateTimeFormatter.ISO_DATE_TIME.parse(%s.toString()))", name, castFieldName);
default:
throw new SqlParseException("Unsupported cast type " + castType);
}
}

/**
* Returns return type of script function. This is simple approach, that might be not the best solution in the long
* term. For example - for JDBC, if the column type in index is INTEGER, and the query is "select column+5", current
* approach will return type of result column as DOUBLE, although there is enough information to understand that
* it might be safely treated as INTEGER.
*/
public static Schema.Type getScriptFunctionReturnType(String functionName) {
functionName = functionName.toLowerCase();

public static Schema.Type getScriptFunctionReturnType(Field field) {
String functionName = ((ScriptMethodField) field).getFunctionName().toLowerCase();
if (functionName.equals("cast")) {
String castType = ((SQLCastExpr) field.getExpression()).getDataType().getName();
return getCastFunctionReturnType(castType);
}
if (dateFunctions.contains(functionName) || stringOperators.contains(functionName)) {
return Schema.Type.TEXT;
}
Expand Down
Loading

0 comments on commit 68b971f

Please sign in to comment.