Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Return Correct Type Information for Fields #365

Merged
merged 8 commits into from
Feb 21, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,28 @@
public enum ScalarFunction implements TypeExpression {

ABS(func(T(NUMBER)).to(T)), // translate to Java: <T extends Number> T ABS(T)
ACOS(func(T(NUMBER)).to(T)),
ACOS(func(T(NUMBER)).to(DOUBLE)),
ADD(func(T(NUMBER), NUMBER).to(T)),
ASCII(func(T(STRING)).to(INTEGER)),
ASIN(func(T(NUMBER)).to(T)),
ATAN(func(T(NUMBER)).to(T)),
ATAN2(func(T(NUMBER), NUMBER).to(T)),
ASIN(func(T(NUMBER)).to(DOUBLE)),
ATAN(func(T(NUMBER)).to(DOUBLE)),
ATAN2(func(T(NUMBER), NUMBER).to(DOUBLE)),
CAST(),
CBRT(func(T(NUMBER)).to(T)),
CEIL(func(T(NUMBER)).to(T)),
CONCAT(), // TODO: varargs support required
CONCAT_WS(),
COS(func(T(NUMBER)).to(T)),
COSH(func(T(NUMBER)).to(T)),
COT(func(T(NUMBER)).to(T)),
COS(func(T(NUMBER)).to(DOUBLE)),
COSH(func(T(NUMBER)).to(DOUBLE)),
COT(func(T(NUMBER)).to(DOUBLE)),
CURDATE(func().to(ESDataType.DATE)),
DATE(func(ESDataType.DATE).to(ESDataType.DATE)),
DATE_FORMAT(
func(ESDataType.DATE, STRING).to(STRING),
func(ESDataType.DATE, STRING, STRING).to(STRING)
),
DAYOFMONTH(func(ESDataType.DATE).to(INTEGER)),
DEGREES(func(T(NUMBER)).to(T)),
DEGREES(func(T(NUMBER)).to(DOUBLE)),
DIVIDE(func(T(NUMBER), NUMBER).to(T)),
E(func().to(DOUBLE)),
EXP(func(T(NUMBER)).to(T)),
Expand Down Expand Up @@ -96,7 +96,7 @@ public enum ScalarFunction implements TypeExpression {
func(T(NUMBER)).to(T),
func(T(NUMBER), NUMBER).to(T)
),
RADIANS(func(T(NUMBER)).to(T)),
RADIANS(func(T(NUMBER)).to(DOUBLE)),
RAND(
func().to(NUMBER),
func(T(NUMBER)).to(T)
Expand All @@ -108,12 +108,12 @@ public enum ScalarFunction implements TypeExpression {
RTRIM(func(T(STRING)).to(T)),
SIGN(func(T(NUMBER)).to(T)),
SIGNUM(func(T(NUMBER)).to(T)),
SIN(func(T(NUMBER)).to(T)),
SINH(func(T(NUMBER)).to(T)),
SIN(func(T(NUMBER)).to(DOUBLE)),
SINH(func(T(NUMBER)).to(DOUBLE)),
SQRT(func(T(NUMBER)).to(T)),
SUBSTRING(func(T(STRING), INTEGER, INTEGER).to(T)),
SUBTRACT(func(T(NUMBER), NUMBER).to(T)),
TAN(func(T(NUMBER)).to(T)),
TAN(func(T(NUMBER)).to(DOUBLE)),
TIMESTAMP(func(ESDataType.DATE).to(ESDataType.DATE)),
TRIM(func(T(STRING)).to(T)),
UPPER(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@

package com.amazon.opendistroforelasticsearch.sql.executor.format;

import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
import com.amazon.opendistroforelasticsearch.sql.domain.Delete;
import com.amazon.opendistroforelasticsearch.sql.domain.IndexStatement;
import com.amazon.opendistroforelasticsearch.sql.domain.Query;
import com.amazon.opendistroforelasticsearch.sql.domain.QueryStatement;
import com.amazon.opendistroforelasticsearch.sql.executor.format.DataRows.Row;
import com.amazon.opendistroforelasticsearch.sql.executor.format.Schema.Column;
import com.amazon.opendistroforelasticsearch.sql.executor.adapter.QueryPlanQueryAction;
import com.amazon.opendistroforelasticsearch.sql.executor.adapter.QueryPlanRequestBuilder;
import com.amazon.opendistroforelasticsearch.sql.executor.format.DataRows.Row;
import com.amazon.opendistroforelasticsearch.sql.executor.format.Schema.Column;
import com.amazon.opendistroforelasticsearch.sql.expression.domain.BindingTuple;
import com.amazon.opendistroforelasticsearch.sql.query.planner.core.ColumnNode;
import com.amazon.opendistroforelasticsearch.sql.query.DefaultQueryAction;
import com.amazon.opendistroforelasticsearch.sql.query.QueryAction;
import com.amazon.opendistroforelasticsearch.sql.query.planner.core.ColumnNode;
import org.elasticsearch.client.Client;
import org.json.JSONArray;
import org.json.JSONObject;
Expand All @@ -48,11 +50,14 @@ public class Protocol {
private ResultSet resultSet;
private ErrorMessage error;
private List<ColumnNode> columnNodeList;
private ColumnTypeProvider scriptColumnType = new ColumnTypeProvider();

public Protocol(Client client, QueryAction queryAction, Object queryResult, String formatType) {
if (queryAction instanceof QueryPlanQueryAction) {
this.columnNodeList =
((QueryPlanRequestBuilder) (((QueryPlanQueryAction) queryAction).explain())).outputColumns();
} else if (queryAction instanceof DefaultQueryAction) {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
scriptColumnType = queryAction.getScriptColumnType();
}
this.formatType = formatType;
QueryStatement query = queryAction.getQueryStatement();
Expand All @@ -75,7 +80,7 @@ private ResultSet loadResultSet(Client client, QueryStatement queryStatement, Ob
if (queryStatement instanceof Delete) {
return new DeleteResultSet(client, (Delete) queryStatement, queryResult);
} else if (queryStatement instanceof Query) {
return new SelectResultSet(client, (Query) queryStatement, queryResult);
return new SelectResultSet(client, (Query) queryStatement, queryResult, scriptColumnType);
} else if (queryStatement instanceof IndexStatement) {
IndexStatement statement = (IndexStatement) queryStatement;
StatementType statementType = statement.getStatementType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.amazon.opendistroforelasticsearch.sql.executor.format;

import com.alibaba.druid.sql.ast.expr.SQLCaseExpr;
import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
import com.amazon.opendistroforelasticsearch.sql.domain.Field;
import com.amazon.opendistroforelasticsearch.sql.domain.JoinSelect;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
Expand Down Expand Up @@ -65,17 +66,19 @@ public class SelectResultSet extends ResultSet {
private String indexName;
private String typeName;
private List<Schema.Column> columns = new ArrayList<>();
private ColumnTypeProvider outputColumnType;

private List<String> head;
private long size;
private long totalHits;
private List<DataRows.Row> rows;

public SelectResultSet(Client client, Query query, Object queryResult) {
public SelectResultSet(Client client, Query query, Object queryResult, ColumnTypeProvider outputColumnType) {
this.client = client;
this.query = query;
this.queryResult = queryResult;
this.selectAll = false;
this.outputColumnType = outputColumnType;

if (isJoinQuery()) {
JoinSelect joinQuery = (JoinSelect) query;
Expand Down Expand Up @@ -308,7 +311,7 @@ private String[] emptyArrayIfNull(String typeName) {
}
}

private Schema.Type fetchMethodReturnType(Field field) {
private Schema.Type fetchMethodReturnType(int fieldIndex, MethodField field) {
switch (field.getName().toLowerCase()) {
case "count":
return Schema.Type.LONG;
Expand All @@ -325,7 +328,8 @@ private Schema.Type fetchMethodReturnType(Field field) {
if (field.getExpression() instanceof SQLCaseExpr) {
return Schema.Type.TEXT;
}
return SQLFunctions.getScriptFunctionReturnType(field);
Schema.Type resolvedType = outputColumnType.get(fieldIndex);
return SQLFunctions.getScriptFunctionReturnType(field, resolvedType);
}
default:
throw new UnsupportedOperationException(
Expand Down Expand Up @@ -374,12 +378,13 @@ private List<Schema.Column> populateColumns(Query query, String[] fieldNames, Ma
* name instead.
*/
if (fieldMap.get(fieldName) instanceof MethodField) {
Field methodField = fieldMap.get(fieldName);
MethodField methodField = (MethodField) fieldMap.get(fieldName);
int fieldIndex = fieldNameList.indexOf(fieldName);
columns.add(
new Schema.Column(
methodField.getAlias(),
null,
fetchMethodReturnType(methodField)
fetchMethodReturnType(fieldIndex, methodField)
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ private static QueryAction explainRequest(final NodeClient client, final SqlRequ
final QueryAction queryAction = new SearchDao(client)
.explain(new QueryActionRequest(sqlRequest.getSql(), typeProvider, format));
queryAction.setSqlRequest(sqlRequest);
queryAction.setColumnTypeProvider(typeProvider);
return queryAction;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,7 @@ private String getNullOrderString(SQLBinaryOpExpr expr) {

private ScriptSortType getScriptSortType(Order order) {
ScriptSortType scriptSortType;
Schema.Type scriptFunctionReturnType;
if (order.getSortField().getExpression() instanceof SQLCastExpr) {
scriptFunctionReturnType = SQLFunctions.getCastFunctionReturnType(
((SQLCastExpr) order.getSortField().getExpression()).getDataType().getName());
} else {
scriptFunctionReturnType = SQLFunctions.getScriptFunctionReturnType(order.getSortField());
}
Schema.Type scriptFunctionReturnType = SQLFunctions.getOrderByFieldType(order.getSortField());


// as of now script function return type returns only text and double
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.amazon.opendistroforelasticsearch.sql.query;

import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
import com.amazon.opendistroforelasticsearch.sql.domain.Query;
import com.amazon.opendistroforelasticsearch.sql.domain.QueryStatement;
import com.amazon.opendistroforelasticsearch.sql.domain.Select;
Expand Down Expand Up @@ -48,6 +49,7 @@ public abstract class QueryAction {
protected Query query;
protected Client client;
protected SqlRequest sqlRequest = SqlRequest.NULL;
protected ColumnTypeProvider scriptColumnType;

public QueryAction(Client client, Query query) {
this.client = client;
Expand All @@ -66,10 +68,18 @@ public void setSqlRequest(SqlRequest sqlRequest) {
this.sqlRequest = sqlRequest;
}

public void setColumnTypeProvider(ColumnTypeProvider scriptColumnType) {
this.scriptColumnType = scriptColumnType;
}

public SqlRequest getSqlRequest() {
return sqlRequest;
}

public ColumnTypeProvider getScriptColumnType() {
return scriptColumnType;
}

/**
* @return List of field names produced by the query
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ private Tuple<String, String> date_format(SQLExpr field, String pattern, String
if (valueName == null) {
return new Tuple<>(name, "def " + name + " = DateTimeFormatter.ofPattern('" + pattern + "').withZone("
+ (zoneId != null ? "ZoneId.of('" + zoneId + "')" : "ZoneId.systemDefault()")
+ ").format(Instant.ofEpochMilli(" + getPropertyOrValue(field) + ".getMillis()))");
+ ").format(Instant.ofEpochMilli(" + getPropertyOrValue(field) + ".toInstant().toEpochMilli()))");
} else {
return new Tuple<>(name, exprString(field) + "; "
+ "def " + name + " = new SimpleDateFormat('" + pattern + "').format("
Expand Down Expand Up @@ -973,34 +973,16 @@ public String getCastScriptStatement(String name, String castType, List<KVValue>
* approach will return type of result column as DOUBLE, although there is enough information to understand that
* it might be safely treated as INTEGER.
*/
public static Schema.Type getScriptFunctionReturnType(Field field) {
public static Schema.Type getScriptFunctionReturnType(MethodField field, Schema.Type resolvedType) {
Schema.Type returnType;
String functionName = ((ScriptMethodField) field).getFunctionName().toLowerCase();
if (functionName.equals("cast")) {
String castType = ((SQLCastExpr) field.getExpression()).getDataType().getName();
return getCastFunctionReturnType(castType);
} else {
returnType = resolvedType;
}
if (dateFunctions.contains(functionName) || stringOperators.contains(functionName)) {
return Schema.Type.TEXT;
}

if (mathConstants.contains(functionName) || numberOperators.contains(functionName)
|| trigFunctions.contains(functionName) || binaryOperators.contains(functionName)
|| utilityFunctions.contains(functionName)) {
return Schema.Type.DOUBLE;
}

if (stringFunctions.contains(functionName)) {
return Schema.Type.INTEGER;
}

if (conditionalFunctions.contains(functionName)) {
return Schema.Type.KEYWORD;
}

throw new UnsupportedOperationException(
String.format(
"The following method is not supported in Schema: %s",
functionName));
return returnType;
}

public static Schema.Type getCastFunctionReturnType(String castType) {
Expand All @@ -1023,4 +1005,38 @@ public static Schema.Type getCastFunctionReturnType(String castType) {
);
}
}

/**
*
* @param field
* @return Schema.Type.TEXT or DOUBLE
* There are only two ORDER BY types (TEXT, NUMBER) in Elasticsearch, so the Type that is returned here essentially
* indicates the category of the function as opposed to the actual return type.
*/
public static Schema.Type getOrderByFieldType(Field field) {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
String functionName = ((ScriptMethodField) field).getFunctionName().toLowerCase();
if (functionName.equals("cast")) {
String castType = ((SQLCastExpr) field.getExpression()).getDataType().getName();
return getCastFunctionReturnType(castType);
}

if (numberOperators.contains(functionName) || mathConstants.contains(functionName)
|| trigFunctions.contains(functionName) || binaryOperators.contains(functionName)) {
return Schema.Type.DOUBLE;
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
} else if (dateFunctions.contains(functionName)) {
if (functionName.equals("date_format") || functionName.equals("now")
|| functionName.equals("curdate") || functionName.equals("date")
|| functionName.equals("timestamp") || functionName.equals("monthname")) {
return Schema.Type.TEXT;
}
return Schema.Type.DOUBLE;
} else if (stringFunctions.contains(functionName) || stringOperators.contains(functionName)) {
return Schema.Type.TEXT;
}

throw new UnsupportedOperationException(
String.format(
"The following method is not supported in Schema for Order By: %s",
functionName));
}
}
Loading