diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/OpenSearchLegacySqlAnalyzer.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/OpenSearchLegacySqlAnalyzer.java index b44e2bbb41..bb063f4df4 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/OpenSearchLegacySqlAnalyzer.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/OpenSearchLegacySqlAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr; import java.util.Optional; @@ -25,84 +24,77 @@ import org.opensearch.sql.legacy.antlr.visitor.EarlyExitAnalysisException; import org.opensearch.sql.legacy.esdomain.LocalClusterState; -/** - * Entry point for ANTLR generated parser to perform strict syntax and semantic analysis. - */ +/** Entry point for ANTLR generated parser to perform strict syntax and semantic analysis. */ public class OpenSearchLegacySqlAnalyzer { - private static final Logger LOG = LogManager.getLogger(); - - /** Original sql query */ - private final SqlAnalysisConfig config; - - public OpenSearchLegacySqlAnalyzer(SqlAnalysisConfig config) { - this.config = config; - } - - public Optional analyze(String sql, LocalClusterState clusterState) { - // Perform analysis for SELECT only for now because of extra code changes required for SHOW/DESCRIBE. - if (!isSelectStatement(sql) || !config.isAnalyzerEnabled()) { - return Optional.empty(); - } + private static final Logger LOG = LogManager.getLogger(); - try { - return Optional.of(analyzeSemantic( - analyzeSyntax(sql), - clusterState - )); - } catch (EarlyExitAnalysisException e) { - // Expected if configured so log on debug level to avoid always logging stack trace - LOG.debug("Analysis exits early and will skip remaining process", e); - return Optional.empty(); - } - } + /** Original sql query */ + private final SqlAnalysisConfig config; - /** - * Build lexer and parser to perform syntax analysis only. - * Runtime exception with clear message is thrown for any verification error. - * - * @return parse tree - */ - public ParseTree analyzeSyntax(String sql) { - OpenSearchLegacySqlParser parser = createParser(createLexer(sql)); - parser.addErrorListener(new SyntaxAnalysisErrorListener()); - return parser.root(); - } + public OpenSearchLegacySqlAnalyzer(SqlAnalysisConfig config) { + this.config = config; + } - /** - * Perform semantic analysis based on syntax analysis output - parse tree. - * - * @param tree parse tree - * @param clusterState cluster state required for index mapping query - */ - public Type analyzeSemantic(ParseTree tree, LocalClusterState clusterState) { - return tree.accept(new AntlrSqlParseTreeVisitor<>(createAnalyzer(clusterState))); + public Optional analyze(String sql, LocalClusterState clusterState) { + // Perform analysis for SELECT only for now because of extra code changes required for + // SHOW/DESCRIBE. + if (!isSelectStatement(sql) || !config.isAnalyzerEnabled()) { + return Optional.empty(); } - /** Factory method for semantic analyzer to help assemble all required components together */ - private SemanticAnalyzer createAnalyzer(LocalClusterState clusterState) { - SemanticContext context = new SemanticContext(); - OpenSearchMappingLoader - mappingLoader = new OpenSearchMappingLoader(context, clusterState, config.getAnalysisThreshold()); - TypeChecker typeChecker = new TypeChecker(context, config.isFieldSuggestionEnabled()); - return new SemanticAnalyzer(mappingLoader, typeChecker); + try { + return Optional.of(analyzeSemantic(analyzeSyntax(sql), clusterState)); + } catch (EarlyExitAnalysisException e) { + // Expected if configured so log on debug level to avoid always logging stack trace + LOG.debug("Analysis exits early and will skip remaining process", e); + return Optional.empty(); } - - private OpenSearchLegacySqlParser createParser(Lexer lexer) { - return new OpenSearchLegacySqlParser( - new CommonTokenStream(lexer)); - } - - private OpenSearchLegacySqlLexer createLexer(String sql) { - return new OpenSearchLegacySqlLexer( - new CaseInsensitiveCharStream(sql)); - } - - private boolean isSelectStatement(String sql) { - sql = sql.replaceAll("\\R", " ").trim(); - int endOfFirstWord = sql.indexOf(' '); - String firstWord = sql.substring(0, endOfFirstWord > 0 ? endOfFirstWord : sql.length()); - return "SELECT".equalsIgnoreCase(firstWord); - } - + } + + /** + * Build lexer and parser to perform syntax analysis only. Runtime exception with clear message is + * thrown for any verification error. + * + * @return parse tree + */ + public ParseTree analyzeSyntax(String sql) { + OpenSearchLegacySqlParser parser = createParser(createLexer(sql)); + parser.addErrorListener(new SyntaxAnalysisErrorListener()); + return parser.root(); + } + + /** + * Perform semantic analysis based on syntax analysis output - parse tree. + * + * @param tree parse tree + * @param clusterState cluster state required for index mapping query + */ + public Type analyzeSemantic(ParseTree tree, LocalClusterState clusterState) { + return tree.accept(new AntlrSqlParseTreeVisitor<>(createAnalyzer(clusterState))); + } + + /** Factory method for semantic analyzer to help assemble all required components together */ + private SemanticAnalyzer createAnalyzer(LocalClusterState clusterState) { + SemanticContext context = new SemanticContext(); + OpenSearchMappingLoader mappingLoader = + new OpenSearchMappingLoader(context, clusterState, config.getAnalysisThreshold()); + TypeChecker typeChecker = new TypeChecker(context, config.isFieldSuggestionEnabled()); + return new SemanticAnalyzer(mappingLoader, typeChecker); + } + + private OpenSearchLegacySqlParser createParser(Lexer lexer) { + return new OpenSearchLegacySqlParser(new CommonTokenStream(lexer)); + } + + private OpenSearchLegacySqlLexer createLexer(String sql) { + return new OpenSearchLegacySqlLexer(new CaseInsensitiveCharStream(sql)); + } + + private boolean isSelectStatement(String sql) { + sql = sql.replaceAll("\\R", " ").trim(); + int endOfFirstWord = sql.indexOf(' '); + String firstWord = sql.substring(0, endOfFirstWord > 0 ? endOfFirstWord : sql.length()); + return "SELECT".equalsIgnoreCase(firstWord); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchDataType.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchDataType.java index eab40c2dc7..00ef4afdf1 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchDataType.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchDataType.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.base; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchIndex.IndexType.NESTED_FIELD; @@ -13,105 +12,102 @@ import java.util.Map; import org.opensearch.sql.legacy.antlr.semantic.types.Type; -/** - * Base type hierarchy based on OpenSearch data type - */ +/** Base type hierarchy based on OpenSearch data type */ public enum OpenSearchDataType implements BaseType { - - TYPE_ERROR, - UNKNOWN, - - SHORT, LONG, - INTEGER(SHORT, LONG), - FLOAT(INTEGER), - DOUBLE(FLOAT), - NUMBER(DOUBLE), - - KEYWORD, - TEXT(KEYWORD), - STRING(TEXT), - - DATE_NANOS, - DATE(DATE_NANOS, STRING), - - BOOLEAN, - - OBJECT, NESTED, - COMPLEX(OBJECT, NESTED), - - GEO_POINT, - - OPENSEARCH_TYPE( - NUMBER, - //STRING, move to under DATE because DATE is compatible - DATE, - BOOLEAN, - COMPLEX, - GEO_POINT - ); - - - /** - * Java Enum's valueOf() may thrown "enum constant not found" exception. - * And Java doesn't provide a contains method. - * So this static map is necessary for check and efficiency. - */ - private static final Map ALL_BASE_TYPES; - static { - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - for (OpenSearchDataType type : OpenSearchDataType.values()) { - builder.put(type.name(), type); - } - ALL_BASE_TYPES = builder.build(); + TYPE_ERROR, + UNKNOWN, + + SHORT, + LONG, + INTEGER(SHORT, LONG), + FLOAT(INTEGER), + DOUBLE(FLOAT), + NUMBER(DOUBLE), + + KEYWORD, + TEXT(KEYWORD), + STRING(TEXT), + + DATE_NANOS, + DATE(DATE_NANOS, STRING), + + BOOLEAN, + + OBJECT, + NESTED, + COMPLEX(OBJECT, NESTED), + + GEO_POINT, + + OPENSEARCH_TYPE( + NUMBER, + // STRING, move to under DATE because DATE is compatible + DATE, + BOOLEAN, + COMPLEX, + GEO_POINT); + + /** + * Java Enum's valueOf() may thrown "enum constant not found" exception. And Java doesn't provide + * a contains method. So this static map is necessary for check and efficiency. + */ + private static final Map ALL_BASE_TYPES; + + static { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (OpenSearchDataType type : OpenSearchDataType.values()) { + builder.put(type.name(), type); } + ALL_BASE_TYPES = builder.build(); + } - public static OpenSearchDataType typeOf(String str) { - return ALL_BASE_TYPES.getOrDefault(toUpper(str), UNKNOWN); - } + public static OpenSearchDataType typeOf(String str) { + return ALL_BASE_TYPES.getOrDefault(toUpper(str), UNKNOWN); + } - /** Parent of current base type */ - private OpenSearchDataType parent; + /** Parent of current base type */ + private OpenSearchDataType parent; - OpenSearchDataType(OpenSearchDataType... compatibleTypes) { - for (OpenSearchDataType subType : compatibleTypes) { - subType.parent = this; - } + OpenSearchDataType(OpenSearchDataType... compatibleTypes) { + for (OpenSearchDataType subType : compatibleTypes) { + subType.parent = this; } - - @Override - public String getName() { - return name(); + } + + @Override + public String getName() { + return name(); + } + + /** + * For base type, compatibility means this (current type) is ancestor of other in the base type + * hierarchy. + */ + @Override + public boolean isCompatible(Type other) { + // Skip compatibility check if type is unknown + if (this == UNKNOWN || other == UNKNOWN) { + return true; } - /** - * For base type, compatibility means this (current type) is ancestor of other - * in the base type hierarchy. - */ - @Override - public boolean isCompatible(Type other) { - // Skip compatibility check if type is unknown - if (this == UNKNOWN || other == UNKNOWN) { - return true; - } - - if (!(other instanceof OpenSearchDataType)) { - // Nested data type is compatible with nested index type for type expression use - if (other instanceof OpenSearchIndex && ((OpenSearchIndex) other).type() == NESTED_FIELD) { - return isCompatible(NESTED); - } - return false; - } - - // One way compatibility: parent base type is compatible with children - OpenSearchDataType cur = (OpenSearchDataType) other; - while (cur != null && cur != this) { - cur = cur.parent; - } - return cur != null; + if (!(other instanceof OpenSearchDataType)) { + // Nested data type is compatible with nested index type for type expression use + if (other instanceof OpenSearchIndex && ((OpenSearchIndex) other).type() == NESTED_FIELD) { + return isCompatible(NESTED); + } + return false; } - @Override - public String toString() { - return "OpenSearch Data Type [" + getName() + "]"; + // One way compatibility: parent base type is compatible with children + OpenSearchDataType cur = (OpenSearchDataType) other; + while (cur != null && cur != this) { + cur = cur.parent; } + return cur != null; + } + + @Override + public String toString() { + return "OpenSearch Data Type [" + getName() + "]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchIndex.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchIndex.java index b3d971100b..2c790f15aa 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchIndex.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/OpenSearchIndex.java @@ -3,68 +3,66 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.base; import java.util.Objects; import org.opensearch.sql.legacy.antlr.semantic.types.Type; -/** - * Index type is not Enum because essentially each index is a brand new type. - */ +/** Index type is not Enum because essentially each index is a brand new type. */ public class OpenSearchIndex implements BaseType { - public enum IndexType { - INDEX, NESTED_FIELD, INDEX_PATTERN - } + public enum IndexType { + INDEX, + NESTED_FIELD, + INDEX_PATTERN + } - private final String indexName; - private final IndexType indexType; + private final String indexName; + private final IndexType indexType; - public OpenSearchIndex(String indexName, IndexType indexType) { - this.indexName = indexName; - this.indexType = indexType; - } + public OpenSearchIndex(String indexName, IndexType indexType) { + this.indexName = indexName; + this.indexType = indexType; + } - public IndexType type() { - return indexType; - } + public IndexType type() { + return indexType; + } - @Override - public String getName() { - return indexName; - } + @Override + public String getName() { + return indexName; + } - @Override - public boolean isCompatible(Type other) { - return equals(other); - } + @Override + public boolean isCompatible(Type other) { + return equals(other); + } - @Override - public String usage() { - return indexType.name(); - } + @Override + public String usage() { + return indexType.name(); + } - @Override - public String toString() { - return indexType + " [" + indexName + "]"; - } + @Override + public String toString() { + return indexType + " [" + indexName + "]"; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - OpenSearchIndex index = (OpenSearchIndex) o; - return Objects.equals(indexName, index.indexName) - && indexType == index.indexType; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(indexName, indexType); + if (o == null || getClass() != o.getClass()) { + return false; } + OpenSearchIndex index = (OpenSearchIndex) o; + return Objects.equals(indexName, index.indexName) && indexType == index.indexType; + } + + @Override + public int hashCode() { + return Objects.hash(indexName, indexType); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/OpenSearchScalarFunction.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/OpenSearchScalarFunction.java index 93e1950d50..435a5ca968 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/OpenSearchScalarFunction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/OpenSearchScalarFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.function; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType.BOOLEAN; @@ -16,87 +15,73 @@ import org.opensearch.sql.legacy.antlr.semantic.types.Type; import org.opensearch.sql.legacy.antlr.semantic.types.TypeExpression; -/** - * OpenSearch special scalar functions - */ +/** OpenSearch special scalar functions */ public enum OpenSearchScalarFunction implements TypeExpression { + DATE_HISTOGRAM(), // this is aggregate function + DAY_OF_MONTH(func(DATE).to(INTEGER)), + DAY_OF_YEAR(func(DATE).to(INTEGER)), + DAY_OF_WEEK(func(DATE).to(INTEGER)), + EXCLUDE(), // can only be used in SELECT? + EXTENDED_STATS(), // need confirm + FIELD(), // couldn't find test cases related + FILTER(), + GEO_BOUNDING_BOX(func(GEO_POINT, NUMBER, NUMBER, NUMBER, NUMBER).to(BOOLEAN)), + GEO_CELL(), // optional arg or overloaded spec is required. + GEO_DISTANCE(func(GEO_POINT, STRING, NUMBER, NUMBER).to(BOOLEAN)), + GEO_DISTANCE_RANGE(func(GEO_POINT, STRING, NUMBER, NUMBER).to(BOOLEAN)), + GEO_INTERSECTS(), // ? + GEO_POLYGON(), // varargs is required for 2nd arg + HISTOGRAM(), // same as date_histogram + HOUR_OF_DAY(func(DATE).to(INTEGER)), + INCLUDE(), // same as exclude + IN_TERMS(), // varargs + MATCHPHRASE(func(STRING, STRING).to(BOOLEAN), func(STRING).to(STRING)), // slop arg is optional + MATCH_PHRASE(MATCHPHRASE.specifications()), + MATCHQUERY(func(STRING, STRING).to(BOOLEAN), func(STRING).to(STRING)), + MATCH_QUERY(MATCHQUERY.specifications()), + MINUTE_OF_DAY(func(DATE).to(INTEGER)), // or long? + MINUTE_OF_HOUR(func(DATE).to(INTEGER)), + MONTH_OF_YEAR(func(DATE).to(INTEGER)), + MULTIMATCH(), // kw arguments + MULTI_MATCH(MULTIMATCH.specifications()), + NESTED(), // overloaded + PERCENTILES(), // ? + REGEXP_QUERY(), // ? + REVERSE_NESTED(), // need overloaded + QUERY(func(STRING).to(BOOLEAN)), + RANGE(), // aggregate function + SCORE(), // semantic problem? + SECOND_OF_MINUTE(func(DATE).to(INTEGER)), + STATS(), + TERM(), // semantic problem + TERMS(), // semantic problem + TOPHITS(), // only available in SELECT + WEEK_OF_YEAR(func(DATE).to(INTEGER)), + WILDCARDQUERY(func(STRING, STRING).to(BOOLEAN), func(STRING).to(STRING)), + WILDCARD_QUERY(WILDCARDQUERY.specifications()); - DATE_HISTOGRAM(), // this is aggregate function - DAY_OF_MONTH(func(DATE).to(INTEGER)), - DAY_OF_YEAR(func(DATE).to(INTEGER)), - DAY_OF_WEEK(func(DATE).to(INTEGER)), - EXCLUDE(), // can only be used in SELECT? - EXTENDED_STATS(), // need confirm - FIELD(), // couldn't find test cases related - FILTER(), - GEO_BOUNDING_BOX(func(GEO_POINT, NUMBER, NUMBER, NUMBER, NUMBER).to(BOOLEAN)), - GEO_CELL(), // optional arg or overloaded spec is required. - GEO_DISTANCE(func(GEO_POINT, STRING, NUMBER, NUMBER).to(BOOLEAN)), - GEO_DISTANCE_RANGE(func(GEO_POINT, STRING, NUMBER, NUMBER).to(BOOLEAN)), - GEO_INTERSECTS(), //? - GEO_POLYGON(), // varargs is required for 2nd arg - HISTOGRAM(), // same as date_histogram - HOUR_OF_DAY(func(DATE).to(INTEGER)), - INCLUDE(), // same as exclude - IN_TERMS(), // varargs - MATCHPHRASE( - func(STRING, STRING).to(BOOLEAN), - func(STRING).to(STRING) - ), //slop arg is optional - MATCH_PHRASE(MATCHPHRASE.specifications()), - MATCHQUERY( - func(STRING, STRING).to(BOOLEAN), - func(STRING).to(STRING) - ), - MATCH_QUERY(MATCHQUERY.specifications()), - MINUTE_OF_DAY(func(DATE).to(INTEGER)), // or long? - MINUTE_OF_HOUR(func(DATE).to(INTEGER)), - MONTH_OF_YEAR(func(DATE).to(INTEGER)), - MULTIMATCH(), // kw arguments - MULTI_MATCH(MULTIMATCH.specifications()), - NESTED(), // overloaded - PERCENTILES(), //? - REGEXP_QUERY(), //? - REVERSE_NESTED(), // need overloaded - QUERY(func(STRING).to(BOOLEAN)), - RANGE(), // aggregate function - SCORE(), // semantic problem? - SECOND_OF_MINUTE(func(DATE).to(INTEGER)), - STATS(), - TERM(), // semantic problem - TERMS(), // semantic problem - TOPHITS(), // only available in SELECT - WEEK_OF_YEAR(func(DATE).to(INTEGER)), - WILDCARDQUERY( - func(STRING, STRING).to(BOOLEAN), - func(STRING).to(STRING) - ), - WILDCARD_QUERY(WILDCARDQUERY.specifications()); - - - private final TypeExpressionSpec[] specifications; - - OpenSearchScalarFunction(TypeExpressionSpec... specifications) { - this.specifications = specifications; - } + private final TypeExpressionSpec[] specifications; - @Override - public String getName() { - return name(); - } + OpenSearchScalarFunction(TypeExpressionSpec... specifications) { + this.specifications = specifications; + } - @Override - public TypeExpressionSpec[] specifications() { - return specifications; - } + @Override + public String getName() { + return name(); + } - private static TypeExpressionSpec func(Type... argTypes) { - return new TypeExpressionSpec().map(argTypes); - } + @Override + public TypeExpressionSpec[] specifications() { + return specifications; + } - @Override - public String toString() { - return "Function [" + name() + "]"; - } + private static TypeExpressionSpec func(Type... argTypes) { + return new TypeExpressionSpec().map(argTypes); + } + @Override + public String toString() { + return "Function [" + name() + "]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/ScalarFunction.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/ScalarFunction.java index e993562df8..5dfada7ca8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/ScalarFunction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/ScalarFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.function; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType.BOOLEAN; @@ -18,123 +17,98 @@ import org.opensearch.sql.legacy.antlr.semantic.types.TypeExpression; import org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType; -/** - * Scalar SQL function - */ +/** Scalar SQL function */ public enum ScalarFunction implements TypeExpression { + ABS(func(T(NUMBER)).to(T)), // translate to Java: T ABS(T) + ACOS(func(T(NUMBER)).to(DOUBLE)), + ADD(func(T(NUMBER), NUMBER).to(T)), + ASCII(func(T(STRING)).to(INTEGER)), + ASIN(func(T(NUMBER)).to(DOUBLE)), + ATAN(func(T(NUMBER)).to(DOUBLE)), + ATAN2(func(T(NUMBER), NUMBER).to(DOUBLE)), + CAST(), + CBRT(func(T(NUMBER)).to(T)), + CEIL(func(T(NUMBER)).to(T)), + CONCAT(), // TODO: varargs support required + CONCAT_WS(), + COS(func(T(NUMBER)).to(DOUBLE)), + COSH(func(T(NUMBER)).to(DOUBLE)), + COT(func(T(NUMBER)).to(DOUBLE)), + CURDATE(func().to(OpenSearchDataType.DATE)), + DATE(func(OpenSearchDataType.DATE).to(OpenSearchDataType.DATE)), + DATE_FORMAT( + func(OpenSearchDataType.DATE, STRING).to(STRING), + func(OpenSearchDataType.DATE, STRING, STRING).to(STRING)), + DAYOFMONTH(func(OpenSearchDataType.DATE).to(INTEGER)), + DEGREES(func(T(NUMBER)).to(DOUBLE)), + DIVIDE(func(T(NUMBER), NUMBER).to(T)), + E(func().to(DOUBLE)), + EXP(func(T(NUMBER)).to(T)), + EXPM1(func(T(NUMBER)).to(T)), + FLOOR(func(T(NUMBER)).to(T)), + IF(func(BOOLEAN, OPENSEARCH_TYPE, OPENSEARCH_TYPE).to(OPENSEARCH_TYPE)), + IFNULL(func(OPENSEARCH_TYPE, OPENSEARCH_TYPE).to(OPENSEARCH_TYPE)), + ISNULL(func(OPENSEARCH_TYPE).to(INTEGER)), + LEFT(func(T(STRING), INTEGER).to(T)), + LENGTH(func(STRING).to(INTEGER)), + LN(func(T(NUMBER)).to(DOUBLE)), + LOCATE(func(STRING, STRING, INTEGER).to(INTEGER), func(STRING, STRING).to(INTEGER)), + LOG(func(T(NUMBER)).to(DOUBLE), func(T(NUMBER), NUMBER).to(DOUBLE)), + LOG2(func(T(NUMBER)).to(DOUBLE)), + LOG10(func(T(NUMBER)).to(DOUBLE)), + LOWER(func(T(STRING)).to(T), func(T(STRING), STRING).to(T)), + LTRIM(func(T(STRING)).to(T)), + MAKETIME(func(INTEGER, INTEGER, INTEGER).to(OpenSearchDataType.DATE)), + MODULUS(func(T(NUMBER), NUMBER).to(T)), + MONTH(func(OpenSearchDataType.DATE).to(INTEGER)), + MONTHNAME(func(OpenSearchDataType.DATE).to(STRING)), + MULTIPLY(func(T(NUMBER), NUMBER).to(NUMBER)), + NOW(func().to(OpenSearchDataType.DATE)), + PI(func().to(DOUBLE)), + POW(func(T(NUMBER)).to(T), func(T(NUMBER), NUMBER).to(T)), + POWER(func(T(NUMBER)).to(T), func(T(NUMBER), NUMBER).to(T)), + RADIANS(func(T(NUMBER)).to(DOUBLE)), + RAND(func().to(NUMBER), func(T(NUMBER)).to(T)), + REPLACE(func(T(STRING), STRING, STRING).to(T)), + RIGHT(func(T(STRING), INTEGER).to(T)), + RINT(func(T(NUMBER)).to(T)), + ROUND(func(T(NUMBER)).to(T)), + RTRIM(func(T(STRING)).to(T)), + SIGN(func(T(NUMBER)).to(T)), + SIGNUM(func(T(NUMBER)).to(T)), + SIN(func(T(NUMBER)).to(DOUBLE)), + SINH(func(T(NUMBER)).to(DOUBLE)), + SQRT(func(T(NUMBER)).to(T)), + SUBSTRING(func(T(STRING), INTEGER, INTEGER).to(T)), + SUBTRACT(func(T(NUMBER), NUMBER).to(T)), + TAN(func(T(NUMBER)).to(DOUBLE)), + TIMESTAMP(func(OpenSearchDataType.DATE).to(OpenSearchDataType.DATE)), + TRIM(func(T(STRING)).to(T)), + UPPER(func(T(STRING)).to(T), func(T(STRING), STRING).to(T)), + YEAR(func(OpenSearchDataType.DATE).to(INTEGER)); - ABS(func(T(NUMBER)).to(T)), // translate to Java: T ABS(T) - ACOS(func(T(NUMBER)).to(DOUBLE)), - ADD(func(T(NUMBER), NUMBER).to(T)), - ASCII(func(T(STRING)).to(INTEGER)), - ASIN(func(T(NUMBER)).to(DOUBLE)), - ATAN(func(T(NUMBER)).to(DOUBLE)), - ATAN2(func(T(NUMBER), NUMBER).to(DOUBLE)), - CAST(), - CBRT(func(T(NUMBER)).to(T)), - CEIL(func(T(NUMBER)).to(T)), - CONCAT(), // TODO: varargs support required - CONCAT_WS(), - COS(func(T(NUMBER)).to(DOUBLE)), - COSH(func(T(NUMBER)).to(DOUBLE)), - COT(func(T(NUMBER)).to(DOUBLE)), - CURDATE(func().to(OpenSearchDataType.DATE)), - DATE(func(OpenSearchDataType.DATE).to(OpenSearchDataType.DATE)), - DATE_FORMAT( - func(OpenSearchDataType.DATE, STRING).to(STRING), - func(OpenSearchDataType.DATE, STRING, STRING).to(STRING) - ), - DAYOFMONTH(func(OpenSearchDataType.DATE).to(INTEGER)), - DEGREES(func(T(NUMBER)).to(DOUBLE)), - DIVIDE(func(T(NUMBER), NUMBER).to(T)), - E(func().to(DOUBLE)), - EXP(func(T(NUMBER)).to(T)), - EXPM1(func(T(NUMBER)).to(T)), - FLOOR(func(T(NUMBER)).to(T)), - IF(func(BOOLEAN, OPENSEARCH_TYPE, OPENSEARCH_TYPE).to(OPENSEARCH_TYPE)), - IFNULL(func(OPENSEARCH_TYPE, OPENSEARCH_TYPE).to(OPENSEARCH_TYPE)), - ISNULL(func(OPENSEARCH_TYPE).to(INTEGER)), - LEFT(func(T(STRING), INTEGER).to(T)), - LENGTH(func(STRING).to(INTEGER)), - LN(func(T(NUMBER)).to(DOUBLE)), - LOCATE( - func(STRING, STRING, INTEGER).to(INTEGER), - func(STRING, STRING).to(INTEGER) - ), - LOG( - func(T(NUMBER)).to(DOUBLE), - func(T(NUMBER), NUMBER).to(DOUBLE) - ), - LOG2(func(T(NUMBER)).to(DOUBLE)), - LOG10(func(T(NUMBER)).to(DOUBLE)), - LOWER( - func(T(STRING)).to(T), - func(T(STRING), STRING).to(T) - ), - LTRIM(func(T(STRING)).to(T)), - MAKETIME(func(INTEGER, INTEGER, INTEGER).to(OpenSearchDataType.DATE)), - MODULUS(func(T(NUMBER), NUMBER).to(T)), - MONTH(func(OpenSearchDataType.DATE).to(INTEGER)), - MONTHNAME(func(OpenSearchDataType.DATE).to(STRING)), - MULTIPLY(func(T(NUMBER), NUMBER).to(NUMBER)), - NOW(func().to(OpenSearchDataType.DATE)), - PI(func().to(DOUBLE)), - POW( - func(T(NUMBER)).to(T), - func(T(NUMBER), NUMBER).to(T) - ), - POWER( - func(T(NUMBER)).to(T), - func(T(NUMBER), NUMBER).to(T) - ), - RADIANS(func(T(NUMBER)).to(DOUBLE)), - RAND( - func().to(NUMBER), - func(T(NUMBER)).to(T) - ), - REPLACE(func(T(STRING), STRING, STRING).to(T)), - RIGHT(func(T(STRING), INTEGER).to(T)), - RINT(func(T(NUMBER)).to(T)), - ROUND(func(T(NUMBER)).to(T)), - RTRIM(func(T(STRING)).to(T)), - SIGN(func(T(NUMBER)).to(T)), - SIGNUM(func(T(NUMBER)).to(T)), - SIN(func(T(NUMBER)).to(DOUBLE)), - SINH(func(T(NUMBER)).to(DOUBLE)), - SQRT(func(T(NUMBER)).to(T)), - SUBSTRING(func(T(STRING), INTEGER, INTEGER).to(T)), - SUBTRACT(func(T(NUMBER), NUMBER).to(T)), - TAN(func(T(NUMBER)).to(DOUBLE)), - TIMESTAMP(func(OpenSearchDataType.DATE).to(OpenSearchDataType.DATE)), - TRIM(func(T(STRING)).to(T)), - UPPER( - func(T(STRING)).to(T), - func(T(STRING), STRING).to(T) - ), - YEAR(func(OpenSearchDataType.DATE).to(INTEGER)); - - private final TypeExpressionSpec[] specifications; + private final TypeExpressionSpec[] specifications; - ScalarFunction(TypeExpressionSpec... specifications) { - this.specifications = specifications; - } + ScalarFunction(TypeExpressionSpec... specifications) { + this.specifications = specifications; + } - @Override - public String getName() { - return name(); - } + @Override + public String getName() { + return name(); + } - @Override - public TypeExpressionSpec[] specifications() { - return specifications; - } + @Override + public TypeExpressionSpec[] specifications() { + return specifications; + } - private static TypeExpressionSpec func(Type... argTypes) { - return new TypeExpressionSpec().map(argTypes); - } + private static TypeExpressionSpec func(Type... argTypes) { + return new TypeExpressionSpec().map(argTypes); + } - @Override - public String toString() { - return "Function [" + name() + "]"; - } + @Override + public String toString() { + return "Function [" + name() + "]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/special/Product.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/special/Product.java index ad4d86895b..98f04dc629 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/special/Product.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/special/Product.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.special; import java.util.Collections; @@ -12,62 +11,56 @@ import lombok.Getter; import org.opensearch.sql.legacy.antlr.semantic.types.Type; -/** - * Combination of multiple types, ex. function arguments - */ +/** Combination of multiple types, ex. function arguments */ public class Product implements Type { - @Getter - private final List types; + @Getter private final List types; - public Product(List itemTypes) { - types = Collections.unmodifiableList(itemTypes); - } + public Product(List itemTypes) { + types = Collections.unmodifiableList(itemTypes); + } - @Override - public String getName() { - return "Product of types " + types; - } + @Override + public String getName() { + return "Product of types " + types; + } - @Override - public boolean isCompatible(Type other) { - if (!(other instanceof Product)) { - return false; - } - - Product otherProd = (Product) other; - if (types.size() != otherProd.types.size()) { - return false; - } - - for (int i = 0; i < types.size(); i++) { - Type type = types.get(i); - Type otherType = otherProd.types.get(i); - if (!isCompatibleEitherWay(type, otherType)) { - return false; - } - } - return true; + @Override + public boolean isCompatible(Type other) { + if (!(other instanceof Product)) { + return false; } - @Override - public Type construct(List others) { - return this; + Product otherProd = (Product) other; + if (types.size() != otherProd.types.size()) { + return false; } - @Override - public String usage() { - if (types.isEmpty()) { - return "(*)"; - } - return types.stream(). - map(Type::usage). - collect(Collectors.joining(", ", "(", ")")); + for (int i = 0; i < types.size(); i++) { + Type type = types.get(i); + Type otherType = otherProd.types.get(i); + if (!isCompatibleEitherWay(type, otherType)) { + return false; + } } + return true; + } + + @Override + public Type construct(List others) { + return this; + } - /** Perform two-way compatibility check here which is different from normal type expression */ - private boolean isCompatibleEitherWay(Type type1, Type type2) { - return type1.isCompatible(type2) || type2.isCompatible(type1); + @Override + public String usage() { + if (types.isEmpty()) { + return "(*)"; } + return types.stream().map(Type::usage).collect(Collectors.joining(", ", "(", ")")); + } + /** Perform two-way compatibility check here which is different from normal type expression */ + private boolean isCompatibleEitherWay(Type type1, Type type2) { + return type1.isCompatible(type2) || type2.isCompatible(type1); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/visitor/OpenSearchMappingLoader.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/visitor/OpenSearchMappingLoader.java index dca201f25b..4d009dc438 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/visitor/OpenSearchMappingLoader.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/visitor/OpenSearchMappingLoader.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.visitor; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchIndex.IndexType.INDEX; @@ -26,178 +25,181 @@ import org.opensearch.sql.legacy.esdomain.mapping.IndexMappings; import org.opensearch.sql.legacy.utils.StringUtils; -/** - * Load index and nested field mapping into semantic context - */ +/** Load index and nested field mapping into semantic context */ public class OpenSearchMappingLoader implements GenericSqlParseTreeVisitor { - /** Semantic context shared in the semantic analysis process */ - private final SemanticContext context; - - /** Local cluster state for mapping query */ - private final LocalClusterState clusterState; - - /** Threshold to decide if continue the analysis */ - private final int threshold; - - public OpenSearchMappingLoader(SemanticContext context, LocalClusterState clusterState, int threshold) { - this.context = context; - this.clusterState = clusterState; - this.threshold = threshold; - } - - /* - * Suppose index 'accounts' includes 'name', 'age' and nested field 'projects' - * which includes 'name' and 'active'. - * - * 1. Define itself: - * ----- new definitions ----- - * accounts -> INDEX - * - * 2. Define without alias no matter if alias given: - * 'accounts' -> INDEX - * ----- new definitions ----- - * 'name' -> TEXT - * 'age' -> INTEGER - * 'projects' -> NESTED - * 'projects.name' -> KEYWORD - * 'projects.active' -> BOOLEAN - */ - @Override - public Type visitIndexName(String indexName) { - if (isNotNested(indexName)) { - defineIndexType(indexName); - loadAllFieldsWithType(indexName); - } - return defaultValue(); - } - - @Override - public void visitAs(String alias, Type type) { - if (!(type instanceof OpenSearchIndex)) { - return; - } - - OpenSearchIndex index = (OpenSearchIndex) type; - String indexName = type.getName(); - - if (index.type() == INDEX) { - String aliasName = alias.isEmpty() ? indexName : alias; - defineAllFieldNamesByAppendingAliasPrefix(indexName, aliasName); - } else if (index.type() == NESTED_FIELD) { - if (!alias.isEmpty()) { - defineNestedFieldNamesByReplacingWithAlias(indexName, alias); - } - } // else Do nothing for index pattern - } - - private void defineIndexType(String indexName) { - environment().define(new Symbol(Namespace.FIELD_NAME, indexName), new OpenSearchIndex(indexName, INDEX)); - } - - private void loadAllFieldsWithType(String indexName) { - Set mappings = getFieldMappings(indexName); - mappings.forEach(mapping -> mapping.flat(this::defineFieldName)); - } - - /* - * 3.1 Define with alias if given: ex."SELECT * FROM accounts a". - * 'accounts' -> INDEX - * 'name' -> TEXT - * 'age' -> INTEGER - * 'projects' -> NESTED - * 'projects.name' -> KEYWORD - * 'projects.active' -> BOOLEAN - * ----- new definitions ----- - * ['a' -> INDEX] -- this is done in semantic analyzer - * 'a.name' -> TEXT - * 'a.age' -> INTEGER - * 'a.projects' -> NESTED - * 'a.projects.name' -> KEYWORD - * 'a.projects.active' -> BOOLEAN - * - * 3.2 Otherwise define by index full name: ex."SELECT * FROM account" - * 'accounts' -> INDEX - * 'name' -> TEXT - * 'age' -> INTEGER - * 'projects' -> NESTED - * 'projects.name' -> KEYWORD - * 'projects.active' -> BOOLEAN - * ----- new definitions ----- - * 'accounts.name' -> TEXT - * 'accounts.age' -> INTEGER - * 'accounts.projects' -> NESTED - * 'accounts.projects.name' -> KEYWORD - * 'accounts.projects.active' -> BOOLEAN - */ - private void defineAllFieldNamesByAppendingAliasPrefix(String indexName, String alias) { - Set mappings = getFieldMappings(indexName); - mappings.stream().forEach(mapping -> mapping.flat((fieldName, type) -> - defineFieldName(alias + "." + fieldName, type))); + /** Semantic context shared in the semantic analysis process */ + private final SemanticContext context; + + /** Local cluster state for mapping query */ + private final LocalClusterState clusterState; + + /** Threshold to decide if continue the analysis */ + private final int threshold; + + public OpenSearchMappingLoader( + SemanticContext context, LocalClusterState clusterState, int threshold) { + this.context = context; + this.clusterState = clusterState; + this.threshold = threshold; + } + + /* + * Suppose index 'accounts' includes 'name', 'age' and nested field 'projects' + * which includes 'name' and 'active'. + * + * 1. Define itself: + * ----- new definitions ----- + * accounts -> INDEX + * + * 2. Define without alias no matter if alias given: + * 'accounts' -> INDEX + * ----- new definitions ----- + * 'name' -> TEXT + * 'age' -> INTEGER + * 'projects' -> NESTED + * 'projects.name' -> KEYWORD + * 'projects.active' -> BOOLEAN + */ + @Override + public Type visitIndexName(String indexName) { + if (isNotNested(indexName)) { + defineIndexType(indexName); + loadAllFieldsWithType(indexName); } + return defaultValue(); + } - /* - * 3.3 Define with alias if given: ex."SELECT * FROM accounts a, a.project p" - * 'accounts' -> INDEX - * 'name' -> TEXT - * 'age' -> INTEGER - * 'projects' -> NESTED - * 'projects.name' -> KEYWORD - * 'projects.active' -> BOOLEAN - * 'a.name' -> TEXT - * 'a.age' -> INTEGER - * 'a.projects' -> NESTED - * 'a.projects.name' -> KEYWORD - * 'a.projects.active' -> BOOLEAN - * ----- new definitions ----- - * ['p' -> NESTED] -- this is done in semantic analyzer - * 'p.name' -> KEYWORD - * 'p.active' -> BOOLEAN - */ - private void defineNestedFieldNamesByReplacingWithAlias(String nestedFieldName, String alias) { - Map typeByFullName = environment().resolveByPrefix( - new Symbol(Namespace.FIELD_NAME, nestedFieldName)); - typeByFullName.forEach( - (fieldName, fieldType) -> defineFieldName(fieldName.replace(nestedFieldName, alias), fieldType) - ); + @Override + public void visitAs(String alias, Type type) { + if (!(type instanceof OpenSearchIndex)) { + return; } - /** - * Check if index name is NOT nested, for example. return true for index 'accounts' or '.opensearch_dashboards' - * but return false for nested field name 'a.projects'. - */ - private boolean isNotNested(String indexName) { - return indexName.indexOf('.', 1) == -1; // taking care of .opensearch_dashboards + OpenSearchIndex index = (OpenSearchIndex) type; + String indexName = type.getName(); + + if (index.type() == INDEX) { + String aliasName = alias.isEmpty() ? indexName : alias; + defineAllFieldNamesByAppendingAliasPrefix(indexName, aliasName); + } else if (index.type() == NESTED_FIELD) { + if (!alias.isEmpty()) { + defineNestedFieldNamesByReplacingWithAlias(indexName, alias); + } + } // else Do nothing for index pattern + } + + private void defineIndexType(String indexName) { + environment() + .define(new Symbol(Namespace.FIELD_NAME, indexName), new OpenSearchIndex(indexName, INDEX)); + } + + private void loadAllFieldsWithType(String indexName) { + Set mappings = getFieldMappings(indexName); + mappings.forEach(mapping -> mapping.flat(this::defineFieldName)); + } + + /* + * 3.1 Define with alias if given: ex."SELECT * FROM accounts a". + * 'accounts' -> INDEX + * 'name' -> TEXT + * 'age' -> INTEGER + * 'projects' -> NESTED + * 'projects.name' -> KEYWORD + * 'projects.active' -> BOOLEAN + * ----- new definitions ----- + * ['a' -> INDEX] -- this is done in semantic analyzer + * 'a.name' -> TEXT + * 'a.age' -> INTEGER + * 'a.projects' -> NESTED + * 'a.projects.name' -> KEYWORD + * 'a.projects.active' -> BOOLEAN + * + * 3.2 Otherwise define by index full name: ex."SELECT * FROM account" + * 'accounts' -> INDEX + * 'name' -> TEXT + * 'age' -> INTEGER + * 'projects' -> NESTED + * 'projects.name' -> KEYWORD + * 'projects.active' -> BOOLEAN + * ----- new definitions ----- + * 'accounts.name' -> TEXT + * 'accounts.age' -> INTEGER + * 'accounts.projects' -> NESTED + * 'accounts.projects.name' -> KEYWORD + * 'accounts.projects.active' -> BOOLEAN + */ + private void defineAllFieldNamesByAppendingAliasPrefix(String indexName, String alias) { + Set mappings = getFieldMappings(indexName); + mappings.stream() + .forEach( + mapping -> + mapping.flat((fieldName, type) -> defineFieldName(alias + "." + fieldName, type))); + } + + /* + * 3.3 Define with alias if given: ex."SELECT * FROM accounts a, a.project p" + * 'accounts' -> INDEX + * 'name' -> TEXT + * 'age' -> INTEGER + * 'projects' -> NESTED + * 'projects.name' -> KEYWORD + * 'projects.active' -> BOOLEAN + * 'a.name' -> TEXT + * 'a.age' -> INTEGER + * 'a.projects' -> NESTED + * 'a.projects.name' -> KEYWORD + * 'a.projects.active' -> BOOLEAN + * ----- new definitions ----- + * ['p' -> NESTED] -- this is done in semantic analyzer + * 'p.name' -> KEYWORD + * 'p.active' -> BOOLEAN + */ + private void defineNestedFieldNamesByReplacingWithAlias(String nestedFieldName, String alias) { + Map typeByFullName = + environment().resolveByPrefix(new Symbol(Namespace.FIELD_NAME, nestedFieldName)); + typeByFullName.forEach( + (fieldName, fieldType) -> + defineFieldName(fieldName.replace(nestedFieldName, alias), fieldType)); + } + + /** + * Check if index name is NOT nested, for example. return true for index 'accounts' or + * '.opensearch_dashboards' but return false for nested field name 'a.projects'. + */ + private boolean isNotNested(String indexName) { + return indexName.indexOf('.', 1) == -1; // taking care of .opensearch_dashboards + } + + private Set getFieldMappings(String indexName) { + IndexMappings indexMappings = clusterState.getFieldMappings(new String[] {indexName}); + Set fieldMappingsSet = new HashSet<>(indexMappings.allMappings()); + + for (FieldMappings fieldMappings : fieldMappingsSet) { + int size = fieldMappings.data().size(); + if (size > threshold) { + throw new EarlyExitAnalysisException( + StringUtils.format( + "Index [%s] has [%d] fields more than threshold [%d]", indexName, size, threshold)); + } } - - private Set getFieldMappings(String indexName) { - IndexMappings indexMappings = clusterState.getFieldMappings(new String[]{indexName}); - Set fieldMappingsSet = new HashSet<>(indexMappings.allMappings()); - - for (FieldMappings fieldMappings : fieldMappingsSet) { - int size = fieldMappings.data().size(); - if (size > threshold) { - throw new EarlyExitAnalysisException(StringUtils.format( - "Index [%s] has [%d] fields more than threshold [%d]", indexName, size, threshold)); - } - } - return fieldMappingsSet; + return fieldMappingsSet; + } + + private void defineFieldName(String fieldName, String type) { + if ("NESTED".equalsIgnoreCase(type)) { + defineFieldName(fieldName, new OpenSearchIndex(fieldName, NESTED_FIELD)); + } else { + defineFieldName(fieldName, OpenSearchDataType.typeOf(type)); } + } - private void defineFieldName(String fieldName, String type) { - if ("NESTED".equalsIgnoreCase(type)) { - defineFieldName(fieldName, new OpenSearchIndex(fieldName, NESTED_FIELD)); - } else { - defineFieldName(fieldName, OpenSearchDataType.typeOf(type)); - } - } + private void defineFieldName(String fieldName, Type type) { + Symbol symbol = new Symbol(Namespace.FIELD_NAME, fieldName); + environment().define(symbol, type); + } - private void defineFieldName(String fieldName, Type type) { - Symbol symbol = new Symbol(Namespace.FIELD_NAME, fieldName); - environment().define(symbol, type); - } - - private Environment environment() { - return context.peek(); - } + private Environment environment() { + return context.peek(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/Reducible.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/Reducible.java index 510a76659e..edb4136d49 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/Reducible.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/Reducible.java @@ -3,21 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.visitor; import java.util.List; -/** - * Abstraction for anything that can be reduced and used by {@link AntlrSqlParseTreeVisitor}. - */ +/** Abstraction for anything that can be reduced and used by {@link AntlrSqlParseTreeVisitor}. */ public interface Reducible { - /** - * Reduce current and others to generate a new one - * @param others others - * @return reduction - */ - T reduce(List others); - + /** + * Reduce current and others to generate a new one + * + * @param others others + * @return reduction + */ + T reduce(List others); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/NullCursor.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/NullCursor.java index fb6beca96d..5b99f49515 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/NullCursor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/NullCursor.java @@ -3,27 +3,24 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.cursor; -/** - * A placeholder Cursor implementation to work with non-paginated queries. - */ +/** A placeholder Cursor implementation to work with non-paginated queries. */ public class NullCursor implements Cursor { - private final CursorType type = CursorType.NULL; + private final CursorType type = CursorType.NULL; - @Override - public String generateCursorId() { - return null; - } + @Override + public String generateCursorId() { + return null; + } - @Override - public CursorType getType() { - return type; - } + @Override + public CursorType getType() { + return type; + } - public NullCursor from(String cursorId) { - return NULL_CURSOR; - } + public NullCursor from(String cursorId) { + return NULL_CURSOR; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Order.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Order.java index 2a9be3ce91..f593d6c428 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Order.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Order.java @@ -3,56 +3,53 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; /** - * - * * @author ansj */ public class Order { - private String nestedPath; - private String name; - private String type; - private Field sortField; - - public boolean isScript() { - return sortField != null && sortField.isScriptField(); - } - - public Order(String nestedPath, String name, String type, Field sortField) { - this.nestedPath = nestedPath; - this.name = name; - this.type = type; - this.sortField = sortField; - } - - public String getNestedPath() { - return nestedPath; - } - - public void setNestedPath(String nestedPath) { - this.nestedPath = nestedPath; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getType() { - return type; - } - - public void setType(String type) { - this.type = type; - } - - public Field getSortField() { - return sortField; - } + private String nestedPath; + private String name; + private String type; + private Field sortField; + + public boolean isScript() { + return sortField != null && sortField.isScriptField(); + } + + public Order(String nestedPath, String name, String type, Field sortField) { + this.nestedPath = nestedPath; + this.name = name; + this.type = type; + this.sortField = sortField; + } + + public String getNestedPath() { + return nestedPath; + } + + public void setNestedPath(String nestedPath) { + this.nestedPath = nestedPath; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Field getSortField() { + return sortField; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Paramer.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Paramer.java index 6cdf0148a8..38ca556199 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Paramer.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Paramer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import com.alibaba.druid.sql.ast.SQLExpr; @@ -25,163 +24,164 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.utils.Util; - public class Paramer { - public String analysis; - public Float boost; - public String value; - public Integer slop; - - public Map fieldsBoosts = new HashMap<>(); - public String type; - public Float tieBreaker; - public Operator operator; - - public String default_field; - - public static Paramer parseParamer(SQLMethodInvokeExpr method) throws SqlParseException { - Paramer instance = new Paramer(); - List parameters = method.getParameters(); - for (SQLExpr expr : parameters) { - if (expr instanceof SQLCharExpr) { - if (instance.value == null) { - instance.value = ((SQLCharExpr) expr).getText(); - } else { - instance.analysis = ((SQLCharExpr) expr).getText(); - } - } else if (expr instanceof SQLNumericLiteralExpr) { - instance.boost = ((SQLNumericLiteralExpr) expr).getNumber().floatValue(); - } else if (expr instanceof SQLBinaryOpExpr) { - SQLBinaryOpExpr sqlExpr = (SQLBinaryOpExpr) expr; - switch (Util.expr2Object(sqlExpr.getLeft()).toString()) { - case "query": - instance.value = Util.expr2Object(sqlExpr.getRight()).toString(); - break; - case "analyzer": - instance.analysis = Util.expr2Object(sqlExpr.getRight()).toString(); - break; - case "boost": - instance.boost = Float.parseFloat(Util.expr2Object(sqlExpr.getRight()).toString()); - break; - case "slop": - instance.slop = Integer.parseInt(Util.expr2Object(sqlExpr.getRight()).toString()); - break; - - case "fields": - int index; - for (String f : Strings.splitStringByCommaToArray( - Util.expr2Object(sqlExpr.getRight()).toString())) { - index = f.lastIndexOf('^'); - if (-1 < index) { - instance.fieldsBoosts.put(f.substring(0, index), - Float.parseFloat(f.substring(index + 1))); - } else { - instance.fieldsBoosts.put(f, 1.0F); - } - } - break; - case "type": - instance.type = Util.expr2Object(sqlExpr.getRight()).toString(); - break; - case "tie_breaker": - instance.tieBreaker = Float.parseFloat(Util.expr2Object(sqlExpr.getRight()).toString()); - break; - case "operator": - instance.operator = Operator.fromString(Util.expr2Object(sqlExpr.getRight()).toString()); - break; - - case "default_field": - instance.default_field = Util.expr2Object(sqlExpr.getRight()).toString(); - break; - - default: - break; - } + public String analysis; + public Float boost; + public String value; + public Integer slop; + + public Map fieldsBoosts = new HashMap<>(); + public String type; + public Float tieBreaker; + public Operator operator; + + public String default_field; + + public static Paramer parseParamer(SQLMethodInvokeExpr method) throws SqlParseException { + Paramer instance = new Paramer(); + List parameters = method.getParameters(); + for (SQLExpr expr : parameters) { + if (expr instanceof SQLCharExpr) { + if (instance.value == null) { + instance.value = ((SQLCharExpr) expr).getText(); + } else { + instance.analysis = ((SQLCharExpr) expr).getText(); + } + } else if (expr instanceof SQLNumericLiteralExpr) { + instance.boost = ((SQLNumericLiteralExpr) expr).getNumber().floatValue(); + } else if (expr instanceof SQLBinaryOpExpr) { + SQLBinaryOpExpr sqlExpr = (SQLBinaryOpExpr) expr; + switch (Util.expr2Object(sqlExpr.getLeft()).toString()) { + case "query": + instance.value = Util.expr2Object(sqlExpr.getRight()).toString(); + break; + case "analyzer": + instance.analysis = Util.expr2Object(sqlExpr.getRight()).toString(); + break; + case "boost": + instance.boost = Float.parseFloat(Util.expr2Object(sqlExpr.getRight()).toString()); + break; + case "slop": + instance.slop = Integer.parseInt(Util.expr2Object(sqlExpr.getRight()).toString()); + break; + + case "fields": + int index; + for (String f : + Strings.splitStringByCommaToArray( + Util.expr2Object(sqlExpr.getRight()).toString())) { + index = f.lastIndexOf('^'); + if (-1 < index) { + instance.fieldsBoosts.put( + f.substring(0, index), Float.parseFloat(f.substring(index + 1))); + } else { + instance.fieldsBoosts.put(f, 1.0F); + } } - } - - return instance; + break; + case "type": + instance.type = Util.expr2Object(sqlExpr.getRight()).toString(); + break; + case "tie_breaker": + instance.tieBreaker = Float.parseFloat(Util.expr2Object(sqlExpr.getRight()).toString()); + break; + case "operator": + instance.operator = + Operator.fromString(Util.expr2Object(sqlExpr.getRight()).toString()); + break; + + case "default_field": + instance.default_field = Util.expr2Object(sqlExpr.getRight()).toString(); + break; + + default: + break; + } + } } - public static ToXContent fullParamer(MatchPhraseQueryBuilder query, Paramer paramer) { - if (paramer.analysis != null) { - query.analyzer(paramer.analysis); - } + return instance; + } - if (paramer.boost != null) { - query.boost(paramer.boost); - } + public static ToXContent fullParamer(MatchPhraseQueryBuilder query, Paramer paramer) { + if (paramer.analysis != null) { + query.analyzer(paramer.analysis); + } - if (paramer.slop != null) { - query.slop(paramer.slop); - } + if (paramer.boost != null) { + query.boost(paramer.boost); + } - return query; + if (paramer.slop != null) { + query.slop(paramer.slop); } - public static ToXContent fullParamer(MatchQueryBuilder query, Paramer paramer) { - if (paramer.analysis != null) { - query.analyzer(paramer.analysis); - } + return query; + } - if (paramer.boost != null) { - query.boost(paramer.boost); - } - return query; + public static ToXContent fullParamer(MatchQueryBuilder query, Paramer paramer) { + if (paramer.analysis != null) { + query.analyzer(paramer.analysis); } - public static ToXContent fullParamer(WildcardQueryBuilder query, Paramer paramer) { - if (paramer.boost != null) { - query.boost(paramer.boost); - } - return query; + if (paramer.boost != null) { + query.boost(paramer.boost); } + return query; + } - public static ToXContent fullParamer(QueryStringQueryBuilder query, Paramer paramer) { - if (paramer.analysis != null) { - query.analyzer(paramer.analysis); - } + public static ToXContent fullParamer(WildcardQueryBuilder query, Paramer paramer) { + if (paramer.boost != null) { + query.boost(paramer.boost); + } + return query; + } - if (paramer.boost != null) { - query.boost(paramer.boost); - } + public static ToXContent fullParamer(QueryStringQueryBuilder query, Paramer paramer) { + if (paramer.analysis != null) { + query.analyzer(paramer.analysis); + } - if (paramer.slop != null) { - query.phraseSlop(paramer.slop); - } + if (paramer.boost != null) { + query.boost(paramer.boost); + } - if (paramer.default_field != null) { - query.defaultField(paramer.default_field); - } + if (paramer.slop != null) { + query.phraseSlop(paramer.slop); + } - return query; + if (paramer.default_field != null) { + query.defaultField(paramer.default_field); } - public static ToXContent fullParamer(MultiMatchQueryBuilder query, Paramer paramer) { - if (paramer.analysis != null) { - query.analyzer(paramer.analysis); - } + return query; + } - if (paramer.boost != null) { - query.boost(paramer.boost); - } + public static ToXContent fullParamer(MultiMatchQueryBuilder query, Paramer paramer) { + if (paramer.analysis != null) { + query.analyzer(paramer.analysis); + } - if (paramer.slop != null) { - query.slop(paramer.slop); - } + if (paramer.boost != null) { + query.boost(paramer.boost); + } - if (paramer.type != null) { - query.type(paramer.type); - } + if (paramer.slop != null) { + query.slop(paramer.slop); + } - if (paramer.tieBreaker != null) { - query.tieBreaker(paramer.tieBreaker); - } + if (paramer.type != null) { + query.type(paramer.type); + } - if (paramer.operator != null) { - query.operator(paramer.operator); - } + if (paramer.tieBreaker != null) { + query.tieBreaker(paramer.tieBreaker); + } - return query; + if (paramer.operator != null) { + query.operator(paramer.operator); } + + return query; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Query.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Query.java index b0538591b8..6f891e7fc5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Query.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Query.java @@ -3,45 +3,39 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import java.util.ArrayList; import java.util.List; -/** - * Represents abstract query. every query - * has indexes, types, and where clause. - */ +/** Represents abstract query. every query has indexes, types, and where clause. */ public abstract class Query implements QueryStatement { - private Where where = null; - private List from = new ArrayList<>(); - - - public Where getWhere() { - return this.where; - } - - public void setWhere(Where where) { - this.where = where; - } - - public List getFrom() { - return from; - } - - - /** - * Get the indexes the query refer to. - * - * @return list of strings, the indexes names - */ - public String[] getIndexArr() { - String[] indexArr = new String[this.from.size()]; - for (int i = 0; i < indexArr.length; i++) { - indexArr[i] = this.from.get(i).getIndex(); - } - return indexArr; + private Where where = null; + private List from = new ArrayList<>(); + + public Where getWhere() { + return this.where; + } + + public void setWhere(Where where) { + this.where = where; + } + + public List getFrom() { + return from; + } + + /** + * Get the indexes the query refer to. + * + * @return list of strings, the indexes names + */ + public String[] getIndexArr() { + String[] indexArr = new String[this.from.size()]; + for (int i = 0; i < indexArr.length; i++) { + indexArr[i] = this.from.get(i).getIndex(); } + return indexArr; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryActionRequest.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryActionRequest.java index f13e053d92..f536e3ad6f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryActionRequest.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryActionRequest.java @@ -3,20 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.legacy.executor.Format; -/** - * The definition of QueryActionRequest. - */ +/** The definition of QueryActionRequest. */ @Getter @RequiredArgsConstructor public class QueryActionRequest { - private final String sql; - private final ColumnTypeProvider typeProvider; - private final Format format; + private final String sql; + private final ColumnTypeProvider typeProvider; + private final Format format; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryStatement.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryStatement.java index 26c0b07517..71fe64906a 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryStatement.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/QueryStatement.java @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; -/** - * Identifier interface used to encompass Query and IndexStatements - */ -public interface QueryStatement { -} +/** Identifier interface used to encompass Query and IndexStatements */ +public interface QueryStatement {} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/ScriptMethodField.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/ScriptMethodField.java index bdc42b4ff3..bb4d17d897 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/ScriptMethodField.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/ScriptMethodField.java @@ -3,29 +3,27 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; import java.util.List; -/** - * Stores information about function name for script fields - */ +/** Stores information about function name for script fields */ public class ScriptMethodField extends MethodField { - private final String functionName; + private final String functionName; - public ScriptMethodField(String functionName, List params, SQLAggregateOption option, String alias) { - super("script", params, option, alias); - this.functionName = functionName; - } + public ScriptMethodField( + String functionName, List params, SQLAggregateOption option, String alias) { + super("script", params, option, alias); + this.functionName = functionName; + } - public String getFunctionName() { - return functionName; - } + public String getFunctionName() { + return functionName; + } - @Override - public boolean isScriptField() { - return true; - } + @Override + public boolean isScriptField() { + return true; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/SearchResult.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/SearchResult.java index 5b7b73a910..e951c84961 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/SearchResult.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/SearchResult.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import java.util.ArrayList; @@ -29,128 +28,120 @@ import org.opensearch.sql.legacy.exception.SqlParseException; public class SearchResult { - /** - * - */ - private List> results; - - private long total; - - double maxScore = 0; - - public SearchResult(SearchResponse resp) { - SearchHits hits = resp.getHits(); - this.total = Optional.ofNullable(hits.getTotalHits()).map(totalHits -> totalHits.value).orElse(0L); - results = new ArrayList<>(hits.getHits().length); - for (SearchHit searchHit : hits.getHits()) { - if (searchHit.getSourceAsMap() != null) { - results.add(searchHit.getSourceAsMap()); - } else if (searchHit.getFields() != null) { - Map fields = searchHit.getFields(); - results.add(toFieldsMap(fields)); - } - - } + /** */ + private List> results; + + private long total; + + double maxScore = 0; + + public SearchResult(SearchResponse resp) { + SearchHits hits = resp.getHits(); + this.total = + Optional.ofNullable(hits.getTotalHits()).map(totalHits -> totalHits.value).orElse(0L); + results = new ArrayList<>(hits.getHits().length); + for (SearchHit searchHit : hits.getHits()) { + if (searchHit.getSourceAsMap() != null) { + results.add(searchHit.getSourceAsMap()); + } else if (searchHit.getFields() != null) { + Map fields = searchHit.getFields(); + results.add(toFieldsMap(fields)); + } } + } - public SearchResult(SearchResponse resp, Select select) throws SqlParseException { - Aggregations aggs = resp.getAggregations(); - if (aggs.get("filter") != null) { - InternalFilter inf = aggs.get("filter"); - aggs = inf.getAggregations(); - } - if (aggs.get("group by") != null) { - InternalTerms terms = aggs.get("group by"); - Collection buckets = terms.getBuckets(); - this.total = buckets.size(); - results = new ArrayList<>(buckets.size()); - for (Bucket bucket : buckets) { - Map aggsMap = toAggsMap(bucket.getAggregations().getAsMap()); - aggsMap.put("docCount", bucket.getDocCount()); - results.add(aggsMap); - } - } else { - results = new ArrayList<>(1); - this.total = 1; - Map map = new HashMap<>(); - for (Aggregation aggregation : aggs) { - map.put(aggregation.getName(), covenValue(aggregation)); - } - results.add(map); - } - + public SearchResult(SearchResponse resp, Select select) throws SqlParseException { + Aggregations aggs = resp.getAggregations(); + if (aggs.get("filter") != null) { + InternalFilter inf = aggs.get("filter"); + aggs = inf.getAggregations(); } - - /** - * - * - * @param fields - * @return - */ - private Map toFieldsMap(Map fields) { - Map result = new HashMap<>(); - for (Entry entry : fields.entrySet()) { - if (entry.getValue().getValues().size() > 1) { - result.put(entry.getKey(), entry.getValue().getValues()); - } else { - result.put(entry.getKey(), entry.getValue().getValue()); - } - - } - return result; + if (aggs.get("group by") != null) { + InternalTerms terms = aggs.get("group by"); + Collection buckets = terms.getBuckets(); + this.total = buckets.size(); + results = new ArrayList<>(buckets.size()); + for (Bucket bucket : buckets) { + Map aggsMap = toAggsMap(bucket.getAggregations().getAsMap()); + aggsMap.put("docCount", bucket.getDocCount()); + results.add(aggsMap); + } + } else { + results = new ArrayList<>(1); + this.total = 1; + Map map = new HashMap<>(); + for (Aggregation aggregation : aggs) { + map.put(aggregation.getName(), covenValue(aggregation)); + } + results.add(map); } - - /** - * - * - * @param fields - * @return - * @throws SqlParseException - */ - private Map toAggsMap(Map fields) throws SqlParseException { - Map result = new HashMap<>(); - for (Entry entry : fields.entrySet()) { - result.put(entry.getKey(), covenValue(entry.getValue())); - } - return result; + } + + /** + * @param fields + * @return + */ + private Map toFieldsMap(Map fields) { + Map result = new HashMap<>(); + for (Entry entry : fields.entrySet()) { + if (entry.getValue().getValues().size() > 1) { + result.put(entry.getKey(), entry.getValue().getValues()); + } else { + result.put(entry.getKey(), entry.getValue().getValue()); + } } - - private Object covenValue(Aggregation value) throws SqlParseException { - if (value instanceof InternalNumericMetricsAggregation.SingleValue) { - return ((InternalNumericMetricsAggregation.SingleValue) value).value(); - } else if (value instanceof InternalValueCount) { - return ((InternalValueCount) value).getValue(); - } else if (value instanceof InternalTopHits) { - return (value); - } else if (value instanceof LongTerms) { - return value; - } else { - throw new SqlParseException("Unknown aggregation value type: " + value.getClass().getSimpleName()); - } + return result; + } + + /** + * @param fields + * @return + * @throws SqlParseException + */ + private Map toAggsMap(Map fields) throws SqlParseException { + Map result = new HashMap<>(); + for (Entry entry : fields.entrySet()) { + result.put(entry.getKey(), covenValue(entry.getValue())); } - - public List> getResults() { - return results; + return result; + } + + private Object covenValue(Aggregation value) throws SqlParseException { + if (value instanceof InternalNumericMetricsAggregation.SingleValue) { + return ((InternalNumericMetricsAggregation.SingleValue) value).value(); + } else if (value instanceof InternalValueCount) { + return ((InternalValueCount) value).getValue(); + } else if (value instanceof InternalTopHits) { + return (value); + } else if (value instanceof LongTerms) { + return value; + } else { + throw new SqlParseException( + "Unknown aggregation value type: " + value.getClass().getSimpleName()); } + } - public void setResults(List> results) { - this.results = results; - } + public List> getResults() { + return results; + } - public long getTotal() { - return total; - } + public void setResults(List> results) { + this.results = results; + } - public void setTotal(long total) { - this.total = total; - } + public long getTotal() { + return total; + } - public double getMaxScore() { - return maxScore; - } + public void setTotal(long total) { + this.total = total; + } - public void setMaxScore(double maxScore) { - this.maxScore = maxScore; - } + public double getMaxScore() { + return maxScore; + } + public void setMaxScore(double maxScore) { + this.maxScore = maxScore; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java index cd600d856e..2faa8cc6e5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import static com.alibaba.druid.sql.ast.statement.SQLJoinTableSource.JoinType; @@ -16,7 +15,6 @@ import org.opensearch.sql.legacy.domain.hints.Hint; import org.opensearch.sql.legacy.parser.SubQueryExpression; - /** * sql select * @@ -24,167 +22,169 @@ */ public class Select extends Query { - /** - * Using this functions will cause query to execute as aggregation. - */ - private static final Set AGGREGATE_FUNCTIONS = - ImmutableSet.of( - "SUM", "MAX", "MIN", "AVG", - "TOPHITS", "COUNT", "STATS", "EXTENDED_STATS", - "PERCENTILES", "SCRIPTED_METRIC" - ); - - private List hints = new ArrayList<>(); - private List fields = new ArrayList<>(); - private List> groupBys = new ArrayList<>(); - private Having having; - private List orderBys = new ArrayList<>(); - private int offset; - private Integer rowCount; - private boolean containsSubQueries; - private List subQueries; - private boolean selectAll = false; - private JoinType nestedJoinType = JoinType.COMMA; - - public boolean isQuery = false; - public boolean isAggregate = false; - - public static final int DEFAULT_LIMIT = 200; - - public Select() { - } - - public List getFields() { - return fields; - } - - public void setOffset(int offset) { - this.offset = offset; - } - - public void setRowCount(Integer rowCount) { - this.rowCount = rowCount; - } - - public void addGroupBy(Field field) { - List wrapper = new ArrayList<>(); - wrapper.add(field); - addGroupBy(wrapper); - } - - public void addGroupBy(List fields) { - isAggregate = true; - selectAll = false; - this.groupBys.add(fields); - } + /** Using this functions will cause query to execute as aggregation. */ + private static final Set AGGREGATE_FUNCTIONS = + ImmutableSet.of( + "SUM", + "MAX", + "MIN", + "AVG", + "TOPHITS", + "COUNT", + "STATS", + "EXTENDED_STATS", + "PERCENTILES", + "SCRIPTED_METRIC"); + + private List hints = new ArrayList<>(); + private List fields = new ArrayList<>(); + private List> groupBys = new ArrayList<>(); + private Having having; + private List orderBys = new ArrayList<>(); + private int offset; + private Integer rowCount; + private boolean containsSubQueries; + private List subQueries; + private boolean selectAll = false; + private JoinType nestedJoinType = JoinType.COMMA; + + public boolean isQuery = false; + public boolean isAggregate = false; + + public static final int DEFAULT_LIMIT = 200; + + public Select() {} + + public List getFields() { + return fields; + } + + public void setOffset(int offset) { + this.offset = offset; + } + + public void setRowCount(Integer rowCount) { + this.rowCount = rowCount; + } + + public void addGroupBy(Field field) { + List wrapper = new ArrayList<>(); + wrapper.add(field); + addGroupBy(wrapper); + } + + public void addGroupBy(List fields) { + isAggregate = true; + selectAll = false; + this.groupBys.add(fields); + } + + public List> getGroupBys() { + return groupBys; + } + + public Having getHaving() { + return having; + } + + public void setHaving(Having having) { + this.having = having; + } + + public List getOrderBys() { + return orderBys; + } + + public int getOffset() { + return offset; + } + + public Integer getRowCount() { + return rowCount; + } + + public void addOrderBy(String nestedPath, String name, String type, Field field) { + if ("_score".equals(name)) { + isQuery = true; + } + this.orderBys.add(new Order(nestedPath, name, type, field)); + } + + public void addField(Field field) { + if (field == null) { + return; + } + if (field == STAR && !isAggregate) { + // Ignore GROUP BY since columns present in result are decided by column list in GROUP BY + this.selectAll = true; + return; + } + + if (field instanceof MethodField + && AGGREGATE_FUNCTIONS.contains(field.getName().toUpperCase())) { + isAggregate = true; + } + + fields.add(field); + } + + public List getHints() { + return hints; + } + + public JoinType getNestedJoinType() { + return nestedJoinType; + } + + public void setNestedJoinType(JoinType nestedJoinType) { + this.nestedJoinType = nestedJoinType; + } + + public void fillSubQueries() { + subQueries = new ArrayList<>(); + Where where = this.getWhere(); + fillSubQueriesFromWhereRecursive(where); + } - public List> getGroupBys() { - return groupBys; + private void fillSubQueriesFromWhereRecursive(Where where) { + if (where == null) { + return; } - - public Having getHaving() { - return having; - } - - public void setHaving(Having having) { - this.having = having; - } - - public List getOrderBys() { - return orderBys; - } - - public int getOffset() { - return offset; - } - - public Integer getRowCount() { - return rowCount; - } - - public void addOrderBy(String nestedPath, String name, String type, Field field) { - if ("_score".equals(name)) { - isQuery = true; + if (where instanceof Condition) { + Condition condition = (Condition) where; + if (condition.getValue() instanceof SubQueryExpression) { + this.subQueries.add((SubQueryExpression) condition.getValue()); + this.containsSubQueries = true; + } + if (condition.getValue() instanceof Object[]) { + + for (Object o : (Object[]) condition.getValue()) { + if (o instanceof SubQueryExpression) { + this.subQueries.add((SubQueryExpression) o); + this.containsSubQueries = true; + } } - this.orderBys.add(new Order(nestedPath, name, type, field)); + } + } else { + for (Where innerWhere : where.getWheres()) { + fillSubQueriesFromWhereRecursive(innerWhere); + } } + } - public void addField(Field field) { - if (field == null) { - return; - } - if (field == STAR && !isAggregate) { - // Ignore GROUP BY since columns present in result are decided by column list in GROUP BY - this.selectAll = true; - return; - } - - if (field instanceof MethodField && AGGREGATE_FUNCTIONS.contains(field.getName().toUpperCase())) { - isAggregate = true; - } - - fields.add(field); - } - - public List getHints() { - return hints; - } - - - public JoinType getNestedJoinType() { - return nestedJoinType; - } - - public void setNestedJoinType(JoinType nestedJoinType) { - this.nestedJoinType = nestedJoinType; - } + public boolean containsSubQueries() { + return containsSubQueries; + } + public List getSubQueries() { + return subQueries; + } - public void fillSubQueries() { - subQueries = new ArrayList<>(); - Where where = this.getWhere(); - fillSubQueriesFromWhereRecursive(where); - } - - private void fillSubQueriesFromWhereRecursive(Where where) { - if (where == null) { - return; - } - if (where instanceof Condition) { - Condition condition = (Condition) where; - if (condition.getValue() instanceof SubQueryExpression) { - this.subQueries.add((SubQueryExpression) condition.getValue()); - this.containsSubQueries = true; - } - if (condition.getValue() instanceof Object[]) { - - for (Object o : (Object[]) condition.getValue()) { - if (o instanceof SubQueryExpression) { - this.subQueries.add((SubQueryExpression) o); - this.containsSubQueries = true; - } - } - } - } else { - for (Where innerWhere : where.getWheres()) { - fillSubQueriesFromWhereRecursive(innerWhere); - } - } - } - - public boolean containsSubQueries() { - return containsSubQueries; - } - - public List getSubQueries() { - return subQueries; - } + public boolean isOrderdSelect() { + return this.getOrderBys() != null && this.getOrderBys().size() > 0; + } - public boolean isOrderdSelect() { - return this.getOrderBys() != null && this.getOrderBys().size() > 0; - } - - public boolean isSelectAll() { - return selectAll; - } + public boolean isSelectAll() { + return selectAll; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/Path.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/Path.java index d5c897cf90..4827e0e61c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/Path.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/Path.java @@ -3,49 +3,49 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain.bucketpath; public class Path { - private final String path; - private final String separator; - private final PathType type; - - private Path(String path, String separator, PathType type) { - this.path = path; - this.separator = separator; - this.type = type; - } - - public String getPath() { - return path; - } - - public String getSeparator() { - return separator; - } - - public PathType getType() { - return type; - } - - public boolean isMetricPath() { - return type == PathType.METRIC; - } - - public boolean isAggPath() { - return type == PathType.AGG; - } - - public static Path getAggPath(String path) { - return new Path(path, ">", PathType.AGG); - } - - public static Path getMetricPath(String path) { - return new Path(path, ".", PathType.METRIC); - } - - public enum PathType { - AGG, METRIC - } + private final String path; + private final String separator; + private final PathType type; + + private Path(String path, String separator, PathType type) { + this.path = path; + this.separator = separator; + this.type = type; + } + + public String getPath() { + return path; + } + + public String getSeparator() { + return separator; + } + + public PathType getType() { + return type; + } + + public boolean isMetricPath() { + return type == PathType.METRIC; + } + + public boolean isAggPath() { + return type == PathType.AGG; + } + + public static Path getAggPath(String path) { + return new Path(path, ">", PathType.AGG); + } + + public static Path getMetricPath(String path) { + return new Path(path, ".", PathType.METRIC); + } + + public enum PathType { + AGG, + METRIC + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/esdomain/OpenSearchClient.java b/legacy/src/main/java/org/opensearch/sql/legacy/esdomain/OpenSearchClient.java index a823947466..fd02486fae 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/esdomain/OpenSearchClient.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/esdomain/OpenSearchClient.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.esdomain; import java.util.ArrayList; @@ -19,47 +18,57 @@ public class OpenSearchClient { - private static final Logger LOG = LogManager.getLogger(); - private static final int[] retryIntervals = new int[]{4, 12, 20, 20}; - private final Client client; + private static final Logger LOG = LogManager.getLogger(); + private static final int[] retryIntervals = new int[] {4, 12, 20, 20}; + private final Client client; - public OpenSearchClient(Client client) { - this.client = client; - } + public OpenSearchClient(Client client) { + this.client = client; + } - public MultiSearchResponse.Item[] multiSearch(MultiSearchRequest multiSearchRequest) { - MultiSearchResponse.Item[] responses = new MultiSearchResponse.Item[multiSearchRequest.requests().size()]; - multiSearchRetry(responses, multiSearchRequest, - IntStream.range(0, multiSearchRequest.requests().size()).boxed().collect(Collectors.toList()), 0); + public MultiSearchResponse.Item[] multiSearch(MultiSearchRequest multiSearchRequest) { + MultiSearchResponse.Item[] responses = + new MultiSearchResponse.Item[multiSearchRequest.requests().size()]; + multiSearchRetry( + responses, + multiSearchRequest, + IntStream.range(0, multiSearchRequest.requests().size()) + .boxed() + .collect(Collectors.toList()), + 0); - return responses; - } + return responses; + } - private void multiSearchRetry(MultiSearchResponse.Item[] responses, MultiSearchRequest multiSearchRequest, - List indices, int retry) { - MultiSearchRequest multiSearchRequestRetry = new MultiSearchRequest(); - for (int i : indices) { - multiSearchRequestRetry.add(multiSearchRequest.requests().get(i)); - } - MultiSearchResponse.Item[] res = client.multiSearch(multiSearchRequestRetry).actionGet().getResponses(); - List indicesFailure = new ArrayList<>(); - //Could get EsRejectedExecutionException and OpenSearchException as getCause - for (int i = 0; i < res.length; i++) { - if (res[i].isFailure()) { - indicesFailure.add(indices.get(i)); - if (retry == 3) { - responses[indices.get(i)] = res[i]; - } - } else { - responses[indices.get(i)] = res[i]; - } - } - if (!indicesFailure.isEmpty()) { - LOG.info("OpenSearch multisearch has failures on retry {}", retry); - if (retry < 3) { - BackOffRetryStrategy.backOffSleep(retryIntervals[retry]); - multiSearchRetry(responses, multiSearchRequest, indicesFailure, retry + 1); - } + private void multiSearchRetry( + MultiSearchResponse.Item[] responses, + MultiSearchRequest multiSearchRequest, + List indices, + int retry) { + MultiSearchRequest multiSearchRequestRetry = new MultiSearchRequest(); + for (int i : indices) { + multiSearchRequestRetry.add(multiSearchRequest.requests().get(i)); + } + MultiSearchResponse.Item[] res = + client.multiSearch(multiSearchRequestRetry).actionGet().getResponses(); + List indicesFailure = new ArrayList<>(); + // Could get EsRejectedExecutionException and OpenSearchException as getCause + for (int i = 0; i < res.length; i++) { + if (res[i].isFailure()) { + indicesFailure.add(indices.get(i)); + if (retry == 3) { + responses[indices.get(i)] = res[i]; } + } else { + responses[indices.get(i)] = res[i]; + } + } + if (!indicesFailure.isEmpty()) { + LOG.info("OpenSearch multisearch has failures on retry {}", retry); + if (retry < 3) { + BackOffRetryStrategy.backOffSleep(retryIntervals[retry]); + multiSearchRetry(responses, multiSearchRequest, indicesFailure, retry + 1); + } } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/QueryActionElasticExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/QueryActionElasticExecutor.java index bcb25fd39a..2e45fb45b7 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/QueryActionElasticExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/QueryActionElasticExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import java.io.IOException; @@ -31,81 +30,85 @@ import org.opensearch.sql.legacy.query.multi.MultiQueryAction; import org.opensearch.sql.legacy.query.multi.MultiQueryRequestBuilder; -/** - * Created by Eliran on 3/10/2015. - */ +/** Created by Eliran on 3/10/2015. */ public class QueryActionElasticExecutor { - public static SearchHits executeSearchAction(DefaultQueryAction searchQueryAction) throws SqlParseException { - SqlOpenSearchRequestBuilder builder = searchQueryAction.explain(); - return ((SearchResponse) builder.get()).getHits(); - } + public static SearchHits executeSearchAction(DefaultQueryAction searchQueryAction) + throws SqlParseException { + SqlOpenSearchRequestBuilder builder = searchQueryAction.explain(); + return ((SearchResponse) builder.get()).getHits(); + } - public static SearchHits executeJoinSearchAction(Client client, OpenSearchJoinQueryAction joinQueryAction) - throws IOException, SqlParseException { - SqlElasticRequestBuilder joinRequestBuilder = joinQueryAction.explain(); - ElasticJoinExecutor executor = ElasticJoinExecutor.createJoinExecutor(client, joinRequestBuilder); - executor.run(); - return executor.getHits(); - } + public static SearchHits executeJoinSearchAction( + Client client, OpenSearchJoinQueryAction joinQueryAction) + throws IOException, SqlParseException { + SqlElasticRequestBuilder joinRequestBuilder = joinQueryAction.explain(); + ElasticJoinExecutor executor = + ElasticJoinExecutor.createJoinExecutor(client, joinRequestBuilder); + executor.run(); + return executor.getHits(); + } - public static Aggregations executeAggregationAction(AggregationQueryAction aggregationQueryAction) - throws SqlParseException { - SqlOpenSearchRequestBuilder select = aggregationQueryAction.explain(); - return ((SearchResponse) select.get()).getAggregations(); - } + public static Aggregations executeAggregationAction(AggregationQueryAction aggregationQueryAction) + throws SqlParseException { + SqlOpenSearchRequestBuilder select = aggregationQueryAction.explain(); + return ((SearchResponse) select.get()).getAggregations(); + } - public static List executeQueryPlanQueryAction(QueryPlanQueryAction queryPlanQueryAction) { - QueryPlanRequestBuilder select = (QueryPlanRequestBuilder) queryPlanQueryAction.explain(); - return select.execute(); - } + public static List executeQueryPlanQueryAction( + QueryPlanQueryAction queryPlanQueryAction) { + QueryPlanRequestBuilder select = (QueryPlanRequestBuilder) queryPlanQueryAction.explain(); + return select.execute(); + } - public static ActionResponse executeShowQueryAction(ShowQueryAction showQueryAction) { - return showQueryAction.explain().get(); - } + public static ActionResponse executeShowQueryAction(ShowQueryAction showQueryAction) { + return showQueryAction.explain().get(); + } - public static ActionResponse executeDescribeQueryAction(DescribeQueryAction describeQueryAction) { - return describeQueryAction.explain().get(); - } + public static ActionResponse executeDescribeQueryAction(DescribeQueryAction describeQueryAction) { + return describeQueryAction.explain().get(); + } - public static ActionResponse executeDeleteAction(DeleteQueryAction deleteQueryAction) throws SqlParseException { - return deleteQueryAction.explain().get(); - } + public static ActionResponse executeDeleteAction(DeleteQueryAction deleteQueryAction) + throws SqlParseException { + return deleteQueryAction.explain().get(); + } - public static SearchHits executeMultiQueryAction(Client client, MultiQueryAction queryAction) - throws SqlParseException, IOException { - SqlElasticRequestBuilder multiRequestBuilder = queryAction.explain(); - ElasticHitsExecutor executor = MultiRequestExecutorFactory.createExecutor(client, - (MultiQueryRequestBuilder) multiRequestBuilder); - executor.run(); - return executor.getHits(); - } + public static SearchHits executeMultiQueryAction(Client client, MultiQueryAction queryAction) + throws SqlParseException, IOException { + SqlElasticRequestBuilder multiRequestBuilder = queryAction.explain(); + ElasticHitsExecutor executor = + MultiRequestExecutorFactory.createExecutor( + client, (MultiQueryRequestBuilder) multiRequestBuilder); + executor.run(); + return executor.getHits(); + } - public static Object executeAnyAction(Client client, QueryAction queryAction) - throws SqlParseException, IOException { - if (queryAction instanceof DefaultQueryAction) { - return executeSearchAction((DefaultQueryAction) queryAction); - } - if (queryAction instanceof AggregationQueryAction) { - return executeAggregationAction((AggregationQueryAction) queryAction); - } - if (queryAction instanceof QueryPlanQueryAction) { - return executeQueryPlanQueryAction((QueryPlanQueryAction) queryAction); - } - if (queryAction instanceof ShowQueryAction) { - return executeShowQueryAction((ShowQueryAction) queryAction); - } - if (queryAction instanceof DescribeQueryAction) { - return executeDescribeQueryAction((DescribeQueryAction) queryAction); - } - if (queryAction instanceof OpenSearchJoinQueryAction) { - return executeJoinSearchAction(client, (OpenSearchJoinQueryAction) queryAction); - } - if (queryAction instanceof MultiQueryAction) { - return executeMultiQueryAction(client, (MultiQueryAction) queryAction); - } - if (queryAction instanceof DeleteQueryAction) { - return executeDeleteAction((DeleteQueryAction) queryAction); - } - return null; + public static Object executeAnyAction(Client client, QueryAction queryAction) + throws SqlParseException, IOException { + if (queryAction instanceof DefaultQueryAction) { + return executeSearchAction((DefaultQueryAction) queryAction); + } + if (queryAction instanceof AggregationQueryAction) { + return executeAggregationAction((AggregationQueryAction) queryAction); + } + if (queryAction instanceof QueryPlanQueryAction) { + return executeQueryPlanQueryAction((QueryPlanQueryAction) queryAction); + } + if (queryAction instanceof ShowQueryAction) { + return executeShowQueryAction((ShowQueryAction) queryAction); + } + if (queryAction instanceof DescribeQueryAction) { + return executeDescribeQueryAction((DescribeQueryAction) queryAction); + } + if (queryAction instanceof OpenSearchJoinQueryAction) { + return executeJoinSearchAction(client, (OpenSearchJoinQueryAction) queryAction); + } + if (queryAction instanceof MultiQueryAction) { + return executeMultiQueryAction(client, (MultiQueryAction) queryAction); + } + if (queryAction instanceof DeleteQueryAction) { + return executeDeleteAction((DeleteQueryAction) queryAction); } + return null; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/RestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/RestExecutor.java index e0124fb8be..8a0ab65970 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/RestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/RestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import java.util.Map; @@ -11,12 +10,12 @@ import org.opensearch.rest.RestChannel; import org.opensearch.sql.legacy.query.QueryAction; -/** - * Created by Eliran on 26/12/2015. - */ +/** Created by Eliran on 26/12/2015. */ public interface RestExecutor { - void execute(Client client, Map params, QueryAction queryAction, RestChannel channel) - throws Exception; + void execute( + Client client, Map params, QueryAction queryAction, RestChannel channel) + throws Exception; - String execute(Client client, Map params, QueryAction queryAction) throws Exception; + String execute(Client client, Map params, QueryAction queryAction) + throws Exception; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanQueryAction.java index 091abca554..b0179d3d8d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.adapter; import com.google.common.base.Strings; @@ -14,27 +13,31 @@ import org.opensearch.sql.legacy.query.SqlElasticRequestBuilder; /** - * The definition of QueryPlan of QueryAction which works as the adapter to the current QueryAction framework. + * The definition of QueryPlan of QueryAction which works as the adapter to the current QueryAction + * framework. */ public class QueryPlanQueryAction extends QueryAction { - private final QueryPlanRequestBuilder requestBuilder; + private final QueryPlanRequestBuilder requestBuilder; - public QueryPlanQueryAction(QueryPlanRequestBuilder requestBuilder) { - super(null, null); - this.requestBuilder = requestBuilder; - } + public QueryPlanQueryAction(QueryPlanRequestBuilder requestBuilder) { + super(null, null); + this.requestBuilder = requestBuilder; + } - @Override - public SqlElasticRequestBuilder explain() { - return requestBuilder; - } + @Override + public SqlElasticRequestBuilder explain() { + return requestBuilder; + } - @Override - public Optional> getFieldNames() { - List fieldNames = ((QueryPlanRequestBuilder) requestBuilder).outputColumns() - .stream() - .map(node -> Strings.isNullOrEmpty(node.getAlias()) ? node.getName() : node.getAlias()) + @Override + public Optional> getFieldNames() { + List fieldNames = + ((QueryPlanRequestBuilder) requestBuilder) + .outputColumns().stream() + .map( + node -> + Strings.isNullOrEmpty(node.getAlias()) ? node.getName() : node.getAlias()) .collect(Collectors.toList()); - return Optional.of(fieldNames); - } + return Optional.of(fieldNames); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanRequestBuilder.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanRequestBuilder.java index ef0bc85bc1..3933df9bbb 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanRequestBuilder.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/adapter/QueryPlanRequestBuilder.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.adapter; import java.util.List; @@ -16,38 +15,36 @@ import org.opensearch.sql.legacy.query.planner.core.BindingTupleQueryPlanner; import org.opensearch.sql.legacy.query.planner.core.ColumnNode; -/** - * The definition of QueryPlan SqlElasticRequestBuilder. - */ +/** The definition of QueryPlan SqlElasticRequestBuilder. */ @RequiredArgsConstructor public class QueryPlanRequestBuilder implements SqlElasticRequestBuilder { - private final BindingTupleQueryPlanner queryPlanner; - - public List execute() { - return queryPlanner.execute(); - } - - public List outputColumns() { - return queryPlanner.getColumnNodes(); - } - - @Override - public String explain() { - return queryPlanner.explain(); - } - - @Override - public ActionRequest request() { - throw new RuntimeException("unsupported operation"); - } - - @Override - public ActionResponse get() { - throw new RuntimeException("unsupported operation"); - } - - @Override - public ActionRequestBuilder getBuilder() { - throw new RuntimeException("unsupported operation"); - } + private final BindingTupleQueryPlanner queryPlanner; + + public List execute() { + return queryPlanner.execute(); + } + + public List outputColumns() { + return queryPlanner.getColumnNodes(); + } + + @Override + public String explain() { + return queryPlanner.explain(); + } + + @Override + public ActionRequest request() { + throw new RuntimeException("unsupported operation"); + } + + @Override + public ActionResponse get() { + throw new RuntimeException("unsupported operation"); + } + + @Override + public ActionRequestBuilder getBuilder() { + throw new RuntimeException("unsupported operation"); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/OpenSearchErrorMessage.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/OpenSearchErrorMessage.java index a48ab003dc..8117d241b1 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/OpenSearchErrorMessage.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/OpenSearchErrorMessage.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import org.opensearch.OpenSearchException; @@ -13,46 +12,53 @@ public class OpenSearchErrorMessage extends ErrorMessage { - OpenSearchErrorMessage(OpenSearchException exception, int status) { - super(exception, status); - } + OpenSearchErrorMessage(OpenSearchException exception, int status) { + super(exception, status); + } - @Override - protected String fetchReason() { - return "Error occurred in OpenSearch engine: " + exception.getMessage(); - } + @Override + protected String fetchReason() { + return "Error occurred in OpenSearch engine: " + exception.getMessage(); + } - /** Currently Sql-Jdbc plugin only supports string type as reason and details in the error messages */ - @Override - protected String fetchDetails() { - StringBuilder details = new StringBuilder(); - if (exception instanceof SearchPhaseExecutionException) { - details.append(fetchSearchPhaseExecutionExceptionDetails((SearchPhaseExecutionException) exception)); - } else { - details.append(defaultDetails(exception)); - } - details.append("\nFor more details, please send request for Json format to see the raw response from " - + "OpenSearch engine."); - return details.toString(); + /** + * Currently Sql-Jdbc plugin only supports string type as reason and details in the error messages + */ + @Override + protected String fetchDetails() { + StringBuilder details = new StringBuilder(); + if (exception instanceof SearchPhaseExecutionException) { + details.append( + fetchSearchPhaseExecutionExceptionDetails((SearchPhaseExecutionException) exception)); + } else { + details.append(defaultDetails(exception)); } + details.append( + "\nFor more details, please send request for Json format to see the raw response from " + + "OpenSearch engine."); + return details.toString(); + } - private String defaultDetails(OpenSearchException exception) { - return exception.getDetailedMessage(); - } + private String defaultDetails(OpenSearchException exception) { + return exception.getDetailedMessage(); + } - /** - * Could not deliver the exactly same error messages due to the limit of JDBC types. - * Currently our cases occurred only SearchPhaseExecutionException instances among all types of OpenSearch exceptions - * according to the survey, see all types: OpenSearchException.OpenSearchExceptionHandle. - * Either add methods of fetching details for different types, or re-make a consistent message by not giving - * detailed messages/root causes but only a suggestion message. - */ - private String fetchSearchPhaseExecutionExceptionDetails(SearchPhaseExecutionException exception) { - StringBuilder details = new StringBuilder(); - ShardSearchFailure[] shardFailures = exception.shardFailures(); - for (ShardSearchFailure failure : shardFailures) { - details.append(StringUtils.format("Shard[%d]: %s\n", failure.shardId(), failure.getCause().toString())); - } - return details.toString(); + /** + * Could not deliver the exactly same error messages due to the limit of JDBC types. Currently our + * cases occurred only SearchPhaseExecutionException instances among all types of OpenSearch + * exceptions according to the survey, see all types: + * OpenSearchException.OpenSearchExceptionHandle. Either add methods of fetching details for + * different types, or re-make a consistent message by not giving detailed messages/root causes + * but only a suggestion message. + */ + private String fetchSearchPhaseExecutionExceptionDetails( + SearchPhaseExecutionException exception) { + StringBuilder details = new StringBuilder(); + ShardSearchFailure[] shardFailures = exception.shardFailures(); + for (ShardSearchFailure failure : shardFailures) { + details.append( + StringUtils.format("Shard[%d]: %s\n", failure.shardId(), failure.getCause().toString())); } + return details.toString(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java index 411fb90a24..00feabf5d8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.util.Map; @@ -27,82 +26,84 @@ public class PrettyFormatRestExecutor implements RestExecutor { - private static final Logger LOG = LogManager.getLogger(); - - private final String format; - - public PrettyFormatRestExecutor(String format) { - this.format = format.toLowerCase(); + private static final Logger LOG = LogManager.getLogger(); + + private final String format; + + public PrettyFormatRestExecutor(String format) { + this.format = format.toLowerCase(); + } + + /** Execute the QueryAction and return the REST response using the channel. */ + @Override + public void execute( + Client client, Map params, QueryAction queryAction, RestChannel channel) { + String formattedResponse = execute(client, params, queryAction); + BytesRestResponse bytesRestResponse; + if (format.equals("jdbc")) { + bytesRestResponse = + new BytesRestResponse( + RestStatus.OK, "application/json; charset=UTF-8", formattedResponse); + } else { + bytesRestResponse = new BytesRestResponse(RestStatus.OK, formattedResponse); } - /** - * Execute the QueryAction and return the REST response using the channel. - */ - @Override - public void execute(Client client, Map params, QueryAction queryAction, RestChannel channel) { - String formattedResponse = execute(client, params, queryAction); - BytesRestResponse bytesRestResponse; - if (format.equals("jdbc")) { - bytesRestResponse = new BytesRestResponse(RestStatus.OK, - "application/json; charset=UTF-8", - formattedResponse); - } else { - bytesRestResponse = new BytesRestResponse(RestStatus.OK, formattedResponse); - } - - if (!BackOffRetryStrategy.isHealthy(2 * bytesRestResponse.content().length(), this)) { - throw new IllegalStateException( - "[PrettyFormatRestExecutor] Memory could be insufficient when sendResponse()."); - } - - channel.sendResponse(bytesRestResponse); + if (!BackOffRetryStrategy.isHealthy(2 * bytesRestResponse.content().length(), this)) { + throw new IllegalStateException( + "[PrettyFormatRestExecutor] Memory could be insufficient when sendResponse()."); } - @Override - public String execute(Client client, Map params, QueryAction queryAction) { - Protocol protocol; - - try { - if (queryAction instanceof DefaultQueryAction) { - protocol = buildProtocolForDefaultQuery(client, (DefaultQueryAction) queryAction); - } else { - Object queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); - protocol = new Protocol(client, queryAction, queryResult, format, Cursor.NULL_CURSOR); - } - } catch (Exception e) { - if (e instanceof OpenSearchException) { - LOG.warn("An error occurred in OpenSearch engine: " - + ((OpenSearchException) e).getDetailedMessage(), e); - } else { - LOG.warn("Error happened in pretty formatter", e); - } - protocol = new Protocol(e); - } - - return protocol.format(); + channel.sendResponse(bytesRestResponse); + } + + @Override + public String execute(Client client, Map params, QueryAction queryAction) { + Protocol protocol; + + try { + if (queryAction instanceof DefaultQueryAction) { + protocol = buildProtocolForDefaultQuery(client, (DefaultQueryAction) queryAction); + } else { + Object queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); + protocol = new Protocol(client, queryAction, queryResult, format, Cursor.NULL_CURSOR); + } + } catch (Exception e) { + if (e instanceof OpenSearchException) { + LOG.warn( + "An error occurred in OpenSearch engine: " + + ((OpenSearchException) e).getDetailedMessage(), + e); + } else { + LOG.warn("Error happened in pretty formatter", e); + } + protocol = new Protocol(e); } - /** - * QueryActionElasticExecutor.executeAnyAction() returns SearchHits inside SearchResponse. - * In order to get scroll ID if any, we need to execute DefaultQueryAction ourselves for SearchResponse. - */ - private Protocol buildProtocolForDefaultQuery(Client client, DefaultQueryAction queryAction) - throws SqlParseException { - - SearchResponse response = (SearchResponse) queryAction.explain().get(); - String scrollId = response.getScrollId(); - - Protocol protocol; - if (!Strings.isNullOrEmpty(scrollId)) { - DefaultCursor defaultCursor = new DefaultCursor(); - defaultCursor.setScrollId(scrollId); - defaultCursor.setLimit(queryAction.getSelect().getRowCount()); - defaultCursor.setFetchSize(queryAction.getSqlRequest().fetchSize()); - protocol = new Protocol(client, queryAction, response.getHits(), format, defaultCursor); - } else { - protocol = new Protocol(client, queryAction, response.getHits(), format, Cursor.NULL_CURSOR); - } - - return protocol; + return protocol.format(); + } + + /** + * QueryActionElasticExecutor.executeAnyAction() returns SearchHits inside SearchResponse. In + * order to get scroll ID if any, we need to execute DefaultQueryAction ourselves for + * SearchResponse. + */ + private Protocol buildProtocolForDefaultQuery(Client client, DefaultQueryAction queryAction) + throws SqlParseException { + + SearchResponse response = (SearchResponse) queryAction.explain().get(); + String scrollId = response.getScrollId(); + + Protocol protocol; + if (!Strings.isNullOrEmpty(scrollId)) { + DefaultCursor defaultCursor = new DefaultCursor(); + defaultCursor.setScrollId(scrollId); + defaultCursor.setLimit(queryAction.getSelect().getRowCount()); + defaultCursor.setFetchSize(queryAction.getSqlRequest().fetchSize()); + protocol = new Protocol(client, queryAction, response.getHits(), format, defaultCursor); + } else { + protocol = new Protocol(client, queryAction, response.getHits(), format, Cursor.NULL_CURSOR); } + + return protocol; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Protocol.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Protocol.java index aba0a3c599..e6ea767e17 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Protocol.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Protocol.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static org.opensearch.sql.legacy.domain.IndexStatement.StatementType; @@ -33,215 +32,223 @@ public class Protocol { - static final int OK_STATUS = 200; - static final int ERROR_STATUS = 500; - - private final String formatType; - private int status; - private long size; - private long total; - private ResultSet resultSet; - private ErrorMessage error; - private List columnNodeList; - private Cursor cursor = new NullCursor(); - private ColumnTypeProvider scriptColumnType = new ColumnTypeProvider(); - - public Protocol(Client client, QueryAction queryAction, Object queryResult, String formatType, Cursor cursor) { - this.cursor = cursor; - - if (queryAction instanceof QueryPlanQueryAction) { - this.columnNodeList = - ((QueryPlanRequestBuilder) (((QueryPlanQueryAction) queryAction).explain())).outputColumns(); - } else if (queryAction instanceof DefaultQueryAction) { - scriptColumnType = queryAction.getScriptColumnType(); - } - - this.formatType = formatType; - QueryStatement query = queryAction.getQueryStatement(); - this.status = OK_STATUS; - this.resultSet = loadResultSet(client, query, queryResult); - this.size = resultSet.getDataRows().getSize(); - this.total = resultSet.getDataRows().getTotalHits(); - } - - - public Protocol(Client client, Object queryResult, String formatType, Cursor cursor) { - this.cursor = cursor; - this.status = OK_STATUS; - this.formatType = formatType; - this.resultSet = loadResultSetForCursor(client, queryResult); - } - - public Protocol(Exception e) { - this.formatType = null; - this.status = ERROR_STATUS; - this.error = ErrorMessageFactory.createErrorMessage(e, status); - } - - private ResultSet loadResultSetForCursor(Client client, Object queryResult) { - return new SelectResultSet(client, queryResult, formatType, cursor); - } - - private ResultSet loadResultSet(Client client, QueryStatement queryStatement, Object queryResult) { - if (queryResult instanceof List) { - return new BindingTupleResultSet(columnNodeList, (List) queryResult); - } - if (queryStatement instanceof Delete) { - return new DeleteResultSet(client, (Delete) queryStatement, queryResult); - } else if (queryStatement instanceof Query) { - return new SelectResultSet(client, (Query) queryStatement, queryResult, - scriptColumnType, formatType, cursor); - } else if (queryStatement instanceof IndexStatement) { - IndexStatement statement = (IndexStatement) queryStatement; - StatementType statementType = statement.getStatementType(); - - if (statementType == StatementType.SHOW) { - return new ShowResultSet(client, statement, queryResult); - } else if (statementType == StatementType.DESCRIBE) { - return new DescribeResultSet(client, statement, queryResult); - } - } - - throw new UnsupportedOperationException( - String.format("The following instance of QueryStatement is not supported: %s", - queryStatement.getClass().toString()) - ); - } - - public int getStatus() { - return status; - } - - public ResultSet getResultSet() { - return resultSet; - } - - public String format() { - if (status == OK_STATUS) { - switch (formatType) { - case "jdbc": - return outputInJdbcFormat(); - case "table": - return outputInTableFormat(); - case "raw": - return outputInRawFormat(); - default: - throw new UnsupportedOperationException( - String.format("The following format is not supported: %s", formatType)); - } - } - - return error.toString(); + static final int OK_STATUS = 200; + static final int ERROR_STATUS = 500; + + private final String formatType; + private int status; + private long size; + private long total; + private ResultSet resultSet; + private ErrorMessage error; + private List columnNodeList; + private Cursor cursor = new NullCursor(); + private ColumnTypeProvider scriptColumnType = new ColumnTypeProvider(); + + public Protocol( + Client client, + QueryAction queryAction, + Object queryResult, + String formatType, + Cursor cursor) { + this.cursor = cursor; + + if (queryAction instanceof QueryPlanQueryAction) { + this.columnNodeList = + ((QueryPlanRequestBuilder) (((QueryPlanQueryAction) queryAction).explain())) + .outputColumns(); + } else if (queryAction instanceof DefaultQueryAction) { + scriptColumnType = queryAction.getScriptColumnType(); + } + + this.formatType = formatType; + QueryStatement query = queryAction.getQueryStatement(); + this.status = OK_STATUS; + this.resultSet = loadResultSet(client, query, queryResult); + this.size = resultSet.getDataRows().getSize(); + this.total = resultSet.getDataRows().getTotalHits(); + } + + public Protocol(Client client, Object queryResult, String formatType, Cursor cursor) { + this.cursor = cursor; + this.status = OK_STATUS; + this.formatType = formatType; + this.resultSet = loadResultSetForCursor(client, queryResult); + } + + public Protocol(Exception e) { + this.formatType = null; + this.status = ERROR_STATUS; + this.error = ErrorMessageFactory.createErrorMessage(e, status); + } + + private ResultSet loadResultSetForCursor(Client client, Object queryResult) { + return new SelectResultSet(client, queryResult, formatType, cursor); + } + + private ResultSet loadResultSet( + Client client, QueryStatement queryStatement, Object queryResult) { + if (queryResult instanceof List) { + return new BindingTupleResultSet(columnNodeList, (List) queryResult); + } + if (queryStatement instanceof Delete) { + return new DeleteResultSet(client, (Delete) queryStatement, queryResult); + } else if (queryStatement instanceof Query) { + return new SelectResultSet( + client, (Query) queryStatement, queryResult, scriptColumnType, formatType, cursor); + } else if (queryStatement instanceof IndexStatement) { + IndexStatement statement = (IndexStatement) queryStatement; + StatementType statementType = statement.getStatementType(); + + if (statementType == StatementType.SHOW) { + return new ShowResultSet(client, statement, queryResult); + } else if (statementType == StatementType.DESCRIBE) { + return new DescribeResultSet(client, statement, queryResult); + } + } + + throw new UnsupportedOperationException( + String.format( + "The following instance of QueryStatement is not supported: %s", + queryStatement.getClass().toString())); + } + + public int getStatus() { + return status; + } + + public ResultSet getResultSet() { + return resultSet; + } + + public String format() { + if (status == OK_STATUS) { + switch (formatType) { + case "jdbc": + return outputInJdbcFormat(); + case "table": + return outputInTableFormat(); + case "raw": + return outputInRawFormat(); + default: + throw new UnsupportedOperationException( + String.format("The following format is not supported: %s", formatType)); + } + } + + return error.toString(); + } + + private String outputInJdbcFormat() { + JSONObject formattedOutput = new JSONObject(); + + formattedOutput.put("status", status); + formattedOutput.put("size", size); + formattedOutput.put("total", total); + + JSONArray schema = getSchemaAsJson(); + + formattedOutput.put("schema", schema); + formattedOutput.put("datarows", getDataRowsAsJson()); + + String cursorId = cursor.generateCursorId(); + if (!Strings.isNullOrEmpty(cursorId)) { + formattedOutput.put("cursor", cursorId); + } + + return formattedOutput.toString(2); + } + + private String outputInRawFormat() { + Schema schema = resultSet.getSchema(); + DataRows dataRows = resultSet.getDataRows(); + + StringBuilder formattedOutput = new StringBuilder(); + for (Row row : dataRows) { + formattedOutput.append(rawEntry(row, schema)).append("\n"); + } + + return formattedOutput.toString(); + } + + private String outputInTableFormat() { + return null; + } + + public String cursorFormat() { + if (status == OK_STATUS) { + switch (formatType) { + case "jdbc": + return cursorOutputInJDBCFormat(); + default: + throw new UnsupportedOperationException( + String.format( + "The following response format is not supported for cursor: [%s]", formatType)); + } } + return error.toString(); + } - private String outputInJdbcFormat() { - JSONObject formattedOutput = new JSONObject(); + private String cursorOutputInJDBCFormat() { + JSONObject formattedOutput = new JSONObject(); + formattedOutput.put("datarows", getDataRowsAsJson()); - formattedOutput.put("status", status); - formattedOutput.put("size", size); - formattedOutput.put("total", total); - - JSONArray schema = getSchemaAsJson(); - - formattedOutput.put("schema", schema); - formattedOutput.put("datarows", getDataRowsAsJson()); - - String cursorId = cursor.generateCursorId(); - if (!Strings.isNullOrEmpty(cursorId)) { - formattedOutput.put("cursor", cursorId); - } - - return formattedOutput.toString(2); + String cursorId = cursor.generateCursorId(); + if (!Strings.isNullOrEmpty(cursorId)) { + formattedOutput.put("cursor", cursorId); } + return formattedOutput.toString(2); + } - private String outputInRawFormat() { - Schema schema = resultSet.getSchema(); - DataRows dataRows = resultSet.getDataRows(); + private String rawEntry(Row row, Schema schema) { + // TODO String separator is being kept to "|" for the time being as using "\t" will require + // formatting since + // TODO tabs are occurring in multiple of 4 (one option is Guava's Strings.padEnd() method) + return StreamSupport.stream(schema.spliterator(), false) + .map(column -> row.getDataOrDefault(column.getName(), "NULL").toString()) + .collect(Collectors.joining("|")); + } - StringBuilder formattedOutput = new StringBuilder(); - for (Row row : dataRows) { - formattedOutput.append(rawEntry(row, schema)).append("\n"); - } - - return formattedOutput.toString(); - } - - private String outputInTableFormat() { - return null; - } + private JSONArray getSchemaAsJson() { + Schema schema = resultSet.getSchema(); + JSONArray schemaJson = new JSONArray(); - public String cursorFormat() { - if (status == OK_STATUS) { - switch (formatType) { - case "jdbc": - return cursorOutputInJDBCFormat(); - default: - throw new UnsupportedOperationException(String.format( - "The following response format is not supported for cursor: [%s]", formatType)); - } - } - return error.toString(); + for (Column column : schema) { + schemaJson.put(schemaEntry(column.getName(), column.getAlias(), column.getType())); } - private String cursorOutputInJDBCFormat() { - JSONObject formattedOutput = new JSONObject(); - formattedOutput.put("datarows", getDataRowsAsJson()); + return schemaJson; + } - String cursorId = cursor.generateCursorId(); - if (!Strings.isNullOrEmpty(cursorId)) { - formattedOutput.put("cursor", cursorId); - } - return formattedOutput.toString(2); + private JSONObject schemaEntry(String name, String alias, String type) { + JSONObject entry = new JSONObject(); + entry.put("name", name); + if (alias != null) { + entry.put("alias", alias); } + entry.put("type", type); - private String rawEntry(Row row, Schema schema) { - // TODO String separator is being kept to "|" for the time being as using "\t" will require formatting since - // TODO tabs are occurring in multiple of 4 (one option is Guava's Strings.padEnd() method) - return StreamSupport.stream(schema.spliterator(), false) - .map(column -> row.getDataOrDefault(column.getName(), "NULL").toString()) - .collect(Collectors.joining("|")); - } - - private JSONArray getSchemaAsJson() { - Schema schema = resultSet.getSchema(); - JSONArray schemaJson = new JSONArray(); - - for (Column column : schema) { - schemaJson.put(schemaEntry(column.getName(), column.getAlias(), column.getType())); - } - - return schemaJson; - } + return entry; + } - private JSONObject schemaEntry(String name, String alias, String type) { - JSONObject entry = new JSONObject(); - entry.put("name", name); - if (alias != null) { - entry.put("alias", alias); - } - entry.put("type", type); + private JSONArray getDataRowsAsJson() { + Schema schema = resultSet.getSchema(); + DataRows dataRows = resultSet.getDataRows(); + JSONArray dataRowsJson = new JSONArray(); - return entry; + for (Row row : dataRows) { + dataRowsJson.put(dataEntry(row, schema)); } - private JSONArray getDataRowsAsJson() { - Schema schema = resultSet.getSchema(); - DataRows dataRows = resultSet.getDataRows(); - JSONArray dataRowsJson = new JSONArray(); - - for (Row row : dataRows) { - dataRowsJson.put(dataEntry(row, schema)); - } - - return dataRowsJson; - } + return dataRowsJson; + } - private JSONArray dataEntry(Row dataRow, Schema schema) { - JSONArray entry = new JSONArray(); - for (Column column : schema) { - String columnName = column.getIdentifier(); - entry.put(dataRow.getDataOrDefault(columnName, JSONObject.NULL)); - } - return entry; + private JSONArray dataEntry(Row dataRow, Schema schema) { + JSONArray entry = new JSONArray(); + for (Column column : schema) { + String columnName = column.getIdentifier(); + entry.put(dataRow.getDataOrDefault(columnName, JSONObject.NULL)); } + return entry; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/ResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/ResultSet.java index 9864f1ffdc..079a738eb3 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/ResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/ResultSet.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.util.regex.Matcher; @@ -12,47 +11,44 @@ public abstract class ResultSet { - protected Schema schema; - protected DataRows dataRows; - - protected Client client; - protected String clusterName; - - public Schema getSchema() { - return schema; - } - - public DataRows getDataRows() { - return dataRows; - } - - protected String getClusterName() { - return client.admin().cluster() - .prepareHealth() - .get() - .getClusterName(); - } - - /** - * Check if given string matches the pattern. Do this check only if the pattern is a regex. - * Otherwise skip the matching process and consider it's a match. - * This is a quick fix to support SHOW/DESCRIBE alias by skip mismatch between actual index name - * and pattern (alias). - * @param string string to match - * @param pattern pattern - * @return true if match or pattern is not regular expression. otherwise false. - */ - protected boolean matchesPatternIfRegex(String string, String pattern) { - return isNotRegexPattern(pattern) || matchesPattern(string, pattern); - } - - protected boolean matchesPattern(String string, String pattern) { - Pattern p = Pattern.compile(pattern); - Matcher matcher = p.matcher(string); - return matcher.find(); - } - - private boolean isNotRegexPattern(String pattern) { - return !pattern.contains(".") && !pattern.contains("*"); - } + protected Schema schema; + protected DataRows dataRows; + + protected Client client; + protected String clusterName; + + public Schema getSchema() { + return schema; + } + + public DataRows getDataRows() { + return dataRows; + } + + protected String getClusterName() { + return client.admin().cluster().prepareHealth().get().getClusterName(); + } + + /** + * Check if given string matches the pattern. Do this check only if the pattern is a regex. + * Otherwise skip the matching process and consider it's a match. This is a quick fix to support + * SHOW/DESCRIBE alias by skip mismatch between actual index name and pattern (alias). + * + * @param string string to match + * @param pattern pattern + * @return true if match or pattern is not regular expression. otherwise false. + */ + protected boolean matchesPatternIfRegex(String string, String pattern) { + return isNotRegexPattern(pattern) || matchesPattern(string, pattern); + } + + protected boolean matchesPattern(String string, String pattern) { + Pattern p = Pattern.compile(pattern); + Matcher matcher = p.matcher(string); + return matcher.find(); + } + + private boolean isNotRegexPattern(String pattern) { + return !pattern.contains(".") && !pattern.contains("*"); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java index e02841fcd6..b29369f713 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static java.util.Collections.unmodifiableList; @@ -17,144 +16,155 @@ public class Schema implements Iterable { - private String indexName; - private List columns; + private String indexName; + private List columns; - private static Set types; + private static Set types; - static { - types = getTypes(); - } + static { + types = getTypes(); + } - public Schema(String indexName, List columns) { - this.indexName = indexName; - this.columns = columns; - } + public Schema(String indexName, List columns) { + this.indexName = indexName; + this.columns = columns; + } - public Schema(IndexStatement statement, List columns) { - this.indexName = statement.getIndexPattern(); - this.columns = columns; - } + public Schema(IndexStatement statement, List columns) { + this.indexName = statement.getIndexPattern(); + this.columns = columns; + } - public Schema(List columns){ - this.columns = columns; - } + public Schema(List columns) { + this.columns = columns; + } + + public String getIndexName() { + return indexName; + } - public String getIndexName() { - return indexName; + public List getHeaders() { + return columns.stream().map(column -> column.getName()).collect(Collectors.toList()); + } + + public List getColumns() { + return unmodifiableList(columns); + } + + private static Set getTypes() { + HashSet types = new HashSet<>(); + for (Type type : Type.values()) { + types.add(type.name()); } - public List getHeaders() { - return columns.stream() - .map(column -> column.getName()) - .collect(Collectors.toList()); + return types; + } + + // A method for efficiently checking if a Type exists + public static boolean hasType(String type) { + return types.contains(type); + } + + // Iterator method for Schema + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator iter = columns.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Column next() { + return iter.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("No changes allowed to Schema columns"); + } + }; + } + + // Only core OpenSearch datatypes currently supported + public enum Type { + TEXT, + KEYWORD, + IP, // String types + LONG, + INTEGER, + SHORT, + BYTE, + DOUBLE, + FLOAT, + HALF_FLOAT, + SCALED_FLOAT, // Numeric types + DATE, // Date types + BOOLEAN, // Boolean types + BINARY, // Binary types + OBJECT, + NESTED, + INTEGER_RANGE, + FLOAT_RANGE, + LONG_RANGE, + DOUBLE_RANGE, + DATE_RANGE; // Range types + + public String nameLowerCase() { + return name().toLowerCase(); } + } + + // Inner class for Column object + public static class Column { - public List getColumns() { - return unmodifiableList(columns); + private final String name; + private String alias; + private final Type type; + + private boolean identifiedByAlias; + + public Column(String name, String alias, Type type, boolean identifiedByAlias) { + this.name = name; + this.alias = alias; + this.type = type; + this.identifiedByAlias = identifiedByAlias; } - private static Set getTypes() { - HashSet types = new HashSet<>(); - for (Type type : Type.values()) { - types.add(type.name()); - } + public Column(String name, String alias, Type type) { + this(name, alias, type, false); + } - return types; + public String getName() { + return name; } - // A method for efficiently checking if a Type exists - public static boolean hasType(String type) { - return types.contains(type); + public String getAlias() { + return alias; } - // Iterator method for Schema - @Override - public Iterator iterator() { - return new Iterator() { - private final Iterator iter = columns.iterator(); - - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public Column next() { - return iter.next(); - } - - @Override - public void remove() { - throw new UnsupportedOperationException("No changes allowed to Schema columns"); - } - }; + public String getType() { + return type.nameLowerCase(); } - // Only core OpenSearch datatypes currently supported - public enum Type { - TEXT, KEYWORD, IP, // String types - LONG, INTEGER, SHORT, BYTE, DOUBLE, FLOAT, HALF_FLOAT, SCALED_FLOAT, // Numeric types - DATE, // Date types - BOOLEAN, // Boolean types - BINARY, // Binary types - OBJECT, - NESTED, - INTEGER_RANGE, FLOAT_RANGE, LONG_RANGE, DOUBLE_RANGE, DATE_RANGE; // Range types - - public String nameLowerCase() { - return name().toLowerCase(); - } + /* + * Some query types (like JOIN) label the data in SearchHit using alias instead of field name if it's given. + * + * This method returns the alias as the identifier if the identifiedByAlias flag is set for such cases so that + * the correct identifier is used to access related data in DataRows. + */ + public String getIdentifier() { + if (identifiedByAlias && alias != null) { + return alias; + } else { + return name; + } } - // Inner class for Column object - public static class Column { - - private final String name; - private String alias; - private final Type type; - - private boolean identifiedByAlias; - - public Column(String name, String alias, Type type, boolean identifiedByAlias) { - this.name = name; - this.alias = alias; - this.type = type; - this.identifiedByAlias = identifiedByAlias; - } - - public Column(String name, String alias, Type type) { - this(name, alias, type, false); - } - - public String getName() { - return name; - } - - public String getAlias() { - return alias; - } - - public String getType() { - return type.nameLowerCase(); - } - - /* - * Some query types (like JOIN) label the data in SearchHit using alias instead of field name if it's given. - * - * This method returns the alias as the identifier if the identifiedByAlias flag is set for such cases so that - * the correct identifier is used to access related data in DataRows. - */ - public String getIdentifier() { - if (identifiedByAlias && alias != null) { - return alias; - } else { - return name; - } - } - - public Type getEnumType() { - return type; - } + public Type getEnumType() { + return type; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java index a6f4cf815a..445bdd45a0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static java.util.Collections.unmodifiableMap; @@ -59,676 +58,671 @@ public class SelectResultSet extends ResultSet { - private static final Logger LOG = LogManager.getLogger(SelectResultSet.class); - - public static final String SCORE = "_score"; - private final String formatType; - - private Query query; - private Object queryResult; - - private boolean selectAll; - private String indexName; - private List columns = new ArrayList<>(); - private ColumnTypeProvider outputColumnType; - - private List head; - private long size; - private long totalHits; - private long internalTotalHits; - private List rows; - private Cursor cursor; - - private DateFieldFormatter dateFieldFormatter; - // alias -> base field name - private Map fieldAliasMap = new HashMap<>(); - - public SelectResultSet(Client client, - Query query, - Object queryResult, - ColumnTypeProvider outputColumnType, - String formatType, - Cursor cursor) { - this.client = client; - this.query = query; - this.queryResult = queryResult; - this.selectAll = false; - this.formatType = formatType; - this.outputColumnType = outputColumnType; - this.cursor = cursor; - - if (isJoinQuery()) { - JoinSelect joinQuery = (JoinSelect) query; - loadFromEsState(joinQuery.getFirstTable()); - loadFromEsState(joinQuery.getSecondTable()); - } else { - loadFromEsState(query); - } - this.schema = new Schema(indexName, columns); - this.head = schema.getHeaders(); - this.dateFieldFormatter = new DateFieldFormatter(indexName, columns, fieldAliasMap); - - extractData(); - populateCursor(); - this.dataRows = new DataRows(size, totalHits, rows); - } - - public SelectResultSet(Client client, Object queryResult, String formatType, Cursor cursor) { - this.cursor = cursor; - this.client = client; - this.queryResult = queryResult; - this.selectAll = false; - this.formatType = formatType; - populateResultSetFromCursor(cursor); - } - - public String indexName(){ - return this.indexName; - } - - public Map fieldAliasMap() { - return unmodifiableMap(this.fieldAliasMap); - } - - public void populateResultSetFromCursor(Cursor cursor) { - switch (cursor.getType()) { - case DEFAULT: - populateResultSetFromDefaultCursor((DefaultCursor) cursor); - default: - return; - } - } - - private void populateResultSetFromDefaultCursor(DefaultCursor cursor) { - this.columns = cursor.getColumns(); - this.schema = new Schema(columns); - this.head = schema.getHeaders(); - this.dateFieldFormatter = new DateFieldFormatter( - cursor.getIndexPattern(), - columns, - cursor.getFieldAliasMap() - ); - extractData(); - this.dataRows = new DataRows(size, totalHits, rows); - } - - //*********************************************************** - // Logic for loading Columns to be stored in Schema - //*********************************************************** - - /** - * Makes a request to local node to receive meta data information and maps each field specified in SELECT to its - * type in the index mapping - */ - private void loadFromEsState(Query query) { - String indexName = fetchIndexName(query); - String[] fieldNames = fetchFieldsAsArray(query); - - // Reset boolean in the case of JOIN query where multiple calls to loadFromEsState() are made - selectAll = isSimpleQuerySelectAll(query) || isJoinQuerySelectAll(query, fieldNames); - - GetFieldMappingsRequest request = new GetFieldMappingsRequest() - .indices(indexName) - .fields(selectAllFieldsIfEmpty(fieldNames)) - .local(true); - GetFieldMappingsResponse response = client.admin().indices() - .getFieldMappings(request) - .actionGet(); - - Map> mappings = response.mappings(); - if (mappings.isEmpty() || !mappings.containsKey(indexName)) { - throw new IllegalArgumentException(String.format("Index type %s does not exist", query.getFrom())); - } - Map typeMappings = mappings.get(indexName); - - - this.indexName = this.indexName == null ? indexName : (this.indexName + "|" + indexName); - this.columns.addAll(renameColumnWithTableAlias(query, populateColumns(query, fieldNames, typeMappings))); - } - - /** - * Rename column name with table alias as prefix for join query - */ - private List renameColumnWithTableAlias(Query query, List columns) { - List renamedCols; - if ((query instanceof TableOnJoinSelect) - && !Strings.isNullOrEmpty(((TableOnJoinSelect) query).getAlias())) { - - TableOnJoinSelect joinQuery = (TableOnJoinSelect) query; - renamedCols = new ArrayList<>(); - - for (Schema.Column column : columns) { - renamedCols.add(new Schema.Column( - joinQuery.getAlias() + "." + column.getName(), - column.getAlias(), - Schema.Type.valueOf(column.getType().toUpperCase()), - true - )); - } - } else { - renamedCols = columns; - } - return renamedCols; - } - - private boolean isSelectAll() { - return selectAll; - } - - /** - * Is a simple (non-join/non-group-by) query with SELECT * explicitly - */ - private boolean isSimpleQuerySelectAll(Query query) { - return (query instanceof Select) && ((Select) query).isSelectAll(); - } - - /** - * Is a join query with SELECT * on either one of the tables some fields specified - */ - private boolean isJoinQuerySelectAll(Query query, String[] fieldNames) { - return fieldNames.length == 0 && !fieldsSelectedOnAnotherTable(query); - } - - /** - * In the case of a JOIN query, if no fields are SELECTed on for a particular table, the other table's fields are - * checked in SELECT to ensure a table is not incorrectly marked as a isSelectAll() case. - */ - private boolean fieldsSelectedOnAnotherTable(Query query) { - if (isJoinQuery()) { - TableOnJoinSelect otherTable = getOtherTable(query); - return otherTable.getSelectedFields().size() > 0; - } - - return false; - } - - private TableOnJoinSelect getOtherTable(Query currJoinSelect) { - JoinSelect joinQuery = (JoinSelect) query; - if (joinQuery.getFirstTable() == currJoinSelect) { - return joinQuery.getSecondTable(); - } else { - return joinQuery.getFirstTable(); - } - } - - private boolean containsWildcard(Query query) { - for (Field field : fetchFields(query)) { - if (!(field instanceof MethodField) && field.getName().contains("*")) { - return true; - } - } - - return false; - } - - private String fetchIndexName(Query query) { - return query.getFrom().get(0).getIndex(); - } - - /** - * queryResult is checked to see if it's of type Aggregation in which case the aggregation fields in GROUP BY - * are returned as well. This prevents returning a Schema of all fields when SELECT * is called with - * GROUP BY (since all fields will be retrieved from the typeMappings request when no fields are returned from - * fetchFields()). - *

- * After getting all of the fields from GROUP BY, the fields from SELECT are iterated and only the fields of type - * MethodField are added (to prevent duplicate field in Schema for queries like - * "SELECT age, COUNT(*) FROM bank GROUP BY age" where 'age' is mentioned in both SELECT and GROUP BY). + private static final Logger LOG = LogManager.getLogger(SelectResultSet.class); + + public static final String SCORE = "_score"; + private final String formatType; + + private Query query; + private Object queryResult; + + private boolean selectAll; + private String indexName; + private List columns = new ArrayList<>(); + private ColumnTypeProvider outputColumnType; + + private List head; + private long size; + private long totalHits; + private long internalTotalHits; + private List rows; + private Cursor cursor; + + private DateFieldFormatter dateFieldFormatter; + // alias -> base field name + private Map fieldAliasMap = new HashMap<>(); + + public SelectResultSet( + Client client, + Query query, + Object queryResult, + ColumnTypeProvider outputColumnType, + String formatType, + Cursor cursor) { + this.client = client; + this.query = query; + this.queryResult = queryResult; + this.selectAll = false; + this.formatType = formatType; + this.outputColumnType = outputColumnType; + this.cursor = cursor; + + if (isJoinQuery()) { + JoinSelect joinQuery = (JoinSelect) query; + loadFromEsState(joinQuery.getFirstTable()); + loadFromEsState(joinQuery.getSecondTable()); + } else { + loadFromEsState(query); + } + this.schema = new Schema(indexName, columns); + this.head = schema.getHeaders(); + this.dateFieldFormatter = new DateFieldFormatter(indexName, columns, fieldAliasMap); + + extractData(); + populateCursor(); + this.dataRows = new DataRows(size, totalHits, rows); + } + + public SelectResultSet(Client client, Object queryResult, String formatType, Cursor cursor) { + this.cursor = cursor; + this.client = client; + this.queryResult = queryResult; + this.selectAll = false; + this.formatType = formatType; + populateResultSetFromCursor(cursor); + } + + public String indexName() { + return this.indexName; + } + + public Map fieldAliasMap() { + return unmodifiableMap(this.fieldAliasMap); + } + + public void populateResultSetFromCursor(Cursor cursor) { + switch (cursor.getType()) { + case DEFAULT: + populateResultSetFromDefaultCursor((DefaultCursor) cursor); + default: + return; + } + } + + private void populateResultSetFromDefaultCursor(DefaultCursor cursor) { + this.columns = cursor.getColumns(); + this.schema = new Schema(columns); + this.head = schema.getHeaders(); + this.dateFieldFormatter = + new DateFieldFormatter(cursor.getIndexPattern(), columns, cursor.getFieldAliasMap()); + extractData(); + this.dataRows = new DataRows(size, totalHits, rows); + } + + // *********************************************************** + // Logic for loading Columns to be stored in Schema + // *********************************************************** + + /** + * Makes a request to local node to receive meta data information and maps each field specified in + * SELECT to its type in the index mapping + */ + private void loadFromEsState(Query query) { + String indexName = fetchIndexName(query); + String[] fieldNames = fetchFieldsAsArray(query); + + // Reset boolean in the case of JOIN query where multiple calls to loadFromEsState() are made + selectAll = isSimpleQuerySelectAll(query) || isJoinQuerySelectAll(query, fieldNames); + + GetFieldMappingsRequest request = + new GetFieldMappingsRequest() + .indices(indexName) + .fields(selectAllFieldsIfEmpty(fieldNames)) + .local(true); + GetFieldMappingsResponse response = + client.admin().indices().getFieldMappings(request).actionGet(); + + Map> mappings = response.mappings(); + if (mappings.isEmpty() || !mappings.containsKey(indexName)) { + throw new IllegalArgumentException( + String.format("Index type %s does not exist", query.getFrom())); + } + Map typeMappings = mappings.get(indexName); + + this.indexName = this.indexName == null ? indexName : (this.indexName + "|" + indexName); + this.columns.addAll( + renameColumnWithTableAlias(query, populateColumns(query, fieldNames, typeMappings))); + } + + /** Rename column name with table alias as prefix for join query */ + private List renameColumnWithTableAlias(Query query, List columns) { + List renamedCols; + if ((query instanceof TableOnJoinSelect) + && !Strings.isNullOrEmpty(((TableOnJoinSelect) query).getAlias())) { + + TableOnJoinSelect joinQuery = (TableOnJoinSelect) query; + renamedCols = new ArrayList<>(); + + for (Schema.Column column : columns) { + renamedCols.add( + new Schema.Column( + joinQuery.getAlias() + "." + column.getName(), + column.getAlias(), + Schema.Type.valueOf(column.getType().toUpperCase()), + true)); + } + } else { + renamedCols = columns; + } + return renamedCols; + } + + private boolean isSelectAll() { + return selectAll; + } + + /** Is a simple (non-join/non-group-by) query with SELECT * explicitly */ + private boolean isSimpleQuerySelectAll(Query query) { + return (query instanceof Select) && ((Select) query).isSelectAll(); + } + + /** Is a join query with SELECT * on either one of the tables some fields specified */ + private boolean isJoinQuerySelectAll(Query query, String[] fieldNames) { + return fieldNames.length == 0 && !fieldsSelectedOnAnotherTable(query); + } + + /** + * In the case of a JOIN query, if no fields are SELECTed on for a particular table, the other + * table's fields are checked in SELECT to ensure a table is not incorrectly marked as a + * isSelectAll() case. + */ + private boolean fieldsSelectedOnAnotherTable(Query query) { + if (isJoinQuery()) { + TableOnJoinSelect otherTable = getOtherTable(query); + return otherTable.getSelectedFields().size() > 0; + } + + return false; + } + + private TableOnJoinSelect getOtherTable(Query currJoinSelect) { + JoinSelect joinQuery = (JoinSelect) query; + if (joinQuery.getFirstTable() == currJoinSelect) { + return joinQuery.getSecondTable(); + } else { + return joinQuery.getFirstTable(); + } + } + + private boolean containsWildcard(Query query) { + for (Field field : fetchFields(query)) { + if (!(field instanceof MethodField) && field.getName().contains("*")) { + return true; + } + } + + return false; + } + + private String fetchIndexName(Query query) { + return query.getFrom().get(0).getIndex(); + } + + /** + * queryResult is checked to see if it's of type Aggregation in which case the aggregation fields + * in GROUP BY are returned as well. This prevents returning a Schema of all fields when SELECT * + * is called with GROUP BY (since all fields will be retrieved from the typeMappings request when + * no fields are returned from fetchFields()). + * + *

After getting all of the fields from GROUP BY, the fields from SELECT are iterated and only + * the fields of type MethodField are added (to prevent duplicate field in Schema for queries like + * "SELECT age, COUNT(*) FROM bank GROUP BY age" where 'age' is mentioned in both SELECT and GROUP + * BY). + */ + private List fetchFields(Query query) { + Select select = (Select) query; + + if (queryResult instanceof Aggregations) { + List groupByFields = + select.getGroupBys().isEmpty() ? new ArrayList<>() : select.getGroupBys().get(0); + + for (Field selectField : select.getFields()) { + if (selectField instanceof MethodField && !selectField.isScriptField()) { + groupByFields.add(selectField); + } else if (selectField.isScriptField() + && selectField.getAlias().equals(groupByFields.get(0).getName())) { + return select.getFields(); + } + } + return groupByFields; + } + + if (query instanceof TableOnJoinSelect) { + return ((TableOnJoinSelect) query).getSelectedFields(); + } + + return select.getFields(); + } + + private String[] fetchFieldsAsArray(Query query) { + List fields = fetchFields(query); + return fields.stream().map(this::getFieldName).toArray(String[]::new); + } + + private String getFieldName(Field field) { + if (field instanceof MethodField) { + return field.getAlias(); + } + + return field.getName(); + } + + private Map fetchFieldMap(Query query) { + Map fieldMap = new HashMap<>(); + + for (Field field : fetchFields(query)) { + fieldMap.put(getFieldName(field), field); + } + + return fieldMap; + } + + private String[] selectAllFieldsIfEmpty(String[] fields) { + if (isSelectAll()) { + return new String[] {"*"}; + } + + return fields; + } + + private String[] emptyArrayIfNull(String typeName) { + if (typeName != null) { + return new String[] {typeName}; + } else { + return Strings.EMPTY_ARRAY; + } + } + + private Schema.Type fetchMethodReturnType(int fieldIndex, MethodField field) { + switch (field.getName().toLowerCase()) { + case "count": + return Schema.Type.LONG; + case "sum": + case "avg": + case "min": + case "max": + case "percentiles": + return Schema.Type.DOUBLE; + case "script": + { + // TODO: return type information is disconnected from the function definitions in + // SQLFunctions. + // Refactor SQLFunctions to have functions self-explanatory (types, scripts) and pluggable + // (similar to Strategy pattern) + if (field.getExpression() instanceof SQLCaseExpr) { + return Schema.Type.TEXT; + } + Schema.Type resolvedType = outputColumnType.get(fieldIndex); + return SQLFunctions.getScriptFunctionReturnType(field, resolvedType); + } + default: + throw new UnsupportedOperationException( + String.format("The following method is not supported in Schema: %s", field.getName())); + } + } + + /** + * Returns a list of Column objects which contain names identifying the field as well as its type. + * + *

If all fields are being selected (SELECT *) then the order of fields returned will be + * random, otherwise the output will be in the same order as how they were selected. + * + *

If an alias was given for a field, that will be used to identify the field in Column, + * otherwise the field name will be used. + */ + private List populateColumns( + Query query, String[] fieldNames, Map typeMappings) { + List fieldNameList; + + if (isSelectAll() || containsWildcard(query)) { + fieldNameList = new ArrayList<>(typeMappings.keySet()); + } else { + fieldNameList = Arrays.asList(fieldNames); + } + + /* + * The reason the 'fieldMap' mapping is needed on top of 'fieldNameList' is because the map would be + * empty in cases like 'SELECT *' but List fieldNameList will always be set in either case. + * That way, 'fieldNameList' is used to access field names in order that they were selected, if given, + * and then 'fieldMap' is used to access the respective Field object to check for aliases. */ - private List fetchFields(Query query) { - Select select = (Select) query; - - if (queryResult instanceof Aggregations) { - List groupByFields = select.getGroupBys().isEmpty() ? new ArrayList<>() : - select.getGroupBys().get(0); - - - for (Field selectField : select.getFields()) { - if (selectField instanceof MethodField && !selectField.isScriptField()) { - groupByFields.add(selectField); - } else if (selectField.isScriptField() - && selectField.getAlias().equals(groupByFields.get(0).getName())) { - return select.getFields(); - } - } - return groupByFields; - } - - if (query instanceof TableOnJoinSelect) { - return ((TableOnJoinSelect) query).getSelectedFields(); - } - - return select.getFields(); - } - - private String[] fetchFieldsAsArray(Query query) { - List fields = fetchFields(query); - return fields.stream() - .map(this::getFieldName) - .toArray(String[]::new); - } - - private String getFieldName(Field field) { - if (field instanceof MethodField) { - return field.getAlias(); - } - - return field.getName(); - } - - private Map fetchFieldMap(Query query) { - Map fieldMap = new HashMap<>(); - - for (Field field : fetchFields(query)) { - fieldMap.put(getFieldName(field), field); - } - - return fieldMap; - } - - private String[] selectAllFieldsIfEmpty(String[] fields) { - if (isSelectAll()) { - return new String[]{"*"}; - } - - return fields; - } - - private String[] emptyArrayIfNull(String typeName) { - if (typeName != null) { - return new String[]{typeName}; - } else { - return Strings.EMPTY_ARRAY; - } - } - - private Schema.Type fetchMethodReturnType(int fieldIndex, MethodField field) { - switch (field.getName().toLowerCase()) { - case "count": - return Schema.Type.LONG; - case "sum": - case "avg": - case "min": - case "max": - case "percentiles": - return Schema.Type.DOUBLE; - case "script": { - // TODO: return type information is disconnected from the function definitions in SQLFunctions. - // Refactor SQLFunctions to have functions self-explanatory (types, scripts) and pluggable - // (similar to Strategy pattern) - if (field.getExpression() instanceof SQLCaseExpr) { - return Schema.Type.TEXT; - } - Schema.Type resolvedType = outputColumnType.get(fieldIndex); - return SQLFunctions.getScriptFunctionReturnType(field, resolvedType); - } - default: - throw new UnsupportedOperationException( - String.format("The following method is not supported in Schema: %s", field.getName())); - } - } - - /** - * Returns a list of Column objects which contain names identifying the field as well as its type. - *

- * If all fields are being selected (SELECT *) then the order of fields returned will be random, otherwise - * the output will be in the same order as how they were selected. - *

- * If an alias was given for a field, that will be used to identify the field in Column, otherwise the field name - * will be used. - */ - private List populateColumns(Query query, String[] fieldNames, Map typeMappings) { - List fieldNameList; - - if (isSelectAll() || containsWildcard(query)) { - fieldNameList = new ArrayList<>(typeMappings.keySet()); - } else { - fieldNameList = Arrays.asList(fieldNames); + Map fieldMap = fetchFieldMap(query); + List columns = new ArrayList<>(); + for (String fieldName : fieldNameList) { + // _score is a special case since it is not included in typeMappings, so it is checked for + // here + if (fieldName.equals(SCORE)) { + columns.add( + new Schema.Column(fieldName, fetchAlias(fieldName, fieldMap), Schema.Type.FLOAT)); + continue; + } + /* + * Methods are also a special case as their type cannot be determined from typeMappings, so it is checked + * for here. + * + * Note: When adding the Column for Method, alias is used in place of getName() because the default name + * is set as alias (ex. COUNT(*)) and overwritten if an alias is given. So alias is used as the + * name instead. + */ + if (fieldMap.get(fieldName) instanceof MethodField) { + MethodField methodField = (MethodField) fieldMap.get(fieldName); + int fieldIndex = fieldNameList.indexOf(fieldName); + + SQLExpr expr = methodField.getExpression(); + if (expr instanceof SQLCastExpr) { + // Since CAST expressions create an alias for a field, we need to save the original field + // name + // for this alias for formatting data later. + SQLIdentifierExpr castFieldIdentifier = + (SQLIdentifierExpr) ((SQLCastExpr) expr).getExpr(); + fieldAliasMap.put(methodField.getAlias(), castFieldIdentifier.getName()); + } + + columns.add( + new Schema.Column( + methodField.getAlias(), null, fetchMethodReturnType(fieldIndex, methodField))); + continue; + } + + /* + * Unnecessary fields (ex. _index, _parent) are ignored. + * Fields like field.keyword will be ignored when isSelectAll is true but will be returned if + * explicitly selected. + */ + FieldMapping field = new FieldMapping(fieldName, typeMappings, fieldMap); + if (!field.isMetaField()) { + + if (field.isMultiField() && !field.isSpecified()) { + continue; + } + if (field.isPropertyField() && !field.isSpecified() && !field.isWildcardSpecified()) { + continue; } /* - * The reason the 'fieldMap' mapping is needed on top of 'fieldNameList' is because the map would be - * empty in cases like 'SELECT *' but List fieldNameList will always be set in either case. - * That way, 'fieldNameList' is used to access field names in order that they were selected, if given, - * and then 'fieldMap' is used to access the respective Field object to check for aliases. + * Three cases regarding Type: + * 1. If Type exists, create Column + * 2. If Type doesn't exist and isSelectAll() is false, throw exception + * 3. If Type doesn't exist and isSelectAll() is true, Column creation for fieldName is skipped */ - Map fieldMap = fetchFieldMap(query); - List columns = new ArrayList<>(); - for (String fieldName : fieldNameList) { - // _score is a special case since it is not included in typeMappings, so it is checked for here - if (fieldName.equals(SCORE)) { - columns.add(new Schema.Column(fieldName, fetchAlias(fieldName, fieldMap), Schema.Type.FLOAT)); - continue; - } - /* - * Methods are also a special case as their type cannot be determined from typeMappings, so it is checked - * for here. - * - * Note: When adding the Column for Method, alias is used in place of getName() because the default name - * is set as alias (ex. COUNT(*)) and overwritten if an alias is given. So alias is used as the - * name instead. - */ - if (fieldMap.get(fieldName) instanceof MethodField) { - MethodField methodField = (MethodField) fieldMap.get(fieldName); - int fieldIndex = fieldNameList.indexOf(fieldName); - - SQLExpr expr = methodField.getExpression(); - if (expr instanceof SQLCastExpr) { - // Since CAST expressions create an alias for a field, we need to save the original field name - // for this alias for formatting data later. - SQLIdentifierExpr castFieldIdentifier = (SQLIdentifierExpr) ((SQLCastExpr) expr).getExpr(); - fieldAliasMap.put(methodField.getAlias(), castFieldIdentifier.getName()); - } - - columns.add( - new Schema.Column( - methodField.getAlias(), - null, - fetchMethodReturnType(fieldIndex, methodField) - ) - ); - continue; - } - - /* - * Unnecessary fields (ex. _index, _parent) are ignored. - * Fields like field.keyword will be ignored when isSelectAll is true but will be returned if - * explicitly selected. - */ - FieldMapping field = new FieldMapping(fieldName, typeMappings, fieldMap); - if (!field.isMetaField()) { - - if (field.isMultiField() && !field.isSpecified()) { - continue; - } - if (field.isPropertyField() && !field.isSpecified() && !field.isWildcardSpecified()) { - continue; - } - - /* - * Three cases regarding Type: - * 1. If Type exists, create Column - * 2. If Type doesn't exist and isSelectAll() is false, throw exception - * 3. If Type doesn't exist and isSelectAll() is true, Column creation for fieldName is skipped - */ - String type = field.type().toUpperCase(); - if (Schema.hasType(type)) { - - // If the current field is a group key, we should use alias as the identifier - boolean isGroupKey = false; - Select select = (Select) query; - if (null != select.getGroupBys() - && !select.getGroupBys().isEmpty() - && select.getGroupBys().get(0).contains(fieldMap.get(fieldName))) { - isGroupKey = true; - } - - columns.add( - new Schema.Column( - fieldName, - fetchAlias(fieldName, fieldMap), - Schema.Type.valueOf(type), - isGroupKey - ) - ); - } else if (!isSelectAll()) { - throw new IllegalArgumentException( - String.format("%s fieldName types are currently not supported.", type)); - } - } - } - - if (isSelectAllOnly(query)) { - populateAllNestedFields(columns, fieldNameList); - } - return columns; - } - - /** - * SELECT * only without other columns or wildcard pattern specified. - */ - private boolean isSelectAllOnly(Query query) { - return isSelectAll() && fetchFields(query).isEmpty(); - } - - /** - * Special case which trades off consistency of SELECT * meaning for more intuition from customer perspective. - * In other cases, * means all regular fields on the level. - * The only exception here is * picks all non-regular (nested) fields as JSON without flatten. - */ - private void populateAllNestedFields(List columns, List fields) { - Set nestedFieldPaths = fields.stream(). - map(FieldMapping::new). - filter(FieldMapping::isPropertyField). - filter(f -> !f.isMultiField()). - map(FieldMapping::path). - collect(toSet()); - - for (String nestedFieldPath : nestedFieldPaths) { - columns.add( - new Schema.Column(nestedFieldPath, "", Schema.Type.TEXT) - ); - } - } - + String type = field.type().toUpperCase(); + if (Schema.hasType(type)) { + + // If the current field is a group key, we should use alias as the identifier + boolean isGroupKey = false; + Select select = (Select) query; + if (null != select.getGroupBys() + && !select.getGroupBys().isEmpty() + && select.getGroupBys().get(0).contains(fieldMap.get(fieldName))) { + isGroupKey = true; + } + + columns.add( + new Schema.Column( + fieldName, + fetchAlias(fieldName, fieldMap), + Schema.Type.valueOf(type), + isGroupKey)); + } else if (!isSelectAll()) { + throw new IllegalArgumentException( + String.format("%s fieldName types are currently not supported.", type)); + } + } + } + + if (isSelectAllOnly(query)) { + populateAllNestedFields(columns, fieldNameList); + } + return columns; + } + + /** SELECT * only without other columns or wildcard pattern specified. */ + private boolean isSelectAllOnly(Query query) { + return isSelectAll() && fetchFields(query).isEmpty(); + } + + /** + * Special case which trades off consistency of SELECT * meaning for more intuition from customer + * perspective. In other cases, * means all regular fields on the level. The only exception here + * is * picks all non-regular (nested) fields as JSON without flatten. + */ + private void populateAllNestedFields(List columns, List fields) { + Set nestedFieldPaths = + fields.stream() + .map(FieldMapping::new) + .filter(FieldMapping::isPropertyField) + .filter(f -> !f.isMultiField()) + .map(FieldMapping::path) + .collect(toSet()); + + for (String nestedFieldPath : nestedFieldPaths) { + columns.add(new Schema.Column(nestedFieldPath, "", Schema.Type.TEXT)); + } + } + + /** + * Since this helper method is called within a check to see if the field exists in type mapping, + * it's already confirmed that the fieldName is valid. The check for fieldName in fieldMap has to + * be done in the case that 'SELECT *' was called since the map will be empty. + */ + private String fetchAlias(String fieldName, Map fieldMap) { + if (fieldMap.containsKey(fieldName)) { + return fieldMap.get(fieldName).getAlias(); + } + + return null; + } + + // *********************************************************** + // Logic for loading Rows to be stored in DataRows + // *********************************************************** + + /** + * Extract data from query results into Row objects Need to cover two cases: 1. queryResult is a + * SearchHits object 2. queryResult is an Aggregations object + * + *

Ignoring queryResult being ActionResponse (from executeDeleteAction), there should be no + * data in this case + */ + private void extractData() { + if (queryResult instanceof SearchHits) { + SearchHits searchHits = (SearchHits) queryResult; + + this.rows = populateRows(searchHits); + this.size = rows.size(); + this.internalTotalHits = + Optional.ofNullable(searchHits.getTotalHits()).map(th -> th.value).orElse(0L); + // size may be greater than totalHits after nested rows be flatten + this.totalHits = Math.max(size, internalTotalHits); + } else if (queryResult instanceof Aggregations) { + Aggregations aggregations = (Aggregations) queryResult; + + this.rows = populateRows(aggregations); + this.size = rows.size(); + this.internalTotalHits = size; + // Total hits is not available from Aggregations so 'size' is used + this.totalHits = size; + } + } + + private void populateCursor() { + switch (cursor.getType()) { + case DEFAULT: + populateDefaultCursor((DefaultCursor) cursor); + default: + return; + } + } + + private void populateDefaultCursor(DefaultCursor cursor) { /** - * Since this helper method is called within a check to see if the field exists in type mapping, it's - * already confirmed that the fieldName is valid. The check for fieldName in fieldMap has to be done in the case - * that 'SELECT *' was called since the map will be empty. + * Assumption: scrollId, fetchSize, limit already being set in + * + * @see PrettyFormatRestExecutor.buildProtocolForDefaultQuery() */ - private String fetchAlias(String fieldName, Map fieldMap) { - if (fieldMap.containsKey(fieldName)) { - return fieldMap.get(fieldName).getAlias(); - } - - return null; - } - - //*********************************************************** - // Logic for loading Rows to be stored in DataRows - //*********************************************************** - - /** - * Extract data from query results into Row objects - * Need to cover two cases: - * 1. queryResult is a SearchHits object - * 2. queryResult is an Aggregations object - *

- * Ignoring queryResult being ActionResponse (from executeDeleteAction), there should be no data in this case - */ - private void extractData() { - if (queryResult instanceof SearchHits) { - SearchHits searchHits = (SearchHits) queryResult; - - this.rows = populateRows(searchHits); - this.size = rows.size(); - this.internalTotalHits = Optional.ofNullable(searchHits.getTotalHits()).map(th -> th.value).orElse(0L); - // size may be greater than totalHits after nested rows be flatten - this.totalHits = Math.max(size, internalTotalHits); - } else if (queryResult instanceof Aggregations) { - Aggregations aggregations = (Aggregations) queryResult; - - this.rows = populateRows(aggregations); - this.size = rows.size(); - this.internalTotalHits = size; - // Total hits is not available from Aggregations so 'size' is used - this.totalHits = size; - } - } - - private void populateCursor() { - switch(cursor.getType()) { - case DEFAULT: - populateDefaultCursor((DefaultCursor) cursor); - default: - return; - } - } - - private void populateDefaultCursor(DefaultCursor cursor) { - /** - * Assumption: scrollId, fetchSize, limit already being set in - * @see PrettyFormatRestExecutor.buildProtocolForDefaultQuery() - */ - - Integer limit = cursor.getLimit(); - long rowsLeft = rowsLeft(cursor.getFetchSize(), cursor.getLimit()); - if (rowsLeft <= 0) { - // close the cursor - String scrollId = cursor.getScrollId(); - ClearScrollResponse clearScrollResponse = client.prepareClearScroll().addScrollId(scrollId).get(); - if (!clearScrollResponse.isSucceeded()) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.error("Error closing the cursor context {} ", scrollId); - } - return; - } - - cursor.setRowsLeft(rowsLeft); - cursor.setIndexPattern(indexName); - cursor.setFieldAliasMap(fieldAliasMap()); - cursor.setColumns(columns); - this.totalHits = limit != null && limit < internalTotalHits ? limit : internalTotalHits; - } - - private long rowsLeft(Integer fetchSize, Integer limit) { - long rowsLeft = 0; - long totalHits = internalTotalHits; - if (limit != null && limit < totalHits) { - rowsLeft = limit - fetchSize; - } else { - rowsLeft = totalHits - fetchSize; - } - return rowsLeft; - } - - private List populateRows(SearchHits searchHits) { - List rows = new ArrayList<>(); - Set newKeys = new HashSet<>(head); - for (SearchHit hit : searchHits) { - Map rowSource = hit.getSourceAsMap(); - List result; - - if (!isJoinQuery()) { - // Row already flatten in source in join. And join doesn't support nested fields for now. - rowSource = flatRow(head, rowSource); - rowSource.put(SCORE, hit.getScore()); - - for (Map.Entry field : hit.getFields().entrySet()) { - rowSource.put(field.getKey(), field.getValue().getValue()); - } - if (formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) { - dateFieldFormatter.applyJDBCDateFormat(rowSource); - } - result = flatNestedField(newKeys, rowSource, hit.getInnerHits()); - } else { - if (formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) { - dateFieldFormatter.applyJDBCDateFormat(rowSource); - } - result = new ArrayList<>(); - result.add(new DataRows.Row(rowSource)); - } - - rows.addAll(result); - } - - return rows; - } - - private List populateRows(Aggregations aggregations) { - List rows = new ArrayList<>(); - List aggs = aggregations.asList(); - if (hasTermAggregations(aggs)) { - Terms terms = (Terms) aggs.get(0); - String field = terms.getName(); - - for (Terms.Bucket bucket : terms.getBuckets()) { - List aggRows = new ArrayList<>(); - getAggsData(bucket, aggRows, addMap(field, bucket.getKey())); - - rows.addAll(aggRows); - } - } else { - // This occurs for cases like "SELECT AVG(age) FROM bank" where we aggregate in SELECT with no GROUP BY - rows.add( - new DataRows.Row( - addNumericAggregation(aggs, new HashMap<>()) - ) - ); - } - return rows; - } - - /** - * This recursive method goes through the buckets iterated through populateRows() and flattens any inner - * aggregations and puts that data as a Map into a Row (this nested aggregation happens when we GROUP BY - * multiple fields) - */ - private void getAggsData(Terms.Bucket bucket, List aggRows, Map data) { - List aggs = bucket.getAggregations().asList(); - if (hasTermAggregations(aggs)) { - Terms terms = (Terms) aggs.get(0); - String field = terms.getName(); - - for (Terms.Bucket innerBucket : terms.getBuckets()) { - data.put(field, innerBucket.getKey()); - getAggsData(innerBucket, aggRows, data); - data.remove(field); - } - } else { - data = addNumericAggregation(aggs, data); - aggRows.add(new DataRows.Row(new HashMap<>(data))); - } - } - - /** - * hasTermAggregations() checks for specific type of aggregation, one that contains Terms. This is the case when the - * aggregations contains the contents of a GROUP BY field. - *

- * If the aggregation contains the data for an aggregation function (ex. COUNT(*)), the items in the list will - * be of instance InternalValueCount, InternalSum, etc. (depending on the aggregation function) and will be - * considered a base case of getAggsData() which will add that data to the Row (if it exists). - */ - private boolean hasTermAggregations(List aggs) { - return !aggs.isEmpty() && aggs.get(0) instanceof Terms; - } - - /** - * Adds the contents of Aggregation (specifically the NumericMetricsAggregation.SingleValue instance) from - * bucket.aggregations into the data map - */ - private Map addNumericAggregation(List aggs, Map data) { - for (Aggregation aggregation : aggs) { - if (aggregation instanceof NumericMetricsAggregation.SingleValue) { - NumericMetricsAggregation.SingleValue singleValueAggregation = - (NumericMetricsAggregation.SingleValue) aggregation; - data.put(singleValueAggregation.getName(), !Double.isInfinite(singleValueAggregation.value()) - ? singleValueAggregation.getValueAsString() : "null"); - } else if (aggregation instanceof Percentiles) { - Percentiles percentiles = (Percentiles) aggregation; - - data.put(percentiles.getName(), StreamSupport - .stream(percentiles.spliterator(), false) - .collect( - Collectors.toMap( - Percentile::getPercent, - Percentile::getValue, - (v1, v2) -> { - throw new IllegalArgumentException( - String.format("Duplicate key for values %s and %s", v1, v2)); - }, - TreeMap::new))); - } else { - throw new SqlFeatureNotImplementedException("Aggregation type " + aggregation.getType() - + " is not yet implemented"); - } - } - - return data; - } + Integer limit = cursor.getLimit(); + long rowsLeft = rowsLeft(cursor.getFetchSize(), cursor.getLimit()); + if (rowsLeft <= 0) { + // close the cursor + String scrollId = cursor.getScrollId(); + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(scrollId).get(); + if (!clearScrollResponse.isSucceeded()) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.error("Error closing the cursor context {} ", scrollId); + } + return; + } + + cursor.setRowsLeft(rowsLeft); + cursor.setIndexPattern(indexName); + cursor.setFieldAliasMap(fieldAliasMap()); + cursor.setColumns(columns); + this.totalHits = limit != null && limit < internalTotalHits ? limit : internalTotalHits; + } + + private long rowsLeft(Integer fetchSize, Integer limit) { + long rowsLeft = 0; + long totalHits = internalTotalHits; + if (limit != null && limit < totalHits) { + rowsLeft = limit - fetchSize; + } else { + rowsLeft = totalHits - fetchSize; + } + return rowsLeft; + } + + private List populateRows(SearchHits searchHits) { + List rows = new ArrayList<>(); + Set newKeys = new HashSet<>(head); + for (SearchHit hit : searchHits) { + Map rowSource = hit.getSourceAsMap(); + List result; + + if (!isJoinQuery()) { + // Row already flatten in source in join. And join doesn't support nested fields for now. + rowSource = flatRow(head, rowSource); + rowSource.put(SCORE, hit.getScore()); + + for (Map.Entry field : hit.getFields().entrySet()) { + rowSource.put(field.getKey(), field.getValue().getValue()); + } + if (formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) { + dateFieldFormatter.applyJDBCDateFormat(rowSource); + } + result = flatNestedField(newKeys, rowSource, hit.getInnerHits()); + } else { + if (formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) { + dateFieldFormatter.applyJDBCDateFormat(rowSource); + } + result = new ArrayList<>(); + result.add(new DataRows.Row(rowSource)); + } + + rows.addAll(result); + } + + return rows; + } + + private List populateRows(Aggregations aggregations) { + List rows = new ArrayList<>(); + List aggs = aggregations.asList(); + if (hasTermAggregations(aggs)) { + Terms terms = (Terms) aggs.get(0); + String field = terms.getName(); + + for (Terms.Bucket bucket : terms.getBuckets()) { + List aggRows = new ArrayList<>(); + getAggsData(bucket, aggRows, addMap(field, bucket.getKey())); + + rows.addAll(aggRows); + } + } else { + // This occurs for cases like "SELECT AVG(age) FROM bank" where we aggregate in SELECT with no + // GROUP BY + rows.add(new DataRows.Row(addNumericAggregation(aggs, new HashMap<>()))); + } + return rows; + } + + /** + * This recursive method goes through the buckets iterated through populateRows() and flattens any + * inner aggregations and puts that data as a Map into a Row (this nested aggregation happens when + * we GROUP BY multiple fields) + */ + private void getAggsData( + Terms.Bucket bucket, List aggRows, Map data) { + List aggs = bucket.getAggregations().asList(); + if (hasTermAggregations(aggs)) { + Terms terms = (Terms) aggs.get(0); + String field = terms.getName(); + + for (Terms.Bucket innerBucket : terms.getBuckets()) { + data.put(field, innerBucket.getKey()); + getAggsData(innerBucket, aggRows, data); + data.remove(field); + } + } else { + data = addNumericAggregation(aggs, data); + aggRows.add(new DataRows.Row(new HashMap<>(data))); + } + } + + /** + * hasTermAggregations() checks for specific type of aggregation, one that contains Terms. This is + * the case when the aggregations contains the contents of a GROUP BY field. + * + *

If the aggregation contains the data for an aggregation function (ex. COUNT(*)), the items + * in the list will be of instance InternalValueCount, InternalSum, etc. (depending on the + * aggregation function) and will be considered a base case of getAggsData() which will add that + * data to the Row (if it exists). + */ + private boolean hasTermAggregations(List aggs) { + return !aggs.isEmpty() && aggs.get(0) instanceof Terms; + } + + /** + * Adds the contents of Aggregation (specifically the NumericMetricsAggregation.SingleValue + * instance) from bucket.aggregations into the data map + */ + private Map addNumericAggregation( + List aggs, Map data) { + for (Aggregation aggregation : aggs) { + if (aggregation instanceof NumericMetricsAggregation.SingleValue) { + NumericMetricsAggregation.SingleValue singleValueAggregation = + (NumericMetricsAggregation.SingleValue) aggregation; + data.put( + singleValueAggregation.getName(), + !Double.isInfinite(singleValueAggregation.value()) + ? singleValueAggregation.getValueAsString() + : "null"); + } else if (aggregation instanceof Percentiles) { + Percentiles percentiles = (Percentiles) aggregation; + + data.put( + percentiles.getName(), + StreamSupport.stream(percentiles.spliterator(), false) + .collect( + Collectors.toMap( + Percentile::getPercent, + Percentile::getValue, + (v1, v2) -> { + throw new IllegalArgumentException( + String.format("Duplicate key for values %s and %s", v1, v2)); + }, + TreeMap::new))); + } else { + throw new SqlFeatureNotImplementedException( + "Aggregation type " + aggregation.getType() + " is not yet implemented"); + } + } + + return data; + } /** + *

      * Simplifies the structure of row's source Map by flattening it, making the full path of an object the key
      * and the Object it refers to the value. This handles the case of regular object since nested objects will not
      * be in hit.source but rather in hit.innerHits
@@ -741,6 +735,7 @@ private Map addNumericAggregation(List aggs, Map
      * Return:
      * flattenedRow = {comment.likes: 2}
+     * 
*/ @SuppressWarnings("unchecked") private Map flatRow(List keys, Map row) { @@ -750,31 +745,33 @@ private Map flatRow(List keys, Map row) boolean found = true; Object currentObj = row; - for (String splitKey : splitKeys) { - // This check is made to prevent Cast Exception as an ArrayList of objects can be in the sourceMap - if (!(currentObj instanceof Map)) { - found = false; - break; - } - - Map currentMap = (Map) currentObj; - if (!currentMap.containsKey(splitKey)) { - found = false; - break; - } - - currentObj = currentMap.get(splitKey); - } - - if (found) { - flattenedRow.put(key, currentObj); - } + for (String splitKey : splitKeys) { + // This check is made to prevent Cast Exception as an ArrayList of objects can be in the + // sourceMap + if (!(currentObj instanceof Map)) { + found = false; + break; } - return flattenedRow; + Map currentMap = (Map) currentObj; + if (!currentMap.containsKey(splitKey)) { + found = false; + break; + } + + currentObj = currentMap.get(splitKey); + } + + if (found) { + flattenedRow.put(key, currentObj); + } } + return flattenedRow; + } + /** + *
      * If innerHits associated with column name exists, flatten both the inner field name and the inner rows in it.
      * 

* Sample input: @@ -792,36 +789,38 @@ private Map flatRow(List keys, Map row) * } * }] * } + *

*/ private List flatNestedField(Set newKeys, Map row, Map innerHits) { List result = new ArrayList<>(); result.add(new DataRows.Row(row)); - if (innerHits == null) { - return result; - } - - for (String colName : innerHits.keySet()) { - SearchHit[] colValue = innerHits.get(colName).getHits(); - doFlatNestedFieldName(colName, colValue, newKeys); - result = doFlatNestedFieldValue(colName, colValue, result); - } + if (innerHits == null) { + return result; + } - return result; + for (String colName : innerHits.keySet()) { + SearchHit[] colValue = innerHits.get(colName).getHits(); + doFlatNestedFieldName(colName, colValue, newKeys); + result = doFlatNestedFieldValue(colName, colValue, result); } - private void doFlatNestedFieldName(String colName, SearchHit[] colValue, Set keys) { - Map innerRow = colValue[0].getSourceAsMap(); - for (String field : innerRow.keySet()) { - String innerName = colName + "." + field; - keys.add(innerName); - } + return result; + } - keys.remove(colName); + private void doFlatNestedFieldName(String colName, SearchHit[] colValue, Set keys) { + Map innerRow = colValue[0].getSourceAsMap(); + for (String field : innerRow.keySet()) { + String innerName = colName + "." + field; + keys.add(innerName); } + keys.remove(colName); + } + /** + *
      * Do Cartesian Product between current outer row and inner rows by nested loop and remove original outer row.
      * 

* Sample input: @@ -843,6 +842,7 @@ private void doFlatNestedFieldName(String colName, SearchHit[] colValue, Set */ private List doFlatNestedFieldValue(String colName, SearchHit[] colValue, List rows) { List result = new ArrayList<>(); @@ -851,28 +851,28 @@ private List doFlatNestedFieldValue(String colName, SearchHit[] co Map innerRow = hit.getSourceAsMap(); Map copy = new HashMap<>(); - for (String field : row.getContents().keySet()) { - copy.put(field, row.getData(field)); - } - for (String field : innerRow.keySet()) { - copy.put(colName + "." + field, innerRow.get(field)); - } - - copy.remove(colName); - result.add(new DataRows.Row(copy)); - } + for (String field : row.getContents().keySet()) { + copy.put(field, row.getData(field)); + } + for (String field : innerRow.keySet()) { + copy.put(colName + "." + field, innerRow.get(field)); } - return result; + copy.remove(colName); + result.add(new DataRows.Row(copy)); + } } - private Map addMap(String field, Object term) { - Map data = new HashMap<>(); - data.put(field, term); - return data; - } + return result; + } - private boolean isJoinQuery() { - return query instanceof JoinSelect; - } + private Map addMap(String field, Object term) { + Map data = new HashMap<>(); + data.put(field, term); + return data; + } + + private boolean isJoinQuery() { + return query instanceof JoinSelect; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/NestedLoopsElasticExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/NestedLoopsElasticExecutor.java index 21a9a6054f..56c5f96af5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/NestedLoopsElasticExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/NestedLoopsElasticExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.join; import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource; @@ -34,301 +33,354 @@ import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; import org.opensearch.sql.legacy.query.maker.Maker; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public class NestedLoopsElasticExecutor extends ElasticJoinExecutor { - private static final Logger LOG = LogManager.getLogger(); + private static final Logger LOG = LogManager.getLogger(); - private final NestedLoopsElasticRequestBuilder nestedLoopsRequest; - private final Client client; + private final NestedLoopsElasticRequestBuilder nestedLoopsRequest; + private final Client client; - public NestedLoopsElasticExecutor(Client client, NestedLoopsElasticRequestBuilder nestedLoops) { - super(nestedLoops); - this.client = client; - this.nestedLoopsRequest = nestedLoops; - } + public NestedLoopsElasticExecutor(Client client, NestedLoopsElasticRequestBuilder nestedLoops) { + super(nestedLoops); + this.client = client; + this.nestedLoopsRequest = nestedLoops; + } - @Override - protected List innerRun() throws SqlParseException { - List combinedResults = new ArrayList<>(); - int totalLimit = nestedLoopsRequest.getTotalLimit(); - int multiSearchMaxSize = nestedLoopsRequest.getMultiSearchMaxSize(); - Select secondTableSelect = nestedLoopsRequest.getSecondTable().getOriginalSelect(); - Where originalSecondTableWhere = secondTableSelect.getWhere(); + @Override + protected List innerRun() throws SqlParseException { + List combinedResults = new ArrayList<>(); + int totalLimit = nestedLoopsRequest.getTotalLimit(); + int multiSearchMaxSize = nestedLoopsRequest.getMultiSearchMaxSize(); + Select secondTableSelect = nestedLoopsRequest.getSecondTable().getOriginalSelect(); + Where originalSecondTableWhere = secondTableSelect.getWhere(); - orderConditions(nestedLoopsRequest.getFirstTable().getAlias(), nestedLoopsRequest.getSecondTable().getAlias()); + orderConditions( + nestedLoopsRequest.getFirstTable().getAlias(), + nestedLoopsRequest.getSecondTable().getAlias()); + if (!BackOffRetryStrategy.isHealthy()) { + throw new IllegalStateException("Memory circuit is broken"); + } + FetchWithScrollResponse fetchWithScrollResponse = + firstFetch(this.nestedLoopsRequest.getFirstTable()); + SearchResponse firstTableResponse = fetchWithScrollResponse.getResponse(); + boolean needScrollForFirstTable = fetchWithScrollResponse.isNeedScrollForFirstTable(); + + int currentCombinedResults = 0; + boolean finishedWithFirstTable = false; + + while (totalLimit > currentCombinedResults && !finishedWithFirstTable) { + + SearchHit[] hits = firstTableResponse.getHits().getHits(); + boolean finishedMultiSearches = hits.length == 0; + int currentHitsIndex = 0; + + while (!finishedMultiSearches) { + MultiSearchRequest multiSearchRequest = + createMultiSearchRequest( + multiSearchMaxSize, + nestedLoopsRequest.getConnectedWhere(), + hits, + secondTableSelect, + originalSecondTableWhere, + currentHitsIndex); + int multiSearchSize = multiSearchRequest.requests().size(); if (!BackOffRetryStrategy.isHealthy()) { - throw new IllegalStateException("Memory circuit is broken"); + throw new IllegalStateException("Memory circuit is broken"); } - FetchWithScrollResponse fetchWithScrollResponse = firstFetch(this.nestedLoopsRequest.getFirstTable()); - SearchResponse firstTableResponse = fetchWithScrollResponse.getResponse(); - boolean needScrollForFirstTable = fetchWithScrollResponse.isNeedScrollForFirstTable(); - - int currentCombinedResults = 0; - boolean finishedWithFirstTable = false; - - while (totalLimit > currentCombinedResults && !finishedWithFirstTable) { - - SearchHit[] hits = firstTableResponse.getHits().getHits(); - boolean finishedMultiSearches = hits.length == 0; - int currentHitsIndex = 0; - - while (!finishedMultiSearches) { - MultiSearchRequest multiSearchRequest = createMultiSearchRequest(multiSearchMaxSize, - nestedLoopsRequest.getConnectedWhere(), hits, secondTableSelect, - originalSecondTableWhere, currentHitsIndex); - int multiSearchSize = multiSearchRequest.requests().size(); - if (!BackOffRetryStrategy.isHealthy()) { - throw new IllegalStateException("Memory circuit is broken"); - } - currentCombinedResults = combineResultsFromMultiResponses(combinedResults, totalLimit, - currentCombinedResults, hits, currentHitsIndex, multiSearchRequest); - currentHitsIndex += multiSearchSize; - finishedMultiSearches = currentHitsIndex >= hits.length - 1 || currentCombinedResults >= totalLimit; - } - - if (hits.length < MAX_RESULTS_ON_ONE_FETCH) { - needScrollForFirstTable = false; - } - - if (!finishedWithFirstTable) { - if (needScrollForFirstTable) { - if (!BackOffRetryStrategy.isHealthy()) { - throw new IllegalStateException("Memory circuit is broken"); - } - firstTableResponse = client.prepareSearchScroll(firstTableResponse.getScrollId()) - .setScroll(new TimeValue(600000)).get(); - } else { - finishedWithFirstTable = true; - } - } - + currentCombinedResults = + combineResultsFromMultiResponses( + combinedResults, + totalLimit, + currentCombinedResults, + hits, + currentHitsIndex, + multiSearchRequest); + currentHitsIndex += multiSearchSize; + finishedMultiSearches = + currentHitsIndex >= hits.length - 1 || currentCombinedResults >= totalLimit; + } + + if (hits.length < MAX_RESULTS_ON_ONE_FETCH) { + needScrollForFirstTable = false; + } + + if (!finishedWithFirstTable) { + if (needScrollForFirstTable) { + if (!BackOffRetryStrategy.isHealthy()) { + throw new IllegalStateException("Memory circuit is broken"); + } + firstTableResponse = + client + .prepareSearchScroll(firstTableResponse.getScrollId()) + .setScroll(new TimeValue(600000)) + .get(); + } else { + finishedWithFirstTable = true; } - return combinedResults; + } } - - private int combineResultsFromMultiResponses(List combinedResults, int totalLimit, - int currentCombinedResults, SearchHit[] hits, int currentIndex, - MultiSearchRequest multiSearchRequest) { - MultiSearchResponse.Item[] responses = new OpenSearchClient(client).multiSearch(multiSearchRequest); - String t1Alias = nestedLoopsRequest.getFirstTable().getAlias(); - String t2Alias = nestedLoopsRequest.getSecondTable().getAlias(); - - for (int j = 0; j < responses.length && currentCombinedResults < totalLimit; j++) { - SearchHit hitFromFirstTable = hits[currentIndex + j]; - onlyReturnedFields(hitFromFirstTable.getSourceAsMap(), - nestedLoopsRequest.getFirstTable().getReturnedFields(), - nestedLoopsRequest.getFirstTable().getOriginalSelect().isSelectAll()); - - SearchResponse multiItemResponse = responses[j].getResponse(); - - if (multiItemResponse == null) { - continue; - } - - updateMetaSearchResults(multiItemResponse); - - //todo: if responseForHit.getHits.length < responseForHit.getTotalHits(). need to fetch more! - SearchHits responseForHit = multiItemResponse.getHits(); - - if (responseForHit.getHits().length == 0 && nestedLoopsRequest.getJoinType() - == SQLJoinTableSource.JoinType.LEFT_OUTER_JOIN) { - SearchHit unmachedResult = createUnmachedResult(nestedLoopsRequest.getSecondTable().getReturnedFields(), - currentCombinedResults, t1Alias, t2Alias, hitFromFirstTable); - combinedResults.add(unmachedResult); - currentCombinedResults++; - continue; - } - - for (SearchHit matchedHit : responseForHit.getHits()) { - SearchHit searchHit = getMergedHit(currentCombinedResults, t1Alias, t2Alias, hitFromFirstTable, - matchedHit); - combinedResults.add(searchHit); - currentCombinedResults++; - if (currentCombinedResults >= totalLimit) { - break; - } - } - if (currentCombinedResults >= totalLimit) { - break; - } - + return combinedResults; + } + + private int combineResultsFromMultiResponses( + List combinedResults, + int totalLimit, + int currentCombinedResults, + SearchHit[] hits, + int currentIndex, + MultiSearchRequest multiSearchRequest) { + MultiSearchResponse.Item[] responses = + new OpenSearchClient(client).multiSearch(multiSearchRequest); + String t1Alias = nestedLoopsRequest.getFirstTable().getAlias(); + String t2Alias = nestedLoopsRequest.getSecondTable().getAlias(); + + for (int j = 0; j < responses.length && currentCombinedResults < totalLimit; j++) { + SearchHit hitFromFirstTable = hits[currentIndex + j]; + onlyReturnedFields( + hitFromFirstTable.getSourceAsMap(), + nestedLoopsRequest.getFirstTable().getReturnedFields(), + nestedLoopsRequest.getFirstTable().getOriginalSelect().isSelectAll()); + + SearchResponse multiItemResponse = responses[j].getResponse(); + + if (multiItemResponse == null) { + continue; + } + + updateMetaSearchResults(multiItemResponse); + + // todo: if responseForHit.getHits.length < responseForHit.getTotalHits(). need to fetch more! + SearchHits responseForHit = multiItemResponse.getHits(); + + if (responseForHit.getHits().length == 0 + && nestedLoopsRequest.getJoinType() == SQLJoinTableSource.JoinType.LEFT_OUTER_JOIN) { + SearchHit unmachedResult = + createUnmachedResult( + nestedLoopsRequest.getSecondTable().getReturnedFields(), + currentCombinedResults, + t1Alias, + t2Alias, + hitFromFirstTable); + combinedResults.add(unmachedResult); + currentCombinedResults++; + continue; + } + + for (SearchHit matchedHit : responseForHit.getHits()) { + SearchHit searchHit = + getMergedHit(currentCombinedResults, t1Alias, t2Alias, hitFromFirstTable, matchedHit); + combinedResults.add(searchHit); + currentCombinedResults++; + if (currentCombinedResults >= totalLimit) { + break; } - return currentCombinedResults; - } - - private SearchHit getMergedHit(int currentCombinedResults, String t1Alias, String t2Alias, - SearchHit hitFromFirstTable, SearchHit matchedHit) { - onlyReturnedFields(matchedHit.getSourceAsMap(), nestedLoopsRequest.getSecondTable().getReturnedFields(), - nestedLoopsRequest.getSecondTable().getOriginalSelect().isSelectAll()); - Map documentFields = new HashMap<>(); - Map metaFields = new HashMap<>(); - matchedHit.getFields().forEach((fieldName, docField) -> - (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) ? metaFields : documentFields).put(fieldName, docField)); - SearchHit searchHit = new SearchHit(currentCombinedResults, hitFromFirstTable.getId() + "|" - + matchedHit.getId(), documentFields, metaFields); - searchHit.sourceRef(hitFromFirstTable.getSourceRef()); - searchHit.getSourceAsMap().clear(); - searchHit.getSourceAsMap().putAll(hitFromFirstTable.getSourceAsMap()); - - mergeSourceAndAddAliases(matchedHit.getSourceAsMap(), searchHit, t1Alias, t2Alias); - return searchHit; + } + if (currentCombinedResults >= totalLimit) { + break; + } } - - private MultiSearchRequest createMultiSearchRequest(int multiSearchMaxSize, Where connectedWhere, SearchHit[] hits, - Select secondTableSelect, Where originalWhere, int currentIndex) - throws SqlParseException { - MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - for (int i = currentIndex; i < currentIndex + multiSearchMaxSize && i < hits.length; i++) { - Map hitFromFirstTableAsMap = hits[i].getSourceAsMap(); - Where newWhere = Where.newInstance(); - if (originalWhere != null) { - newWhere.addWhere(originalWhere); - } - if (connectedWhere != null) { - Where connectedWhereCloned = null; - try { - connectedWhereCloned = (Where) connectedWhere.clone(); - } catch (CloneNotSupportedException e) { - e.printStackTrace(); - } - updateValuesOnWhereConditions(hitFromFirstTableAsMap, connectedWhereCloned); - newWhere.addWhere(connectedWhereCloned); - } - - -// for(Condition c : conditions){ -// Object value = deepSearchInMap(hitFromFirstTableAsMap,c.getValue().toString()); -// Condition conditionWithValue = new Condition(Where.CONN.AND,c.getName(),c.getOpear(),value); -// newWhere.addWhere(conditionWithValue); -// } - //using the 2nd table select and DefaultAction because we can't just change query on request - // (need to create lot of requests) - if (newWhere.getWheres().size() != 0) { - secondTableSelect.setWhere(newWhere); - } - DefaultQueryAction action = new DefaultQueryAction(this.client, secondTableSelect); - action.explain(); - SearchRequestBuilder secondTableRequest = action.getRequestBuilder(); - Integer secondTableHintLimit = this.nestedLoopsRequest.getSecondTable().getHintLimit(); - if (secondTableHintLimit != null && secondTableHintLimit <= MAX_RESULTS_ON_ONE_FETCH) { - secondTableRequest.setSize(secondTableHintLimit); - } - multiSearchRequest.add(secondTableRequest); + return currentCombinedResults; + } + + private SearchHit getMergedHit( + int currentCombinedResults, + String t1Alias, + String t2Alias, + SearchHit hitFromFirstTable, + SearchHit matchedHit) { + onlyReturnedFields( + matchedHit.getSourceAsMap(), + nestedLoopsRequest.getSecondTable().getReturnedFields(), + nestedLoopsRequest.getSecondTable().getOriginalSelect().isSelectAll()); + Map documentFields = new HashMap<>(); + Map metaFields = new HashMap<>(); + matchedHit + .getFields() + .forEach( + (fieldName, docField) -> + (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) + ? metaFields + : documentFields) + .put(fieldName, docField)); + SearchHit searchHit = + new SearchHit( + currentCombinedResults, + hitFromFirstTable.getId() + "|" + matchedHit.getId(), + documentFields, + metaFields); + searchHit.sourceRef(hitFromFirstTable.getSourceRef()); + searchHit.getSourceAsMap().clear(); + searchHit.getSourceAsMap().putAll(hitFromFirstTable.getSourceAsMap()); + + mergeSourceAndAddAliases(matchedHit.getSourceAsMap(), searchHit, t1Alias, t2Alias); + return searchHit; + } + + private MultiSearchRequest createMultiSearchRequest( + int multiSearchMaxSize, + Where connectedWhere, + SearchHit[] hits, + Select secondTableSelect, + Where originalWhere, + int currentIndex) + throws SqlParseException { + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (int i = currentIndex; i < currentIndex + multiSearchMaxSize && i < hits.length; i++) { + Map hitFromFirstTableAsMap = hits[i].getSourceAsMap(); + Where newWhere = Where.newInstance(); + if (originalWhere != null) { + newWhere.addWhere(originalWhere); + } + if (connectedWhere != null) { + Where connectedWhereCloned = null; + try { + connectedWhereCloned = (Where) connectedWhere.clone(); + } catch (CloneNotSupportedException e) { + e.printStackTrace(); } - return multiSearchRequest; + updateValuesOnWhereConditions(hitFromFirstTableAsMap, connectedWhereCloned); + newWhere.addWhere(connectedWhereCloned); + } + + // for(Condition c : conditions){ + // Object value = + // deepSearchInMap(hitFromFirstTableAsMap,c.getValue().toString()); + // Condition conditionWithValue = new + // Condition(Where.CONN.AND,c.getName(),c.getOpear(),value); + // newWhere.addWhere(conditionWithValue); + // } + // using the 2nd table select and DefaultAction because we can't just change query on request + // (need to create lot of requests) + if (newWhere.getWheres().size() != 0) { + secondTableSelect.setWhere(newWhere); + } + DefaultQueryAction action = new DefaultQueryAction(this.client, secondTableSelect); + action.explain(); + SearchRequestBuilder secondTableRequest = action.getRequestBuilder(); + Integer secondTableHintLimit = this.nestedLoopsRequest.getSecondTable().getHintLimit(); + if (secondTableHintLimit != null && secondTableHintLimit <= MAX_RESULTS_ON_ONE_FETCH) { + secondTableRequest.setSize(secondTableHintLimit); + } + multiSearchRequest.add(secondTableRequest); } - - private void updateValuesOnWhereConditions(Map hit, Where where) { - if (where instanceof Condition) { - Condition c = (Condition) where; - Object value = deepSearchInMap(hit, c.getValue().toString()); - if (value == null) { - value = Maker.NONE; - } - c.setValue(value); - } - for (Where innerWhere : where.getWheres()) { - updateValuesOnWhereConditions(hit, innerWhere); - } + return multiSearchRequest; + } + + private void updateValuesOnWhereConditions(Map hit, Where where) { + if (where instanceof Condition) { + Condition c = (Condition) where; + Object value = deepSearchInMap(hit, c.getValue().toString()); + if (value == null) { + value = Maker.NONE; + } + c.setValue(value); } - - private FetchWithScrollResponse firstFetch(TableInJoinRequestBuilder tableRequest) { - Integer hintLimit = tableRequest.getHintLimit(); - boolean needScrollForFirstTable = false; - SearchResponse responseWithHits; - if (hintLimit != null && hintLimit < MAX_RESULTS_ON_ONE_FETCH) { - - responseWithHits = tableRequest.getRequestBuilder().setSize(hintLimit).get(); - needScrollForFirstTable = false; - } else { - //scroll request with max. - responseWithHits = scrollOneTimeWithMax(client, tableRequest); - if (responseWithHits.getHits().getTotalHits() != null - && responseWithHits.getHits().getTotalHits().value < MAX_RESULTS_ON_ONE_FETCH) { - needScrollForFirstTable = true; - } - } - - updateMetaSearchResults(responseWithHits); - return new FetchWithScrollResponse(responseWithHits, needScrollForFirstTable); + for (Where innerWhere : where.getWheres()) { + updateValuesOnWhereConditions(hit, innerWhere); } - - - private void orderConditions(String t1Alias, String t2Alias) { - orderConditionRecursive(t1Alias, t2Alias, nestedLoopsRequest.getConnectedWhere()); -// Collection conditions = nestedLoopsRequest.getT1FieldToCondition().values(); -// for(Condition c : conditions){ -// //TODO: support all orders and for each OPEAR find his related OPEAR (< is > , EQ is EQ ,etc..) -// if(!c.getName().startsWith(t2Alias+".") || !c.getValue().toString().startsWith(t1Alias +".")) -// throw new RuntimeException("On NestedLoops currently only supported Ordered conditions -// t2.field2 OPEAR t1.field1) , badCondition was:" + c); -// c.setName(c.getName().replaceFirst(t2Alias+".","")); -// c.setValue(c.getValue().toString().replaceFirst(t1Alias+ ".", "")); -// } + } + + private FetchWithScrollResponse firstFetch(TableInJoinRequestBuilder tableRequest) { + Integer hintLimit = tableRequest.getHintLimit(); + boolean needScrollForFirstTable = false; + SearchResponse responseWithHits; + if (hintLimit != null && hintLimit < MAX_RESULTS_ON_ONE_FETCH) { + + responseWithHits = tableRequest.getRequestBuilder().setSize(hintLimit).get(); + needScrollForFirstTable = false; + } else { + // scroll request with max. + responseWithHits = scrollOneTimeWithMax(client, tableRequest); + if (responseWithHits.getHits().getTotalHits() != null + && responseWithHits.getHits().getTotalHits().value < MAX_RESULTS_ON_ONE_FETCH) { + needScrollForFirstTable = true; + } } - private void orderConditionRecursive(String t1Alias, String t2Alias, Where where) { - if (where == null) { - return; - } - if (where instanceof Condition) { - Condition c = (Condition) where; - if (shouldReverse(c, t1Alias, t2Alias)) { - try { - reverseOrderOfCondition(c, t1Alias, t2Alias); - return; - } catch (SqlParseException e) { - //Do nothing here to continue using original logic below. - //The condition is not changed here. - } - } - if (!c.getName().startsWith(t2Alias + ".") || !c.getValue().toString().startsWith(t1Alias + ".")) { - throw new RuntimeException("On NestedLoops currently only supported Ordered conditions " - + "(t2.field2 OPEAR t1.field1) , badCondition was:" + c); - } - c.setName(c.getName().replaceFirst(t2Alias + ".", "")); - c.setValue(c.getValue().toString().replaceFirst(t1Alias + ".", "")); - return; - } else { - for (Where innerWhere : where.getWheres()) { - orderConditionRecursive(t1Alias, t2Alias, innerWhere); - } + updateMetaSearchResults(responseWithHits); + return new FetchWithScrollResponse(responseWithHits, needScrollForFirstTable); + } + + private void orderConditions(String t1Alias, String t2Alias) { + orderConditionRecursive(t1Alias, t2Alias, nestedLoopsRequest.getConnectedWhere()); + // Collection conditions = + // nestedLoopsRequest.getT1FieldToCondition().values(); + // for(Condition c : conditions){ + // //TODO: support all orders and for each OPEAR find his related OPEAR (< is > , EQ + // is EQ ,etc..) + // if(!c.getName().startsWith(t2Alias+".") || + // !c.getValue().toString().startsWith(t1Alias +".")) + // throw new RuntimeException("On NestedLoops currently only supported Ordered + // conditions + // t2.field2 OPEAR t1.field1) , badCondition was:" + c); + // c.setName(c.getName().replaceFirst(t2Alias+".","")); + // c.setValue(c.getValue().toString().replaceFirst(t1Alias+ ".", "")); + // } + } + + private void orderConditionRecursive(String t1Alias, String t2Alias, Where where) { + if (where == null) { + return; + } + if (where instanceof Condition) { + Condition c = (Condition) where; + if (shouldReverse(c, t1Alias, t2Alias)) { + try { + reverseOrderOfCondition(c, t1Alias, t2Alias); + return; + } catch (SqlParseException e) { + // Do nothing here to continue using original logic below. + // The condition is not changed here. } + } + if (!c.getName().startsWith(t2Alias + ".") + || !c.getValue().toString().startsWith(t1Alias + ".")) { + throw new RuntimeException( + "On NestedLoops currently only supported Ordered conditions " + + "(t2.field2 OPEAR t1.field1) , badCondition was:" + + c); + } + c.setName(c.getName().replaceFirst(t2Alias + ".", "")); + c.setValue(c.getValue().toString().replaceFirst(t1Alias + ".", "")); + return; + } else { + for (Where innerWhere : where.getWheres()) { + orderConditionRecursive(t1Alias, t2Alias, innerWhere); + } } - - private Boolean shouldReverse(Condition cond, String t1Alias, String t2Alias) { - return cond.getName().startsWith(t1Alias + ".") && cond.getValue().toString().startsWith(t2Alias + ".") - && cond.getOPERATOR().isSimpleOperator(); + } + + private Boolean shouldReverse(Condition cond, String t1Alias, String t2Alias) { + return cond.getName().startsWith(t1Alias + ".") + && cond.getValue().toString().startsWith(t2Alias + ".") + && cond.getOPERATOR().isSimpleOperator(); + } + + private void reverseOrderOfCondition(Condition cond, String t1Alias, String t2Alias) + throws SqlParseException { + cond.setOPERATOR(cond.getOPERATOR().simpleReverse()); + String name = cond.getName(); + cond.setName(cond.getValue().toString().replaceFirst(t2Alias + ".", "")); + cond.setValue(name.replaceFirst(t1Alias + ".", "")); + } + + private class FetchWithScrollResponse { + private SearchResponse response; + private boolean needScrollForFirstTable; + + private FetchWithScrollResponse(SearchResponse response, boolean needScrollForFirstTable) { + this.response = response; + this.needScrollForFirstTable = needScrollForFirstTable; } - private void reverseOrderOfCondition(Condition cond, String t1Alias, String t2Alias) throws SqlParseException { - cond.setOPERATOR(cond.getOPERATOR().simpleReverse()); - String name = cond.getName(); - cond.setName(cond.getValue().toString().replaceFirst(t2Alias + ".", "")); - cond.setValue(name.replaceFirst(t1Alias + ".", "")); + public SearchResponse getResponse() { + return response; } - - private class FetchWithScrollResponse { - private SearchResponse response; - private boolean needScrollForFirstTable; - - private FetchWithScrollResponse(SearchResponse response, boolean needScrollForFirstTable) { - this.response = response; - this.needScrollForFirstTable = needScrollForFirstTable; - } - - public SearchResponse getResponse() { - return response; - } - - public boolean isNeedScrollForFirstTable() { - return needScrollForFirstTable; - } - + public boolean isNeedScrollForFirstTable() { + return needScrollForFirstTable; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/QueryPlanElasticExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/QueryPlanElasticExecutor.java index 5702d397d5..f4b2f5421d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/QueryPlanElasticExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/QueryPlanElasticExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.join; import java.util.List; @@ -12,31 +11,30 @@ import org.opensearch.sql.legacy.query.planner.core.QueryPlanner; /** - * Executor for generic QueryPlanner execution. This executor is just acting as adaptor to integrate with - * existing framework. In future, QueryPlanner should be executed by itself and leave the response sent back - * or other post-processing logic to ElasticDefaultRestExecutor. + * Executor for generic QueryPlanner execution. This executor is just acting as adaptor to integrate + * with existing framework. In future, QueryPlanner should be executed by itself and leave the + * response sent back or other post-processing logic to ElasticDefaultRestExecutor. */ class QueryPlanElasticExecutor extends ElasticJoinExecutor { - private final QueryPlanner queryPlanner; - - QueryPlanElasticExecutor(HashJoinQueryPlanRequestBuilder request) { - super(request); - this.queryPlanner = request.plan(); - } - - @Override - protected List innerRun() { - List result = queryPlanner.execute(); - populateMetaResult(); - return result; - } - - private void populateMetaResult() { - metaResults.addTotalNumOfShards(queryPlanner.getMetaResult().getTotalNumOfShards()); - metaResults.addSuccessfulShards(queryPlanner.getMetaResult().getSuccessfulShards()); - metaResults.addFailedShards(queryPlanner.getMetaResult().getFailedShards()); - metaResults.updateTimeOut(queryPlanner.getMetaResult().isTimedOut()); - } - + private final QueryPlanner queryPlanner; + + QueryPlanElasticExecutor(HashJoinQueryPlanRequestBuilder request) { + super(request); + this.queryPlanner = request.plan(); + } + + @Override + protected List innerRun() { + List result = queryPlanner.execute(); + populateMetaResult(); + return result; + } + + private void populateMetaResult() { + metaResults.addTotalNumOfShards(queryPlanner.getMetaResult().getTotalNumOfShards()); + metaResults.addSuccessfulShards(queryPlanner.getMetaResult().getSuccessfulShards()); + metaResults.addFailedShards(queryPlanner.getMetaResult().getFailedShards()); + metaResults.updateTimeOut(queryPlanner.getMetaResult().isTimedOut()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/SearchHitsResult.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/SearchHitsResult.java index 0955de9b88..10a1555874 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/SearchHitsResult.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/SearchHitsResult.java @@ -3,42 +3,39 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.join; import java.util.ArrayList; import java.util.List; import org.opensearch.search.SearchHit; -/** - * Created by Eliran on 28/8/2015. - */ +/** Created by Eliran on 28/8/2015. */ public class SearchHitsResult { - private List searchHits; - private boolean matchedWithOtherTable; + private List searchHits; + private boolean matchedWithOtherTable; - public SearchHitsResult() { - searchHits = new ArrayList<>(); - } + public SearchHitsResult() { + searchHits = new ArrayList<>(); + } - public SearchHitsResult(List searchHits, boolean matchedWithOtherTable) { - this.searchHits = searchHits; - this.matchedWithOtherTable = matchedWithOtherTable; - } + public SearchHitsResult(List searchHits, boolean matchedWithOtherTable) { + this.searchHits = searchHits; + this.matchedWithOtherTable = matchedWithOtherTable; + } - public List getSearchHits() { - return searchHits; - } + public List getSearchHits() { + return searchHits; + } - public void setSearchHits(List searchHits) { - this.searchHits = searchHits; - } + public void setSearchHits(List searchHits) { + this.searchHits = searchHits; + } - public boolean isMatchedWithOtherTable() { - return matchedWithOtherTable; - } + public boolean isMatchedWithOtherTable() { + return matchedWithOtherTable; + } - public void setMatchedWithOtherTable(boolean matchedWithOtherTable) { - this.matchedWithOtherTable = matchedWithOtherTable; - } + public void setMatchedWithOtherTable(boolean matchedWithOtherTable) { + this.matchedWithOtherTable = matchedWithOtherTable; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperation.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperation.java index 0be4dfa786..ea2a698921 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperation.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperation.java @@ -3,39 +3,36 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.operator; import lombok.Getter; import lombok.RequiredArgsConstructor; -/** - * The definition of the Scalar Operation. - */ +/** The definition of the Scalar Operation. */ @Getter @RequiredArgsConstructor public enum ScalarOperation { - ADD("add"), - SUBTRACT("subtract"), - MULTIPLY("multiply"), - DIVIDE("divide"), - MODULES("modules"), - ABS("abs"), - ACOS("acos"), - ASIN("asin"), - ATAN("atan"), - ATAN2("atan2"), - TAN("tan"), - CBRT("cbrt"), - CEIL("ceil"), - COS("cos"), - COSH("cosh"), - EXP("exp"), - FLOOR("floor"), - LN("ln"), - LOG("log"), - LOG2("log2"), - LOG10("log10"); + ADD("add"), + SUBTRACT("subtract"), + MULTIPLY("multiply"), + DIVIDE("divide"), + MODULES("modules"), + ABS("abs"), + ACOS("acos"), + ASIN("asin"), + ATAN("atan"), + ATAN2("atan2"), + TAN("tan"), + CBRT("cbrt"), + CEIL("ceil"), + COS("cos"), + COSH("cosh"), + EXP("exp"), + FLOOR("floor"), + LN("ln"), + LOG("log"), + LOG2("log2"), + LOG10("log10"); - private final String name; + private final String name; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperator.java index bfb3a75afb..c0c3360afc 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/ScalarOperator.java @@ -3,26 +3,25 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.operator; import java.util.List; import org.opensearch.sql.legacy.expression.model.ExprValue; -/** - * Scalar Operator is a function has one or more arguments and return a single value. - */ +/** Scalar Operator is a function has one or more arguments and return a single value. */ public interface ScalarOperator { - /** - * Apply the operator to the input arguments. - * @param valueList argument list. - * @return result. - */ - ExprValue apply(List valueList); + /** + * Apply the operator to the input arguments. + * + * @param valueList argument list. + * @return result. + */ + ExprValue apply(List valueList); - /** - * The name of the operator. - * @return name. - */ - String name(); + /** + * The name of the operator. + * + * @return name. + */ + String name(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/NumericMetric.java b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/NumericMetric.java index 085034bcd2..ee6d373f8f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/NumericMetric.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/NumericMetric.java @@ -3,40 +3,38 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.metrics; public class NumericMetric extends Metric { - private Counter counter; - - public NumericMetric(String name, Counter counter) { - super(name); - this.counter = counter; - } + private Counter counter; - public String getName() { - return super.getName(); - } + public NumericMetric(String name, Counter counter) { + super(name); + this.counter = counter; + } - public Counter getCounter() { - return counter; - } + public String getName() { + return super.getName(); + } - public void increment() { - counter.increment(); - } + public Counter getCounter() { + return counter; + } - public void increment(long n) { - counter.add(n); - } + public void increment() { + counter.increment(); + } - public T getValue() { - return counter.getValue(); - } + public void increment(long n) { + counter.add(n); + } - public void clear() { - counter.reset(); - } + public T getValue() { + return counter.getValue(); + } + public void clear() { + counter.reset(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/RollingCounter.java b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/RollingCounter.java index 1c624d7ffe..c7b9ec56ec 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/RollingCounter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/RollingCounter.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.metrics; import java.time.Clock; @@ -13,87 +12,85 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; /** - * Rolling counter. The count is refreshed every interval. In every interval the count is cumulative. + * Rolling counter. The count is refreshed every interval. In every interval the count is + * cumulative. */ public class RollingCounter implements Counter { - private final long capacity; - private final long window; - private final long interval; - private final Clock clock; - private final ConcurrentSkipListMap time2CountWin; - private final LongAdder count; - - public RollingCounter() { - this( - LocalClusterState.state().getSettingValue( - Settings.Key.METRICS_ROLLING_WINDOW), - LocalClusterState.state().getSettingValue( - Settings.Key.METRICS_ROLLING_INTERVAL)); - } - - public RollingCounter(long window, long interval, Clock clock) { - this.window = window; - this.interval = interval; - this.clock = clock; - time2CountWin = new ConcurrentSkipListMap<>(); - count = new LongAdder(); - capacity = window / interval * 2; - } - - public RollingCounter(long window, long interval) { - this(window, interval, Clock.systemDefaultZone()); + private final long capacity; + private final long window; + private final long interval; + private final Clock clock; + private final ConcurrentSkipListMap time2CountWin; + private final LongAdder count; + + public RollingCounter() { + this( + LocalClusterState.state().getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW), + LocalClusterState.state().getSettingValue(Settings.Key.METRICS_ROLLING_INTERVAL)); + } + + public RollingCounter(long window, long interval, Clock clock) { + this.window = window; + this.interval = interval; + this.clock = clock; + time2CountWin = new ConcurrentSkipListMap<>(); + count = new LongAdder(); + capacity = window / interval * 2; + } + + public RollingCounter(long window, long interval) { + this(window, interval, Clock.systemDefaultZone()); + } + + @Override + public void increment() { + add(1L); + } + + @Override + public void add(long n) { + trim(); + time2CountWin.compute(getKey(clock.millis()), (k, v) -> (v == null) ? n : v + n); + } + + @Override + public Long getValue() { + return getValue(getPreKey(clock.millis())); + } + + public long getValue(long key) { + Long res = time2CountWin.get(key); + if (res == null) { + return 0; } - @Override - public void increment() { - add(1L); - } + return res; + } - @Override - public void add(long n) { - trim(); - time2CountWin.compute(getKey(clock.millis()), (k, v) -> (v == null) ? n : v + n); - } + public long getSum() { + return count.longValue(); + } - @Override - public Long getValue() { - return getValue(getPreKey(clock.millis())); + private void trim() { + if (time2CountWin.size() > capacity) { + time2CountWin.headMap(getKey(clock.millis() - window * 1000)).clear(); } + } - public long getValue(long key) { - Long res = time2CountWin.get(key); - if (res == null) { - return 0; - } + private long getKey(long millis) { + return millis / 1000 / this.interval; + } - return res; - } - - public long getSum() { - return count.longValue(); - } + private long getPreKey(long millis) { + return getKey(millis) - 1; + } - private void trim() { - if (time2CountWin.size() > capacity) { - time2CountWin.headMap(getKey(clock.millis() - window * 1000)).clear(); - } - } - - private long getKey(long millis) { - return millis / 1000 / this.interval; - } - - private long getPreKey(long millis) { - return getKey(millis) - 1; - } - - public int size() { - return time2CountWin.size(); - } - - public void reset() { - time2CountWin.clear(); - } + public int size() { + return time2CountWin.size(); + } + public void reset() { + time2CountWin.clear(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/NestedType.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/NestedType.java index d9b7886310..4deeba1309 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/NestedType.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/NestedType.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; import com.alibaba.druid.sql.ast.SQLExpr; @@ -18,111 +17,107 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by Eliran on 12/11/2015. - */ +/** Created by Eliran on 12/11/2015. */ public class NestedType { - public String field; - public String path; - public Where where; - private boolean reverse; - private boolean simple; - private final BucketPath bucketPath = new BucketPath(); - - public boolean tryFillFromExpr(SQLExpr expr) throws SqlParseException { - if (!(expr instanceof SQLMethodInvokeExpr)) { - return false; - } - SQLMethodInvokeExpr method = (SQLMethodInvokeExpr) expr; - String methodNameLower = method.getMethodName().toLowerCase(); - if (!(methodNameLower.equals("nested") || methodNameLower.equals("reverse_nested"))) { - return false; - } + public String field; + public String path; + public Where where; + private boolean reverse; + private boolean simple; + private final BucketPath bucketPath = new BucketPath(); + + public boolean tryFillFromExpr(SQLExpr expr) throws SqlParseException { + if (!(expr instanceof SQLMethodInvokeExpr)) { + return false; + } + SQLMethodInvokeExpr method = (SQLMethodInvokeExpr) expr; + String methodNameLower = method.getMethodName().toLowerCase(); + if (!(methodNameLower.equals("nested") || methodNameLower.equals("reverse_nested"))) { + return false; + } - reverse = methodNameLower.equals("reverse_nested"); + reverse = methodNameLower.equals("reverse_nested"); - List parameters = method.getParameters(); - if (parameters.size() != 2 && parameters.size() != 1) { - throw new IllegalArgumentException("on nested object only allowed 2 parameters " - + "(field,path)/(path,conditions..) or 1 parameter (field) "); - } + List parameters = method.getParameters(); + if (parameters.size() != 2 && parameters.size() != 1) { + throw new IllegalArgumentException( + "on nested object only allowed 2 parameters " + + "(field,path)/(path,conditions..) or 1 parameter (field) "); + } - String field = Util.extendedToString(parameters.get(0)); - this.field = field; - if (parameters.size() == 1) { - //calc path myself.. - if (!field.contains(".")) { - if (!reverse) { - throw new IllegalArgumentException("Illegal nested field name: " + field); - } else { - this.path = null; - this.simple = true; - } - } else { - int lastDot = field.lastIndexOf("."); - this.path = field.substring(0, lastDot); - this.simple = true; - - } - - } else if (parameters.size() == 2) { - SQLExpr secondParameter = parameters.get(1); - if (secondParameter instanceof SQLTextLiteralExpr || secondParameter instanceof SQLIdentifierExpr - || secondParameter instanceof SQLPropertyExpr) { - - String pathString = Util.extendedToString(secondParameter); - if (pathString.equals("")) { - this.path = null; - } else { - this.path = pathString; - } - this.simple = true; - } else { - this.path = field; - Where where = Where.newInstance(); - new WhereParser(new SqlParser()).parseWhere(secondParameter, where); - if (where.getWheres().size() == 0) { - throw new SqlParseException("Failed to parse filter condition"); - } - this.where = where; - simple = false; - } + String field = Util.extendedToString(parameters.get(0)); + this.field = field; + if (parameters.size() == 1) { + // calc path myself.. + if (!field.contains(".")) { + if (!reverse) { + throw new IllegalArgumentException("Illegal nested field name: " + field); + } else { + this.path = null; + this.simple = true; } - - return true; + } else { + int lastDot = field.lastIndexOf("."); + this.path = field.substring(0, lastDot); + this.simple = true; + } + + } else if (parameters.size() == 2) { + SQLExpr secondParameter = parameters.get(1); + if (secondParameter instanceof SQLTextLiteralExpr + || secondParameter instanceof SQLIdentifierExpr + || secondParameter instanceof SQLPropertyExpr) { + + String pathString = Util.extendedToString(secondParameter); + if (pathString.equals("")) { + this.path = null; + } else { + this.path = pathString; + } + this.simple = true; + } else { + this.path = field; + Where where = Where.newInstance(); + new WhereParser(new SqlParser()).parseWhere(secondParameter, where); + if (where.getWheres().size() == 0) { + throw new SqlParseException("Failed to parse filter condition"); + } + this.where = where; + simple = false; + } } - public boolean isSimple() { - return simple; - } + return true; + } - public boolean isReverse() { - return reverse; - } + public boolean isSimple() { + return simple; + } - /** - * Return the name of the Nested Aggregation. - */ - public String getNestedAggName() { - return field + "@NESTED"; - } + public boolean isReverse() { + return reverse; + } - /** - * Return the name of the Filter Aggregation - */ - public String getFilterAggName() { - return field + "@FILTER"; - } + /** Return the name of the Nested Aggregation. */ + public String getNestedAggName() { + return field + "@NESTED"; + } - public void addBucketPath(Path path) { - bucketPath.add(path); - } + /** Return the name of the Filter Aggregation */ + public String getFilterAggName() { + return field + "@FILTER"; + } - public String getBucketPath() { - return bucketPath.getBucketPath(); - } + public void addBucketPath(Path path) { + bucketPath.add(path); + } + + public String getBucketPath() { + return bucketPath.getBucketPath(); + } /** + *

      * Return true if the filed is the nested filed.
      * For example, the mapping
      * {
@@ -138,6 +133,7 @@ public String getBucketPath() {
      * 

* If the filed is projects, return true. * If the filed is projects.name, return false. + *

*/ public boolean isNestedField() { return !field.contains(".") && field.equalsIgnoreCase(path); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ScriptFilter.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ScriptFilter.java index 3eb4fecf67..3f9b12ca84 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ScriptFilter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ScriptFilter.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; import com.alibaba.druid.sql.ast.SQLExpr; @@ -16,96 +15,92 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by Eliran on 11/12/2015. - */ +/** Created by Eliran on 11/12/2015. */ public class ScriptFilter { - private String script; - private Map args; - private ScriptType scriptType; + private String script; + private Map args; + private ScriptType scriptType; - public ScriptFilter() { + public ScriptFilter() { - args = null; - scriptType = ScriptType.INLINE; - } + args = null; + scriptType = ScriptType.INLINE; + } - public ScriptFilter(String script, Map args, ScriptType scriptType) { - this.script = script; - this.args = args; - this.scriptType = scriptType; - } + public ScriptFilter(String script, Map args, ScriptType scriptType) { + this.script = script; + this.args = args; + this.scriptType = scriptType; + } - public boolean tryParseFromMethodExpr(SQLMethodInvokeExpr expr) throws SqlParseException { - if (!expr.getMethodName().toLowerCase().equals("script")) { - return false; - } - List methodParameters = expr.getParameters(); - if (methodParameters.size() == 0) { - return false; - } - script = Util.extendedToString(methodParameters.get(0)); - - if (methodParameters.size() == 1) { - return true; - } - - args = new HashMap<>(); - for (int i = 1; i < methodParameters.size(); i++) { - - SQLExpr innerExpr = methodParameters.get(i); - if (!(innerExpr instanceof SQLBinaryOpExpr)) { - return false; - } - SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) innerExpr; - if (!binaryOpExpr.getOperator().getName().equals("=")) { - return false; - } - - SQLExpr right = binaryOpExpr.getRight(); - Object value = Util.expr2Object(right); - String key = Util.extendedToString(binaryOpExpr.getLeft()); - if (key.equals("script_type")) { - parseAndUpdateScriptType(value.toString()); - } else { - args.put(key, value); - } - - } - return true; + public boolean tryParseFromMethodExpr(SQLMethodInvokeExpr expr) throws SqlParseException { + if (!expr.getMethodName().toLowerCase().equals("script")) { + return false; } - - private void parseAndUpdateScriptType(String scriptType) { - String scriptTypeUpper = scriptType.toUpperCase(); - switch (scriptTypeUpper) { - case "INLINE": - this.scriptType = ScriptType.INLINE; - break; - case "INDEXED": - case "STORED": - this.scriptType = ScriptType.STORED; - break; - } + List methodParameters = expr.getParameters(); + if (methodParameters.size() == 0) { + return false; } + script = Util.extendedToString(methodParameters.get(0)); - public boolean containsParameters() { - return args != null && args.size() > 0; + if (methodParameters.size() == 1) { + return true; } - public String getScript() { - return script; + args = new HashMap<>(); + for (int i = 1; i < methodParameters.size(); i++) { + + SQLExpr innerExpr = methodParameters.get(i); + if (!(innerExpr instanceof SQLBinaryOpExpr)) { + return false; + } + SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) innerExpr; + if (!binaryOpExpr.getOperator().getName().equals("=")) { + return false; + } + + SQLExpr right = binaryOpExpr.getRight(); + Object value = Util.expr2Object(right); + String key = Util.extendedToString(binaryOpExpr.getLeft()); + if (key.equals("script_type")) { + parseAndUpdateScriptType(value.toString()); + } else { + args.put(key, value); + } } - - public ScriptType getScriptType() { - return scriptType; + return true; + } + + private void parseAndUpdateScriptType(String scriptType) { + String scriptTypeUpper = scriptType.toUpperCase(); + switch (scriptTypeUpper) { + case "INLINE": + this.scriptType = ScriptType.INLINE; + break; + case "INDEXED": + case "STORED": + this.scriptType = ScriptType.STORED; + break; } + } - public Map getArgs() { - return args; - } + public boolean containsParameters() { + return args != null && args.size() > 0; + } - public void setArgs(Map args) { - this.args = args; - } + public String getScript() { + return script; + } + + public ScriptType getScriptType() { + return scriptType; + } + + public Map getArgs() { + return args; + } + public void setArgs(Map args) { + this.args = args; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/SelectParser.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/SelectParser.java index 85becdaa53..62a63b320f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/SelectParser.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/SelectParser.java @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; -/** - * Created by allwefantasy on 9/2/16. - */ -public class SelectParser { -} +/** Created by allwefantasy on 9/2/16. */ +public class SelectParser {} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index cd8056aed1..12176d4fa7 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.OK; @@ -42,8 +41,10 @@ /** * New SQL REST action handler. This will not be registered to OpenSearch unless: - * 1) we want to test new SQL engine; - * 2) all old functionalities migrated to new query engine and legacy REST handler removed. + *
    + *
  1. we want to test new SQL engine; + *
  2. all old functionalities migrated to new query engine and legacy REST handler removed. + *
*/ public class RestSQLQueryAction extends BaseRestHandler { @@ -53,9 +54,7 @@ public class RestSQLQueryAction extends BaseRestHandler { private final Injector injector; - /** - * Constructor of RestSQLQueryAction. - */ + /** Constructor of RestSQLQueryAction. */ public RestSQLQueryAction(Injector injector) { super(); this.injector = injector; @@ -105,7 +104,7 @@ public RestChannelConsumer prepareRequest( fallbackHandler)); } // If close request, sqlService.closeCursor - else { + else { return channel -> sqlService.execute( request, @@ -123,8 +122,7 @@ private ResponseListener fallBackListener( return new ResponseListener() { @Override public void onResponse(T response) { - LOG.info("[{}] Request is handled by new SQL query engine", - QueryContext.getRequestId()); + LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId()); next.onResponse(response); } @@ -144,12 +142,13 @@ private ResponseListener createExplainResponseListener( return new ResponseListener<>() { @Override public void onResponse(ExplainResponse response) { - JsonResponseFormatter formatter = new JsonResponseFormatter<>(PRETTY) { - @Override - protected Object buildJsonObject(ExplainResponse response) { - return response; - } - }; + JsonResponseFormatter formatter = + new JsonResponseFormatter<>(PRETTY) { + @Override + protected Object buildJsonObject(ExplainResponse response) { + return response; + } + }; sendResponse(channel, OK, formatter.format(response), formatter.contentType()); } @@ -179,9 +178,12 @@ private ResponseListener createQueryResponseListener( return new ResponseListener() { @Override public void onResponse(QueryResponse response) { - sendResponse(channel, OK, - formatter.format(new QueryResult(response.getSchema(), response.getResults(), - response.getCursor())), formatter.contentType()); + sendResponse( + channel, + OK, + formatter.format( + new QueryResult(response.getSchema(), response.getResults(), response.getCursor())), + formatter.contentType()); } @Override @@ -191,9 +193,9 @@ public void onFailure(Exception e) { }; } - private void sendResponse(RestChannel channel, RestStatus status, String content, String contentType) { - channel.sendResponse(new BytesRestResponse( - status, contentType, content)); + private void sendResponse( + RestChannel channel, RestStatus status, String content, String contentType) { + channel.sendResponse(new BytesRestResponse(status, contentType, content)); } private static void logAndPublishMetrics(Exception e) { diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 69ed469fed..fc8934dd73 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; @@ -67,233 +66,263 @@ public class RestSqlAction extends BaseRestHandler { - private static final Logger LOG = LogManager.getLogger(RestSqlAction.class); - - private final boolean allowExplicitIndex; - - private static final Predicate CONTAINS_SUBQUERY = Pattern.compile("\\(\\s*select ").asPredicate(); - - /** - * API endpoint path - */ - public static final String QUERY_API_ENDPOINT = "/_plugins/_sql"; - public static final String EXPLAIN_API_ENDPOINT = QUERY_API_ENDPOINT + "/_explain"; - public static final String CURSOR_CLOSE_ENDPOINT = QUERY_API_ENDPOINT + "/close"; - public static final String LEGACY_QUERY_API_ENDPOINT = "/_opendistro/_sql"; - public static final String LEGACY_EXPLAIN_API_ENDPOINT = LEGACY_QUERY_API_ENDPOINT + "/_explain"; - public static final String LEGACY_CURSOR_CLOSE_ENDPOINT = LEGACY_QUERY_API_ENDPOINT + "/close"; - - /** - * New SQL query request handler. - */ - private final RestSQLQueryAction newSqlQueryHandler; - - public RestSqlAction(Settings settings, Injector injector) { - super(); - this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); - this.newSqlQueryHandler = new RestSQLQueryAction(injector); - } - - @Override - public List routes() { - return ImmutableList.of(); - } - - @Override - public List replacedRoutes() { - return ImmutableList.of( - new ReplacedRoute( - RestRequest.Method.POST, QUERY_API_ENDPOINT, - RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT), - new ReplacedRoute( - RestRequest.Method.POST, EXPLAIN_API_ENDPOINT, - RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT), - new ReplacedRoute( - RestRequest.Method.POST, CURSOR_CLOSE_ENDPOINT, - RestRequest.Method.POST, LEGACY_CURSOR_CLOSE_ENDPOINT)); - } - - @Override - public String getName() { - return "sql_action"; - } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { - Metrics.getInstance().getNumericalMetric(MetricName.REQ_TOTAL).increment(); - Metrics.getInstance().getNumericalMetric(MetricName.REQ_COUNT_TOTAL).increment(); - - QueryContext.addRequestId(); - - try { - if (!isSQLFeatureEnabled()) { - throw new SQLFeatureDisabledException( - "Either plugins.sql.enabled or rest.action.multi.allow_explicit_index setting is false" - ); - } - - final SqlRequest sqlRequest = SqlRequestFactory.getSqlRequest(request); - if (isLegacyCursor(sqlRequest)) { - if (isExplainRequest(request)) { - throw new IllegalArgumentException("Invalid request. Cannot explain cursor"); - } else { - LOG.info("[{}] Cursor request {}: {}", QueryContext.getRequestId(), request.uri(), sqlRequest.cursor()); - return channel -> handleCursorRequest(request, sqlRequest.cursor(), client, channel); - } - } - - LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), request.uri()); - - Format format = SqlRequestParam.getFormat(request.params()); - - // Route request to new query engine if it's supported already - SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(), - sqlRequest.getSql(), request.path(), request.params(), sqlRequest.cursor()); - return newSqlQueryHandler.prepareRequest(newSqlRequest, - (restChannel, exception) -> { - try{ - if (newSqlRequest.isExplainRequest()) { - LOG.info("Request is falling back to old SQL engine due to: " + exception.getMessage()); - } - LOG.info("[{}] Request {} is not supported and falling back to old SQL engine", - QueryContext.getRequestId(), newSqlRequest); - LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); - QueryAction queryAction = explainRequest(client, sqlRequest, format); - executeSqlRequest(request, queryAction, client, restChannel); - } catch (Exception e) { - logAndPublishMetrics(e); - reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); - } - }, - (restChannel, exception) -> { - logAndPublishMetrics(exception); - reportError(restChannel, exception, isClientError(exception) ? - BAD_REQUEST : SERVICE_UNAVAILABLE); - }); - } catch (Exception e) { - logAndPublishMetrics(e); - return channel -> reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); - } - } - - - /** - * @param sqlRequest client request - * @return true if this cursor was generated by the legacy engine, false otherwise. - */ - private static boolean isLegacyCursor(SqlRequest sqlRequest) { - String cursor = sqlRequest.cursor(); - return cursor != null - && CursorType.getById(cursor.substring(0, 1)) != CursorType.NULL; - } - - @Override - protected Set responseParams() { - Set responseParams = new HashSet<>(super.responseParams()); - responseParams.addAll(Arrays.asList("sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize")); - return responseParams; - } - - private void handleCursorRequest(final RestRequest request, final String cursor, final Client client, - final RestChannel channel) throws Exception { - CursorAsyncRestExecutor cursorRestExecutor = CursorActionRequestRestExecutorFactory.createExecutor( - request, cursor, SqlRequestParam.getFormat(request.params())); - cursorRestExecutor.execute(client, request.params(), channel); - } - - private static void logAndPublishMetrics(final Exception e) { - if (isClientError(e)) { - LOG.error(QueryContext.getRequestId() + " Client side error during query execution", e); - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); - } else { - LOG.error(QueryContext.getRequestId() + " Server side error during query execution", e); - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - } - } - - private static QueryAction explainRequest(final NodeClient client, final SqlRequest sqlRequest, Format format) - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - - ColumnTypeProvider typeProvider = performAnalysis(sqlRequest.getSql()); - - final QueryAction queryAction = new SearchDao(client) - .explain(new QueryActionRequest(sqlRequest.getSql(), typeProvider, format)); - queryAction.setSqlRequest(sqlRequest); - queryAction.setFormat(format); - queryAction.setColumnTypeProvider(typeProvider); - return queryAction; - } - - private void executeSqlRequest(final RestRequest request, final QueryAction queryAction, final Client client, - final RestChannel channel) throws Exception { - Map params = request.params(); + private static final Logger LOG = LogManager.getLogger(RestSqlAction.class); + + private final boolean allowExplicitIndex; + + private static final Predicate CONTAINS_SUBQUERY = + Pattern.compile("\\(\\s*select ").asPredicate(); + + /** API endpoint path */ + public static final String QUERY_API_ENDPOINT = "/_plugins/_sql"; + + public static final String EXPLAIN_API_ENDPOINT = QUERY_API_ENDPOINT + "/_explain"; + public static final String CURSOR_CLOSE_ENDPOINT = QUERY_API_ENDPOINT + "/close"; + public static final String LEGACY_QUERY_API_ENDPOINT = "/_opendistro/_sql"; + public static final String LEGACY_EXPLAIN_API_ENDPOINT = LEGACY_QUERY_API_ENDPOINT + "/_explain"; + public static final String LEGACY_CURSOR_CLOSE_ENDPOINT = LEGACY_QUERY_API_ENDPOINT + "/close"; + + /** New SQL query request handler. */ + private final RestSQLQueryAction newSqlQueryHandler; + + public RestSqlAction(Settings settings, Injector injector) { + super(); + this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); + this.newSqlQueryHandler = new RestSQLQueryAction(injector); + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList.of( + new ReplacedRoute( + RestRequest.Method.POST, QUERY_API_ENDPOINT, + RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT), + new ReplacedRoute( + RestRequest.Method.POST, EXPLAIN_API_ENDPOINT, + RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT), + new ReplacedRoute( + RestRequest.Method.POST, CURSOR_CLOSE_ENDPOINT, + RestRequest.Method.POST, LEGACY_CURSOR_CLOSE_ENDPOINT)); + } + + @Override + public String getName() { + return "sql_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + Metrics.getInstance().getNumericalMetric(MetricName.REQ_TOTAL).increment(); + Metrics.getInstance().getNumericalMetric(MetricName.REQ_COUNT_TOTAL).increment(); + + QueryContext.addRequestId(); + + try { + if (!isSQLFeatureEnabled()) { + throw new SQLFeatureDisabledException( + "Either plugins.sql.enabled or rest.action.multi.allow_explicit_index setting is" + + " false"); + } + + final SqlRequest sqlRequest = SqlRequestFactory.getSqlRequest(request); + if (isLegacyCursor(sqlRequest)) { if (isExplainRequest(request)) { - final String jsonExplanation = queryAction.explain().explain(); - String result; - if (SqlRequestParam.isPrettyFormat(params)) { - result = JsonPrettyFormatter.format(jsonExplanation); - } else { - result = jsonExplanation; - } - channel.sendResponse(new BytesRestResponse(OK, "application/json; charset=UTF-8", result)); + throw new IllegalArgumentException("Invalid request. Cannot explain cursor"); } else { - RestExecutor restExecutor = ActionRequestRestExecutorFactory.createExecutor( - SqlRequestParam.getFormat(params), - queryAction); - //doing this hack because OpenSearch throws exception for un-consumed props - Map additionalParams = new HashMap<>(); - for (String paramName : responseParams()) { - if (request.hasParam(paramName)) { - additionalParams.put(paramName, request.param(paramName)); - } - } - restExecutor.execute(client, additionalParams, queryAction, channel); + LOG.info( + "[{}] Cursor request {}: {}", + QueryContext.getRequestId(), + request.uri(), + sqlRequest.cursor()); + return channel -> handleCursorRequest(request, sqlRequest.cursor(), client, channel); } + } + + LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), request.uri()); + + Format format = SqlRequestParam.getFormat(request.params()); + + // Route request to new query engine if it's supported already + SQLQueryRequest newSqlRequest = + new SQLQueryRequest( + sqlRequest.getJsonContent(), + sqlRequest.getSql(), + request.path(), + request.params(), + sqlRequest.cursor()); + return newSqlQueryHandler.prepareRequest( + newSqlRequest, + (restChannel, exception) -> { + try { + if (newSqlRequest.isExplainRequest()) { + LOG.info( + "Request is falling back to old SQL engine due to: " + exception.getMessage()); + } + LOG.info( + "[{}] Request {} is not supported and falling back to old SQL engine", + QueryContext.getRequestId(), + newSqlRequest); + LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); + QueryAction queryAction = explainRequest(client, sqlRequest, format); + executeSqlRequest(request, queryAction, client, restChannel); + } catch (Exception e) { + logAndPublishMetrics(e); + reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + } + }, + (restChannel, exception) -> { + logAndPublishMetrics(exception); + reportError( + restChannel, + exception, + isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + }); + } catch (Exception e) { + logAndPublishMetrics(e); + return channel -> + reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); } - - private static boolean isExplainRequest(final RestRequest request) { - return request.path().endsWith("/_explain"); - } - - private static boolean isClientError(Exception e) { - return e instanceof NullPointerException // NPE is hard to differentiate but more likely caused by bad query - || e instanceof SqlParseException - || e instanceof ParserException - || e instanceof SQLFeatureNotSupportedException - || e instanceof SQLFeatureDisabledException - || e instanceof IllegalArgumentException - || e instanceof IndexNotFoundException - || e instanceof VerificationException - || e instanceof SqlAnalysisException - || e instanceof SyntaxCheckException - || e instanceof SemanticCheckException - || e instanceof ExpressionEvaluationException; - } - - private void sendResponse(final RestChannel channel, final String message, final RestStatus status) { - channel.sendResponse(new BytesRestResponse(status, message)); - } - - private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { - sendResponse(channel, ErrorMessageFactory.createErrorMessage(e, status.getStatus()).toString(), status); - } - - private boolean isSQLFeatureEnabled() { - boolean isSqlEnabled = LocalClusterState.state().getSettingValue( - org.opensearch.sql.common.setting.Settings.Key.SQL_ENABLED); - return allowExplicitIndex && isSqlEnabled; + } + + /** + * @param sqlRequest client request + * @return true if this cursor was generated by the legacy engine, false otherwise. + */ + private static boolean isLegacyCursor(SqlRequest sqlRequest) { + String cursor = sqlRequest.cursor(); + return cursor != null && CursorType.getById(cursor.substring(0, 1)) != CursorType.NULL; + } + + @Override + protected Set responseParams() { + Set responseParams = new HashSet<>(super.responseParams()); + responseParams.addAll( + Arrays.asList( + "sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize")); + return responseParams; + } + + private void handleCursorRequest( + final RestRequest request, + final String cursor, + final Client client, + final RestChannel channel) + throws Exception { + CursorAsyncRestExecutor cursorRestExecutor = + CursorActionRequestRestExecutorFactory.createExecutor( + request, cursor, SqlRequestParam.getFormat(request.params())); + cursorRestExecutor.execute(client, request.params(), channel); + } + + private static void logAndPublishMetrics(final Exception e) { + if (isClientError(e)) { + LOG.error(QueryContext.getRequestId() + " Client side error during query execution", e); + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); + } else { + LOG.error(QueryContext.getRequestId() + " Server side error during query execution", e); + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); } - - private static ColumnTypeProvider performAnalysis(String sql) { - LocalClusterState clusterState = LocalClusterState.state(); - SqlAnalysisConfig config = new SqlAnalysisConfig(false, false, 200); - - OpenSearchLegacySqlAnalyzer analyzer = new OpenSearchLegacySqlAnalyzer(config); - Optional outputColumnType = analyzer.analyze(sql, clusterState); - if (outputColumnType.isPresent()) { - return new ColumnTypeProvider(outputColumnType.get()); - } else { - return new ColumnTypeProvider(); + } + + private static QueryAction explainRequest( + final NodeClient client, final SqlRequest sqlRequest, Format format) + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + + ColumnTypeProvider typeProvider = performAnalysis(sqlRequest.getSql()); + + final QueryAction queryAction = + new SearchDao(client) + .explain(new QueryActionRequest(sqlRequest.getSql(), typeProvider, format)); + queryAction.setSqlRequest(sqlRequest); + queryAction.setFormat(format); + queryAction.setColumnTypeProvider(typeProvider); + return queryAction; + } + + private void executeSqlRequest( + final RestRequest request, + final QueryAction queryAction, + final Client client, + final RestChannel channel) + throws Exception { + Map params = request.params(); + if (isExplainRequest(request)) { + final String jsonExplanation = queryAction.explain().explain(); + String result; + if (SqlRequestParam.isPrettyFormat(params)) { + result = JsonPrettyFormatter.format(jsonExplanation); + } else { + result = jsonExplanation; + } + channel.sendResponse(new BytesRestResponse(OK, "application/json; charset=UTF-8", result)); + } else { + RestExecutor restExecutor = + ActionRequestRestExecutorFactory.createExecutor( + SqlRequestParam.getFormat(params), queryAction); + // doing this hack because OpenSearch throws exception for un-consumed props + Map additionalParams = new HashMap<>(); + for (String paramName : responseParams()) { + if (request.hasParam(paramName)) { + additionalParams.put(paramName, request.param(paramName)); } + } + restExecutor.execute(client, additionalParams, queryAction, channel); + } + } + + private static boolean isExplainRequest(final RestRequest request) { + return request.path().endsWith("/_explain"); + } + + private static boolean isClientError(Exception e) { + return e + instanceof + NullPointerException // NPE is hard to differentiate but more likely caused by bad query + || e instanceof SqlParseException + || e instanceof ParserException + || e instanceof SQLFeatureNotSupportedException + || e instanceof SQLFeatureDisabledException + || e instanceof IllegalArgumentException + || e instanceof IndexNotFoundException + || e instanceof VerificationException + || e instanceof SqlAnalysisException + || e instanceof SyntaxCheckException + || e instanceof SemanticCheckException + || e instanceof ExpressionEvaluationException; + } + + private void sendResponse( + final RestChannel channel, final String message, final RestStatus status) { + channel.sendResponse(new BytesRestResponse(status, message)); + } + + private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { + sendResponse( + channel, ErrorMessageFactory.createErrorMessage(e, status.getStatus()).toString(), status); + } + + private boolean isSQLFeatureEnabled() { + boolean isSqlEnabled = + LocalClusterState.state() + .getSettingValue(org.opensearch.sql.common.setting.Settings.Key.SQL_ENABLED); + return allowExplicitIndex && isSqlEnabled; + } + + private static ColumnTypeProvider performAnalysis(String sql) { + LocalClusterState clusterState = LocalClusterState.state(); + SqlAnalysisConfig config = new SqlAnalysisConfig(false, false, 200); + + OpenSearchLegacySqlAnalyzer analyzer = new OpenSearchLegacySqlAnalyzer(config); + Optional outputColumnType = analyzer.analyze(sql, clusterState); + if (outputColumnType.isPresent()) { + return new ColumnTypeProvider(outputColumnType.get()); + } else { + return new ColumnTypeProvider(); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java index cf3a3e3f96..bc0f3c73b8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; @@ -27,64 +26,69 @@ import org.opensearch.sql.legacy.metrics.Metrics; /** - * Currently this interface is for node level. - * Cluster level is coming up soon. https://github.com/opendistro-for-elasticsearch/sql/issues/41 + * Currently this interface is for node level. Cluster level is coming up soon. + * https://github.com/opendistro-for-elasticsearch/sql/issues/41 */ public class RestSqlStatsAction extends BaseRestHandler { - private static final Logger LOG = LogManager.getLogger(RestSqlStatsAction.class); - - /** - * API endpoint path - */ - public static final String STATS_API_ENDPOINT = "/_plugins/_sql/stats"; - public static final String LEGACY_STATS_API_ENDPOINT = "/_opendistro/_sql/stats"; - - public RestSqlStatsAction(Settings settings, RestController restController) { - super(); - } - - @Override - public String getName() { - return "sql_stats_action"; - } - - @Override - public List routes() { - return ImmutableList.of(); - } - - @Override - public List replacedRoutes() { - return ImmutableList.of( - new ReplacedRoute( - RestRequest.Method.POST, STATS_API_ENDPOINT, - RestRequest.Method.POST, LEGACY_STATS_API_ENDPOINT), - new ReplacedRoute( - RestRequest.Method.GET, STATS_API_ENDPOINT, - RestRequest.Method.GET, LEGACY_STATS_API_ENDPOINT)); + private static final Logger LOG = LogManager.getLogger(RestSqlStatsAction.class); + + /** API endpoint path */ + public static final String STATS_API_ENDPOINT = "/_plugins/_sql/stats"; + + public static final String LEGACY_STATS_API_ENDPOINT = "/_opendistro/_sql/stats"; + + public RestSqlStatsAction(Settings settings, RestController restController) { + super(); + } + + @Override + public String getName() { + return "sql_stats_action"; + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList.of( + new ReplacedRoute( + RestRequest.Method.POST, STATS_API_ENDPOINT, + RestRequest.Method.POST, LEGACY_STATS_API_ENDPOINT), + new ReplacedRoute( + RestRequest.Method.GET, STATS_API_ENDPOINT, + RestRequest.Method.GET, LEGACY_STATS_API_ENDPOINT)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + + QueryContext.addRequestId(); + + try { + return channel -> + channel.sendResponse( + new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())); + } catch (Exception e) { + LOG.error("Failed during Query SQL STATS Action.", e); + + return channel -> + channel.sendResponse( + new BytesRestResponse( + SERVICE_UNAVAILABLE, + ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + .toString())); } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { - - QueryContext.addRequestId(); - - try { - return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.OK, - Metrics.getInstance().collectToJSON())); - } catch (Exception e) { - LOG.error("Failed during Query SQL STATS Action.", e); - - return channel -> channel.sendResponse(new BytesRestResponse(SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()).toString())); - } - } - - @Override - protected Set responseParams() { - Set responseParams = new HashSet<>(super.responseParams()); - responseParams.addAll(Arrays.asList("sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize")); - return responseParams; - } - + } + + @Override + protected Set responseParams() { + Set responseParams = new HashSet<>(super.responseParams()); + responseParams.addAll( + Arrays.asList( + "sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize")); + return responseParams; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/SearchDao.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/SearchDao.java index a18895723c..ea4e08281c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/SearchDao.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/SearchDao.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.plugin; import java.sql.SQLFeatureNotSupportedException; @@ -16,39 +15,36 @@ import org.opensearch.sql.legacy.query.OpenSearchActionFactory; import org.opensearch.sql.legacy.query.QueryAction; - public class SearchDao { - private static final Set END_TABLE_MAP = new HashSet<>(); - - static { - END_TABLE_MAP.add("limit"); - END_TABLE_MAP.add("order"); - END_TABLE_MAP.add("where"); - END_TABLE_MAP.add("group"); - - } - - private Client client = null; - - public SearchDao(Client client) { - this.client = client; - } - - public Client getClient() { - return client; - } - - /** - * Prepare action And transform sql - * into OpenSearch ActionRequest - * - * @param queryActionRequest SQL query action request to execute. - * @return OpenSearch request - * @throws SqlParseException - */ - public QueryAction explain(QueryActionRequest queryActionRequest) - throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { - return OpenSearchActionFactory.create(client, queryActionRequest); - } + private static final Set END_TABLE_MAP = new HashSet<>(); + + static { + END_TABLE_MAP.add("limit"); + END_TABLE_MAP.add("order"); + END_TABLE_MAP.add("where"); + END_TABLE_MAP.add("group"); + } + + private Client client = null; + + public SearchDao(Client client) { + this.client = client; + } + + public Client getClient() { + return client; + } + + /** + * Prepare action And transform sql into OpenSearch ActionRequest + * + * @param queryActionRequest SQL query action request to execute. + * @return OpenSearch request + * @throws SqlParseException + */ + public QueryAction explain(QueryActionRequest queryActionRequest) + throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { + return OpenSearchActionFactory.create(client, queryActionRequest); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/OpenSearchActionFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/OpenSearchActionFactory.java index de7256d2cf..b9a7c9f218 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/OpenSearchActionFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/OpenSearchActionFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; import static org.opensearch.sql.legacy.domain.IndexStatement.StatementType; @@ -65,188 +64,193 @@ public class OpenSearchActionFactory { - public static QueryAction create(Client client, String sql) - throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { - return create(client, new QueryActionRequest(sql, new ColumnTypeProvider(), Format.JSON)); - } - - /** - * Create the compatible Query object - * based on the SQL query. - * - * @param request The SQL query. - * @return Query object. - */ - public static QueryAction create(Client client, QueryActionRequest request) - throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { - String sql = request.getSql(); - // Remove line breaker anywhere and semicolon at the end - sql = sql.replaceAll("\\R", " ").trim(); - if (sql.endsWith(";")) { - sql = sql.substring(0, sql.length() - 1); - } - - switch (getFirstWord(sql)) { - case "SELECT": - SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(sql); - - RewriteRuleExecutor ruleExecutor = RewriteRuleExecutor.builder() - .withRule(new SQLExprParentSetterRule()) - .withRule(new OrdinalRewriterRule(sql)) - .withRule(new UnquoteIdentifierRule()) - .withRule(new TableAliasPrefixRemoveRule()) - .withRule(new SubQueryRewriteRule()) - .build(); - ruleExecutor.executeOn(sqlExpr); - sqlExpr.accept(new NestedFieldRewriter()); - - if (isMulti(sqlExpr)) { - sqlExpr.accept(new TermFieldRewriter(TermRewriterFilter.MULTI_QUERY)); - MultiQuerySelect multiSelect = - new SqlParser().parseMultiSelect((SQLUnionQuery) sqlExpr.getSubQuery().getQuery()); - return new MultiQueryAction(client, multiSelect); - } else if (isJoin(sqlExpr, sql)) { - new JoinRewriteRule(LocalClusterState.state()).rewrite(sqlExpr); - sqlExpr.accept(new TermFieldRewriter(TermRewriterFilter.JOIN)); - JoinSelect joinSelect = new SqlParser().parseJoinSelect(sqlExpr); - return OpenSearchJoinQueryActionFactory.createJoinAction(client, joinSelect); - } else { - sqlExpr.accept(new TermFieldRewriter()); - // migrate aggregation to query planner framework. - if (shouldMigrateToQueryPlan(sqlExpr, request.getFormat())) { - return new QueryPlanQueryAction(new QueryPlanRequestBuilder( - new BindingTupleQueryPlanner(client, sqlExpr, request.getTypeProvider()))); - } - Select select = new SqlParser().parseSelect(sqlExpr); - return handleSelect(client, select); - } - case "DELETE": - if (isSQLDeleteEnabled()) { - SQLStatementParser parser = createSqlStatementParser(sql); - SQLDeleteStatement deleteStatement = parser.parseDeleteStatement(); - Delete delete = new SqlParser().parseDelete(deleteStatement); - return new DeleteQueryAction(client, delete); - } else { - throw new SQLFeatureDisabledException( - StringUtils.format("DELETE clause is disabled by default and will be " - + "deprecated. Using the %s setting to enable it", - Settings.Key.SQL_DELETE_ENABLED.getKeyValue())); - } - case "SHOW": - IndexStatement showStatement = new IndexStatement(StatementType.SHOW, sql); - return new ShowQueryAction(client, showStatement); - case "DESCRIBE": - IndexStatement describeStatement = new IndexStatement(StatementType.DESCRIBE, sql); - return new DescribeQueryAction(client, describeStatement); - default: - throw new SQLFeatureNotSupportedException( - String.format("Query must start with SELECT, DELETE, SHOW or DESCRIBE: %s", sql)); - } - } + public static QueryAction create(Client client, String sql) + throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { + return create(client, new QueryActionRequest(sql, new ColumnTypeProvider(), Format.JSON)); + } - private static boolean isSQLDeleteEnabled() { - return LocalClusterState.state().getSettingValue(Settings.Key.SQL_DELETE_ENABLED); + /** + * Create the compatible Query object based on the SQL query. + * + * @param request The SQL query. + * @return Query object. + */ + public static QueryAction create(Client client, QueryActionRequest request) + throws SqlParseException, SQLFeatureNotSupportedException, SQLFeatureDisabledException { + String sql = request.getSql(); + // Remove line breaker anywhere and semicolon at the end + sql = sql.replaceAll("\\R", " ").trim(); + if (sql.endsWith(";")) { + sql = sql.substring(0, sql.length() - 1); } - private static String getFirstWord(String sql) { - int endOfFirstWord = sql.indexOf(' '); - return sql.substring(0, endOfFirstWord > 0 ? endOfFirstWord : sql.length()).toUpperCase(); - } + switch (getFirstWord(sql)) { + case "SELECT": + SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(sql); - private static boolean isMulti(SQLQueryExpr sqlExpr) { - return sqlExpr.getSubQuery().getQuery() instanceof SQLUnionQuery; - } + RewriteRuleExecutor ruleExecutor = + RewriteRuleExecutor.builder() + .withRule(new SQLExprParentSetterRule()) + .withRule(new OrdinalRewriterRule(sql)) + .withRule(new UnquoteIdentifierRule()) + .withRule(new TableAliasPrefixRemoveRule()) + .withRule(new SubQueryRewriteRule()) + .build(); + ruleExecutor.executeOn(sqlExpr); + sqlExpr.accept(new NestedFieldRewriter()); - private static void executeAndFillSubQuery(Client client, - SubQueryExpression subQueryExpression, - QueryAction queryAction) throws SqlParseException { - List values = new ArrayList<>(); - Object queryResult; - try { - queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); - } catch (Exception e) { - throw new SqlParseException("could not execute SubQuery: " + e.getMessage()); + if (isMulti(sqlExpr)) { + sqlExpr.accept(new TermFieldRewriter(TermRewriterFilter.MULTI_QUERY)); + MultiQuerySelect multiSelect = + new SqlParser().parseMultiSelect((SQLUnionQuery) sqlExpr.getSubQuery().getQuery()); + return new MultiQueryAction(client, multiSelect); + } else if (isJoin(sqlExpr, sql)) { + new JoinRewriteRule(LocalClusterState.state()).rewrite(sqlExpr); + sqlExpr.accept(new TermFieldRewriter(TermRewriterFilter.JOIN)); + JoinSelect joinSelect = new SqlParser().parseJoinSelect(sqlExpr); + return OpenSearchJoinQueryActionFactory.createJoinAction(client, joinSelect); + } else { + sqlExpr.accept(new TermFieldRewriter()); + // migrate aggregation to query planner framework. + if (shouldMigrateToQueryPlan(sqlExpr, request.getFormat())) { + return new QueryPlanQueryAction( + new QueryPlanRequestBuilder( + new BindingTupleQueryPlanner(client, sqlExpr, request.getTypeProvider()))); + } + Select select = new SqlParser().parseSelect(sqlExpr); + return handleSelect(client, select); } - - String returnField = subQueryExpression.getReturnField(); - if (queryResult instanceof SearchHits) { - SearchHits hits = (SearchHits) queryResult; - for (SearchHit hit : hits) { - values.add(ElasticResultHandler.getFieldValue(hit, returnField)); - } + case "DELETE": + if (isSQLDeleteEnabled()) { + SQLStatementParser parser = createSqlStatementParser(sql); + SQLDeleteStatement deleteStatement = parser.parseDeleteStatement(); + Delete delete = new SqlParser().parseDelete(deleteStatement); + return new DeleteQueryAction(client, delete); } else { - throw new SqlParseException("on sub queries only support queries that return Hits and not aggregations"); + throw new SQLFeatureDisabledException( + StringUtils.format( + "DELETE clause is disabled by default and will be " + + "deprecated. Using the %s setting to enable it", + Settings.Key.SQL_DELETE_ENABLED.getKeyValue())); } - subQueryExpression.setValues(values.toArray()); + case "SHOW": + IndexStatement showStatement = new IndexStatement(StatementType.SHOW, sql); + return new ShowQueryAction(client, showStatement); + case "DESCRIBE": + IndexStatement describeStatement = new IndexStatement(StatementType.DESCRIBE, sql); + return new DescribeQueryAction(client, describeStatement); + default: + throw new SQLFeatureNotSupportedException( + String.format("Query must start with SELECT, DELETE, SHOW or DESCRIBE: %s", sql)); } + } - private static QueryAction handleSelect(Client client, Select select) { - if (select.isAggregate) { - return new AggregationQueryAction(client, select); - } else { - return new DefaultQueryAction(client, select); - } + private static boolean isSQLDeleteEnabled() { + return LocalClusterState.state().getSettingValue(Settings.Key.SQL_DELETE_ENABLED); + } + + private static String getFirstWord(String sql) { + int endOfFirstWord = sql.indexOf(' '); + return sql.substring(0, endOfFirstWord > 0 ? endOfFirstWord : sql.length()).toUpperCase(); + } + + private static boolean isMulti(SQLQueryExpr sqlExpr) { + return sqlExpr.getSubQuery().getQuery() instanceof SQLUnionQuery; + } + + private static void executeAndFillSubQuery( + Client client, SubQueryExpression subQueryExpression, QueryAction queryAction) + throws SqlParseException { + List values = new ArrayList<>(); + Object queryResult; + try { + queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); + } catch (Exception e) { + throw new SqlParseException("could not execute SubQuery: " + e.getMessage()); } - private static SQLStatementParser createSqlStatementParser(String sql) { - ElasticLexer lexer = new ElasticLexer(sql); - lexer.nextToken(); - return new MySqlStatementParser(lexer); + String returnField = subQueryExpression.getReturnField(); + if (queryResult instanceof SearchHits) { + SearchHits hits = (SearchHits) queryResult; + for (SearchHit hit : hits) { + values.add(ElasticResultHandler.getFieldValue(hit, returnField)); + } + } else { + throw new SqlParseException( + "on sub queries only support queries that return Hits and not aggregations"); } + subQueryExpression.setValues(values.toArray()); + } - private static boolean isJoin(SQLQueryExpr sqlExpr, String sql) { - MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) sqlExpr.getSubQuery().getQuery(); - return query.getFrom() instanceof SQLJoinTableSource - && ((SQLJoinTableSource) query.getFrom()).getJoinType() != SQLJoinTableSource.JoinType.COMMA; + private static QueryAction handleSelect(Client client, Select select) { + if (select.isAggregate) { + return new AggregationQueryAction(client, select); + } else { + return new DefaultQueryAction(client, select); } + } - @VisibleForTesting - public static boolean shouldMigrateToQueryPlan(SQLQueryExpr expr, Format format) { - // The JSON format will return the OpenSearch aggregation result, which is not supported by the QueryPlanner. - if (format == Format.JSON) { - return false; - } - QueryPlannerScopeDecider decider = new QueryPlannerScopeDecider(); - return decider.isInScope(expr); + private static SQLStatementParser createSqlStatementParser(String sql) { + ElasticLexer lexer = new ElasticLexer(sql); + lexer.nextToken(); + return new MySqlStatementParser(lexer); + } + + private static boolean isJoin(SQLQueryExpr sqlExpr, String sql) { + MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) sqlExpr.getSubQuery().getQuery(); + return query.getFrom() instanceof SQLJoinTableSource + && ((SQLJoinTableSource) query.getFrom()).getJoinType() + != SQLJoinTableSource.JoinType.COMMA; + } + + @VisibleForTesting + public static boolean shouldMigrateToQueryPlan(SQLQueryExpr expr, Format format) { + // The JSON format will return the OpenSearch aggregation result, which is not supported by the + // QueryPlanner. + if (format == Format.JSON) { + return false; } + QueryPlannerScopeDecider decider = new QueryPlannerScopeDecider(); + return decider.isInScope(expr); + } - private static class QueryPlannerScopeDecider extends MySqlASTVisitorAdapter { - private boolean hasAggregationFunc = false; - private boolean hasNestedFunction = false; - private boolean hasGroupBy = false; - private boolean hasAllColumnExpr = false; + private static class QueryPlannerScopeDecider extends MySqlASTVisitorAdapter { + private boolean hasAggregationFunc = false; + private boolean hasNestedFunction = false; + private boolean hasGroupBy = false; + private boolean hasAllColumnExpr = false; - public boolean isInScope(SQLQueryExpr expr) { - expr.accept(this); - return !hasAllColumnExpr && !hasNestedFunction && (hasGroupBy || hasAggregationFunc); - } + public boolean isInScope(SQLQueryExpr expr) { + expr.accept(this); + return !hasAllColumnExpr && !hasNestedFunction && (hasGroupBy || hasAggregationFunc); + } - @Override - public boolean visit(SQLSelectItem expr) { - if (expr.getExpr() instanceof SQLAllColumnExpr) { - hasAllColumnExpr = true; - } - return super.visit(expr); - } + @Override + public boolean visit(SQLSelectItem expr) { + if (expr.getExpr() instanceof SQLAllColumnExpr) { + hasAllColumnExpr = true; + } + return super.visit(expr); + } - @Override - public boolean visit(SQLSelectGroupByClause expr) { - hasGroupBy = true; - return super.visit(expr); - } + @Override + public boolean visit(SQLSelectGroupByClause expr) { + hasGroupBy = true; + return super.visit(expr); + } - @Override - public boolean visit(SQLAggregateExpr expr) { - hasAggregationFunc = true; - return super.visit(expr); - } + @Override + public boolean visit(SQLAggregateExpr expr) { + hasAggregationFunc = true; + return super.visit(expr); + } - @Override - public boolean visit(SQLMethodInvokeExpr expr) { - if (expr.getMethodName().equalsIgnoreCase("nested")) { - hasNestedFunction = true; - } - return super.visit(expr); - } + @Override + public boolean visit(SQLMethodInvokeExpr expr) { + if (expr.getMethodName().equalsIgnoreCase("nested")) { + hasNestedFunction = true; + } + return super.visit(expr); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/QueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/QueryAction.java index 7646639be4..c9b39d2f97 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/QueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/QueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; import com.fasterxml.jackson.core.JsonFactory; @@ -32,199 +31,208 @@ import org.opensearch.sql.legacy.request.SqlRequest; /** - * Abstract class. used to transform Select object (Represents SQL query) to - * SearchRequestBuilder (Represents OpenSearch query) + * Abstract class. used to transform Select object (Represents SQL query) to SearchRequestBuilder + * (Represents OpenSearch query) */ public abstract class QueryAction { - protected Query query; - protected Client client; - protected SqlRequest sqlRequest = SqlRequest.NULL; - protected ColumnTypeProvider scriptColumnType; - protected Format format; - - public QueryAction(Client client, Query query) { - this.client = client; - this.query = query; - } - - public Client getClient() { - return client; - } - - public QueryStatement getQueryStatement() { - return query; - } - - public void setSqlRequest(SqlRequest sqlRequest) { - this.sqlRequest = sqlRequest; - } - - public void setColumnTypeProvider(ColumnTypeProvider scriptColumnType) { - this.scriptColumnType = scriptColumnType; - } - - public SqlRequest getSqlRequest() { - return sqlRequest; - } - - public void setFormat(Format format) { - this.format = format; - } - - public Format getFormat() { - return this.format; - } - - public ColumnTypeProvider getScriptColumnType() { - return scriptColumnType; - } - - /** - * @return List of field names produced by the query - */ - public Optional> getFieldNames() { - return Optional.empty(); - } - - protected void updateRequestWithCollapse(Select select, SearchRequestBuilder request) throws SqlParseException { - JsonFactory jsonFactory = new JsonFactory(); - for (Hint hint : select.getHints()) { - if (hint.getType() == HintType.COLLAPSE && hint.getParams() != null && 0 < hint.getParams().length) { - try (JsonXContentParser parser = new JsonXContentParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, jsonFactory.createParser(hint.getParams()[0].toString()))) { - request.setCollapse(CollapseBuilder.fromXContent(parser)); - } catch (IOException e) { - throw new SqlParseException("could not parse collapse hint: " + e.getMessage()); - } - } - } - } - - protected void updateRequestWithPostFilter(Select select, SearchRequestBuilder request) { - for (Hint hint : select.getHints()) { - if (hint.getType() == HintType.POST_FILTER && hint.getParams() != null && 0 < hint.getParams().length) { - request.setPostFilter(QueryBuilders.wrapperQuery(hint.getParams()[0].toString())); - } - } - } - - protected void updateRequestWithIndexAndRoutingOptions(Select select, SearchRequestBuilder request) { - for (Hint hint : select.getHints()) { - if (hint.getType() == HintType.IGNORE_UNAVAILABLE) { - //saving the defaults from TransportClient search - request.setIndicesOptions(IndicesOptions.fromOptions(true, false, true, false, - IndicesOptions.strictExpandOpenAndForbidClosed())); - } - if (hint.getType() == HintType.ROUTINGS) { - Object[] routings = hint.getParams(); - String[] routingsAsStringArray = new String[routings.length]; - for (int i = 0; i < routings.length; i++) { - routingsAsStringArray[i] = routings[i].toString(); - } - request.setRouting(routingsAsStringArray); - } - } - } - - protected void updateRequestWithHighlight(Select select, SearchRequestBuilder request) { - boolean foundAnyHighlights = false; - HighlightBuilder highlightBuilder = new HighlightBuilder(); - for (Hint hint : select.getHints()) { - if (hint.getType() == HintType.HIGHLIGHT) { - HighlightBuilder.Field highlightField = parseHighlightField(hint.getParams()); - if (highlightField != null) { - foundAnyHighlights = true; - highlightBuilder.field(highlightField); - } - } - } - if (foundAnyHighlights) { - request.highlighter(highlightBuilder); - } - } - - protected HighlightBuilder.Field parseHighlightField(Object[] params) { - if (params == null || params.length == 0 || params.length > 2) { - //todo: exception. + protected Query query; + protected Client client; + protected SqlRequest sqlRequest = SqlRequest.NULL; + protected ColumnTypeProvider scriptColumnType; + protected Format format; + + public QueryAction(Client client, Query query) { + this.client = client; + this.query = query; + } + + public Client getClient() { + return client; + } + + public QueryStatement getQueryStatement() { + return query; + } + + public void setSqlRequest(SqlRequest sqlRequest) { + this.sqlRequest = sqlRequest; + } + + public void setColumnTypeProvider(ColumnTypeProvider scriptColumnType) { + this.scriptColumnType = scriptColumnType; + } + + public SqlRequest getSqlRequest() { + return sqlRequest; + } + + public void setFormat(Format format) { + this.format = format; + } + + public Format getFormat() { + return this.format; + } + + public ColumnTypeProvider getScriptColumnType() { + return scriptColumnType; + } + + /** + * @return List of field names produced by the query + */ + public Optional> getFieldNames() { + return Optional.empty(); + } + + protected void updateRequestWithCollapse(Select select, SearchRequestBuilder request) + throws SqlParseException { + JsonFactory jsonFactory = new JsonFactory(); + for (Hint hint : select.getHints()) { + if (hint.getType() == HintType.COLLAPSE + && hint.getParams() != null + && 0 < hint.getParams().length) { + try (JsonXContentParser parser = + new JsonXContentParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + jsonFactory.createParser(hint.getParams()[0].toString()))) { + request.setCollapse(CollapseBuilder.fromXContent(parser)); + } catch (IOException e) { + throw new SqlParseException("could not parse collapse hint: " + e.getMessage()); } - HighlightBuilder.Field field = new HighlightBuilder.Field(params[0].toString()); - if (params.length == 1) { - return field; + } + } + } + + protected void updateRequestWithPostFilter(Select select, SearchRequestBuilder request) { + for (Hint hint : select.getHints()) { + if (hint.getType() == HintType.POST_FILTER + && hint.getParams() != null + && 0 < hint.getParams().length) { + request.setPostFilter(QueryBuilders.wrapperQuery(hint.getParams()[0].toString())); + } + } + } + + protected void updateRequestWithIndexAndRoutingOptions( + Select select, SearchRequestBuilder request) { + for (Hint hint : select.getHints()) { + if (hint.getType() == HintType.IGNORE_UNAVAILABLE) { + // saving the defaults from TransportClient search + request.setIndicesOptions( + IndicesOptions.fromOptions( + true, false, true, false, IndicesOptions.strictExpandOpenAndForbidClosed())); + } + if (hint.getType() == HintType.ROUTINGS) { + Object[] routings = hint.getParams(); + String[] routingsAsStringArray = new String[routings.length]; + for (int i = 0; i < routings.length; i++) { + routingsAsStringArray[i] = routings[i].toString(); } - Map highlightParams = (Map) params[1]; - - for (Map.Entry param : highlightParams.entrySet()) { - switch (param.getKey()) { - case "type": - field.highlighterType((String) param.getValue()); - break; - case "boundary_chars": - field.boundaryChars(fromArrayListToCharArray((ArrayList) param.getValue())); - break; - case "boundary_max_scan": - field.boundaryMaxScan((Integer) param.getValue()); - break; - case "force_source": - field.forceSource((Boolean) param.getValue()); - break; - case "fragmenter": - field.fragmenter((String) param.getValue()); - break; - case "fragment_offset": - field.fragmentOffset((Integer) param.getValue()); - break; - case "fragment_size": - field.fragmentSize((Integer) param.getValue()); - break; - case "highlight_filter": - field.highlightFilter((Boolean) param.getValue()); - break; - case "matched_fields": - field.matchedFields((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); - break; - case "no_match_size": - field.noMatchSize((Integer) param.getValue()); - break; - case "num_of_fragments": - field.numOfFragments((Integer) param.getValue()); - break; - case "order": - field.order((String) param.getValue()); - break; - case "phrase_limit": - field.phraseLimit((Integer) param.getValue()); - break; - case "post_tags": - field.postTags((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); - break; - case "pre_tags": - field.preTags((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); - break; - case "require_field_match": - field.requireFieldMatch((Boolean) param.getValue()); - break; - - } + request.setRouting(routingsAsStringArray); + } + } + } + + protected void updateRequestWithHighlight(Select select, SearchRequestBuilder request) { + boolean foundAnyHighlights = false; + HighlightBuilder highlightBuilder = new HighlightBuilder(); + for (Hint hint : select.getHints()) { + if (hint.getType() == HintType.HIGHLIGHT) { + HighlightBuilder.Field highlightField = parseHighlightField(hint.getParams()); + if (highlightField != null) { + foundAnyHighlights = true; + highlightBuilder.field(highlightField); } - return field; - } - - private char[] fromArrayListToCharArray(ArrayList arrayList) { - char[] chars = new char[arrayList.size()]; - int i = 0; - for (Object item : arrayList) { - chars[i] = item.toString().charAt(0); - i++; - } - return chars; - } - - /** - * Prepare the request, and return OpenSearch request. - * - * @return ActionRequestBuilder (OpenSearch request) - * @throws SqlParseException - */ - public abstract SqlElasticRequestBuilder explain() throws SqlParseException; + } + } + if (foundAnyHighlights) { + request.highlighter(highlightBuilder); + } + } + + protected HighlightBuilder.Field parseHighlightField(Object[] params) { + if (params == null || params.length == 0 || params.length > 2) { + // todo: exception. + } + HighlightBuilder.Field field = new HighlightBuilder.Field(params[0].toString()); + if (params.length == 1) { + return field; + } + Map highlightParams = (Map) params[1]; + + for (Map.Entry param : highlightParams.entrySet()) { + switch (param.getKey()) { + case "type": + field.highlighterType((String) param.getValue()); + break; + case "boundary_chars": + field.boundaryChars(fromArrayListToCharArray((ArrayList) param.getValue())); + break; + case "boundary_max_scan": + field.boundaryMaxScan((Integer) param.getValue()); + break; + case "force_source": + field.forceSource((Boolean) param.getValue()); + break; + case "fragmenter": + field.fragmenter((String) param.getValue()); + break; + case "fragment_offset": + field.fragmentOffset((Integer) param.getValue()); + break; + case "fragment_size": + field.fragmentSize((Integer) param.getValue()); + break; + case "highlight_filter": + field.highlightFilter((Boolean) param.getValue()); + break; + case "matched_fields": + field.matchedFields((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); + break; + case "no_match_size": + field.noMatchSize((Integer) param.getValue()); + break; + case "num_of_fragments": + field.numOfFragments((Integer) param.getValue()); + break; + case "order": + field.order((String) param.getValue()); + break; + case "phrase_limit": + field.phraseLimit((Integer) param.getValue()); + break; + case "post_tags": + field.postTags((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); + break; + case "pre_tags": + field.preTags((String[]) ((ArrayList) param.getValue()).toArray(new String[0])); + break; + case "require_field_match": + field.requireFieldMatch((Boolean) param.getValue()); + break; + } + } + return field; + } + + private char[] fromArrayListToCharArray(ArrayList arrayList) { + char[] chars = new char[arrayList.size()]; + int i = 0; + for (Object item : arrayList) { + chars[i] = item.toString().charAt(0); + i++; + } + return chars; + } + + /** + * Prepare the request, and return OpenSearch request. + * + * @return ActionRequestBuilder (OpenSearch request) + * @throws SqlParseException + */ + public abstract SqlElasticRequestBuilder explain() throws SqlParseException; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/NestedLoopsElasticRequestBuilder.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/NestedLoopsElasticRequestBuilder.java index c14d8f3012..9dd34c71b9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/NestedLoopsElasticRequestBuilder.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/NestedLoopsElasticRequestBuilder.java @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; - import java.io.IOException; import org.json.JSONObject; import org.json.JSONStringer; @@ -19,86 +17,96 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.query.maker.QueryMaker; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public class NestedLoopsElasticRequestBuilder extends JoinRequestBuilder { - private Where connectedWhere; - private int multiSearchMaxSize; + private Where connectedWhere; + private int multiSearchMaxSize; - public NestedLoopsElasticRequestBuilder() { + public NestedLoopsElasticRequestBuilder() { - multiSearchMaxSize = 100; - } + multiSearchMaxSize = 100; + } - @Override - public String explain() { - String conditions = ""; - - try { - Where where = (Where) this.connectedWhere.clone(); - setValueTypeConditionToStringRecursive(where); - if (where != null) { - conditions = QueryMaker.explain(where, false).toString(); - } - } catch (CloneNotSupportedException | SqlParseException e) { - conditions = "Could not parse conditions due to " + e.getMessage(); - } - - String desc = "Nested Loops run first query, and for each result run " - + "second query with additional conditions as following."; - String[] queries = explainNL(); - JSONStringer jsonStringer = new JSONStringer(); - jsonStringer.object().key("description").value(desc) - .key("conditions").value(new JSONObject(conditions)) - .key("first query").value(new JSONObject(queries[0])) - .key("second query").value(new JSONObject(queries[1])).endObject(); - return jsonStringer.toString(); - } + @Override + public String explain() { + String conditions = ""; - public int getMultiSearchMaxSize() { - return multiSearchMaxSize; + try { + Where where = (Where) this.connectedWhere.clone(); + setValueTypeConditionToStringRecursive(where); + if (where != null) { + conditions = QueryMaker.explain(where, false).toString(); + } + } catch (CloneNotSupportedException | SqlParseException e) { + conditions = "Could not parse conditions due to " + e.getMessage(); } - public void setMultiSearchMaxSize(int multiSearchMaxSize) { - this.multiSearchMaxSize = multiSearchMaxSize; + String desc = + "Nested Loops run first query, and for each result run " + + "second query with additional conditions as following."; + String[] queries = explainNL(); + JSONStringer jsonStringer = new JSONStringer(); + jsonStringer + .object() + .key("description") + .value(desc) + .key("conditions") + .value(new JSONObject(conditions)) + .key("first query") + .value(new JSONObject(queries[0])) + .key("second query") + .value(new JSONObject(queries[1])) + .endObject(); + return jsonStringer.toString(); + } + + public int getMultiSearchMaxSize() { + return multiSearchMaxSize; + } + + public void setMultiSearchMaxSize(int multiSearchMaxSize) { + this.multiSearchMaxSize = multiSearchMaxSize; + } + + public Where getConnectedWhere() { + return connectedWhere; + } + + public void setConnectedWhere(Where connectedWhere) { + this.connectedWhere = connectedWhere; + } + + private void setValueTypeConditionToStringRecursive(Where where) { + if (where == null) { + return; } - - public Where getConnectedWhere() { - return connectedWhere; + if (where instanceof Condition) { + Condition c = (Condition) where; + c.setValue(c.getValue().toString()); + return; + } else { + for (Where innerWhere : where.getWheres()) { + setValueTypeConditionToStringRecursive(innerWhere); + } } - - public void setConnectedWhere(Where connectedWhere) { - this.connectedWhere = connectedWhere; - } - - private void setValueTypeConditionToStringRecursive(Where where) { - if (where == null) { - return; - } - if (where instanceof Condition) { - Condition c = (Condition) where; - c.setValue(c.getValue().toString()); - return; - } else { - for (Where innerWhere : where.getWheres()) { - setValueTypeConditionToStringRecursive(innerWhere); - } - } - } - - private String[] explainNL() { - return new String[]{explainQuery(this.getFirstTable()), explainQuery(this.getSecondTable())}; - } - - private String explainQuery(TableInJoinRequestBuilder requestBuilder) { - try { - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().prettyPrint(); - requestBuilder.getRequestBuilder().request().source().toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); - return BytesReference.bytes(xContentBuilder).utf8ToString(); - } catch (IOException e) { - return e.getMessage(); - } + } + + private String[] explainNL() { + return new String[] {explainQuery(this.getFirstTable()), explainQuery(this.getSecondTable())}; + } + + private String explainQuery(TableInJoinRequestBuilder requestBuilder) { + try { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().prettyPrint(); + requestBuilder + .getRequestBuilder() + .request() + .source() + .toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + return BytesReference.bytes(xContentBuilder).utf8ToString(); + } catch (IOException e) { + return e.getMessage(); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchHashJoinQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchHashJoinQueryAction.java index 0a87c16067..078ed6bcce 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchHashJoinQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchHashJoinQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; import java.util.AbstractMap; @@ -20,129 +19,126 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.query.planner.HashJoinQueryPlanRequestBuilder; -/** - * Created by Eliran on 22/8/2015. - */ +/** Created by Eliran on 22/8/2015. */ public class OpenSearchHashJoinQueryAction extends OpenSearchJoinQueryAction { - public OpenSearchHashJoinQueryAction(Client client, JoinSelect joinSelect) { - super(client, joinSelect); - } + public OpenSearchHashJoinQueryAction(Client client, JoinSelect joinSelect) { + super(client, joinSelect); + } - @Override - protected void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) throws SqlParseException { - String t1Alias = joinSelect.getFirstTable().getAlias(); - String t2Alias = joinSelect.getSecondTable().getAlias(); + @Override + protected void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) + throws SqlParseException { + String t1Alias = joinSelect.getFirstTable().getAlias(); + String t2Alias = joinSelect.getSecondTable().getAlias(); - List>> comparisonFields = getComparisonFields(t1Alias, t2Alias, - joinSelect.getConnectedWhere()); + List>> comparisonFields = + getComparisonFields(t1Alias, t2Alias, joinSelect.getConnectedWhere()); - ((HashJoinElasticRequestBuilder) requestBuilder).setT1ToT2FieldsComparison(comparisonFields); - } + ((HashJoinElasticRequestBuilder) requestBuilder).setT1ToT2FieldsComparison(comparisonFields); + } - @Override - protected JoinRequestBuilder createSpecificBuilder() { - if (isLegacy()) { - return new HashJoinElasticRequestBuilder(); - } - return new HashJoinQueryPlanRequestBuilder(client, sqlRequest); + @Override + protected JoinRequestBuilder createSpecificBuilder() { + if (isLegacy()) { + return new HashJoinElasticRequestBuilder(); } - - @Override - protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { - super.updateRequestWithHints(requestBuilder); - for (Hint hint : joinSelect.getHints()) { - if (hint.getType() == HintType.HASH_WITH_TERMS_FILTER) { - ((HashJoinElasticRequestBuilder) requestBuilder).setUseTermFiltersOptimization(true); - } - } + return new HashJoinQueryPlanRequestBuilder(client, sqlRequest); + } + + @Override + protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { + super.updateRequestWithHints(requestBuilder); + for (Hint hint : joinSelect.getHints()) { + if (hint.getType() == HintType.HASH_WITH_TERMS_FILTER) { + ((HashJoinElasticRequestBuilder) requestBuilder).setUseTermFiltersOptimization(true); + } } - - /** - * Keep the option to run legacy hash join algorithm mainly for the comparison - */ - private boolean isLegacy() { - for (Hint hint : joinSelect.getHints()) { - if (hint.getType() == HintType.JOIN_ALGORITHM_USE_LEGACY) { - return true; - } - } - return false; + } + + /** Keep the option to run legacy hash join algorithm mainly for the comparison */ + private boolean isLegacy() { + for (Hint hint : joinSelect.getHints()) { + if (hint.getType() == HintType.JOIN_ALGORITHM_USE_LEGACY) { + return true; + } } - - private List> getComparisonFields(String t1Alias, String t2Alias, - List connectedConditions) - throws SqlParseException { - List> comparisonFields = new ArrayList<>(); - for (Condition condition : connectedConditions) { - - if (condition.getOPERATOR() != Condition.OPERATOR.EQ) { - throw new SqlParseException( - String.format("HashJoin should only be with EQ conditions, got:%s on condition:%s", - condition.getOPERATOR().name(), condition.toString())); - } - - String firstField = condition.getName(); - String secondField = condition.getValue().toString(); - Field t1Field, t2Field; - if (firstField.startsWith(t1Alias)) { - t1Field = new Field(removeAlias(firstField, t1Alias), null); - t2Field = new Field(removeAlias(secondField, t2Alias), null); - } else { - t1Field = new Field(removeAlias(secondField, t1Alias), null); - t2Field = new Field(removeAlias(firstField, t2Alias), null); - } - comparisonFields.add(new AbstractMap.SimpleEntry<>(t1Field, t2Field)); - } - return comparisonFields; + return false; + } + + private List> getComparisonFields( + String t1Alias, String t2Alias, List connectedConditions) + throws SqlParseException { + List> comparisonFields = new ArrayList<>(); + for (Condition condition : connectedConditions) { + + if (condition.getOPERATOR() != Condition.OPERATOR.EQ) { + throw new SqlParseException( + String.format( + "HashJoin should only be with EQ conditions, got:%s on condition:%s", + condition.getOPERATOR().name(), condition.toString())); + } + + String firstField = condition.getName(); + String secondField = condition.getValue().toString(); + Field t1Field, t2Field; + if (firstField.startsWith(t1Alias)) { + t1Field = new Field(removeAlias(firstField, t1Alias), null); + t2Field = new Field(removeAlias(secondField, t2Alias), null); + } else { + t1Field = new Field(removeAlias(secondField, t1Alias), null); + t2Field = new Field(removeAlias(firstField, t2Alias), null); + } + comparisonFields.add(new AbstractMap.SimpleEntry<>(t1Field, t2Field)); } - - private List>> getComparisonFields(String t1Alias, String t2Alias, - Where connectedWhere) throws SqlParseException { - List>> comparisonFields = new ArrayList<>(); - //where is AND with lots of conditions. - if (connectedWhere == null) { - return comparisonFields; - } - boolean allAnds = true; - for (Where innerWhere : connectedWhere.getWheres()) { - if (innerWhere.getConn() == Where.CONN.OR) { - allAnds = false; - break; - } - } - if (allAnds) { - List> innerComparisonFields = - getComparisonFieldsFromWhere(t1Alias, t2Alias, connectedWhere); - comparisonFields.add(innerComparisonFields); - } else { - for (Where innerWhere : connectedWhere.getWheres()) { - comparisonFields.add(getComparisonFieldsFromWhere(t1Alias, t2Alias, innerWhere)); - } - } - - return comparisonFields; + return comparisonFields; + } + + private List>> getComparisonFields( + String t1Alias, String t2Alias, Where connectedWhere) throws SqlParseException { + List>> comparisonFields = new ArrayList<>(); + // where is AND with lots of conditions. + if (connectedWhere == null) { + return comparisonFields; } - - private List> getComparisonFieldsFromWhere(String t1Alias, String t2Alias, Where where) - throws SqlParseException { - List conditions = new ArrayList<>(); - if (where instanceof Condition) { - conditions.add((Condition) where); - } else { - for (Where innerWhere : where.getWheres()) { - if (!(innerWhere instanceof Condition)) { - throw new SqlParseException( - "if connectedCondition is AND then all inner wheres should be Conditions"); - } - conditions.add((Condition) innerWhere); - } - } - return getComparisonFields(t1Alias, t2Alias, conditions); + boolean allAnds = true; + for (Where innerWhere : connectedWhere.getWheres()) { + if (innerWhere.getConn() == Where.CONN.OR) { + allAnds = false; + break; + } + } + if (allAnds) { + List> innerComparisonFields = + getComparisonFieldsFromWhere(t1Alias, t2Alias, connectedWhere); + comparisonFields.add(innerComparisonFields); + } else { + for (Where innerWhere : connectedWhere.getWheres()) { + comparisonFields.add(getComparisonFieldsFromWhere(t1Alias, t2Alias, innerWhere)); + } } - private String removeAlias(String field, String alias) { - return field.replace(alias + ".", ""); + return comparisonFields; + } + + private List> getComparisonFieldsFromWhere( + String t1Alias, String t2Alias, Where where) throws SqlParseException { + List conditions = new ArrayList<>(); + if (where instanceof Condition) { + conditions.add((Condition) where); + } else { + for (Where innerWhere : where.getWheres()) { + if (!(innerWhere instanceof Condition)) { + throw new SqlParseException( + "if connectedCondition is AND then all inner wheres should be Conditions"); + } + conditions.add((Condition) innerWhere); + } } + return getComparisonFields(t1Alias, t2Alias, conditions); + } + private String removeAlias(String field, String alias) { + return field.replace(alias + ".", ""); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryAction.java index 35e718d985..7068ddf9a2 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; import java.util.List; @@ -20,111 +19,107 @@ import org.opensearch.sql.legacy.query.planner.HashJoinQueryPlanRequestBuilder; import org.opensearch.sql.legacy.query.planner.core.Config; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public abstract class OpenSearchJoinQueryAction extends QueryAction { - protected JoinSelect joinSelect; - - public OpenSearchJoinQueryAction(Client client, JoinSelect joinSelect) { - super(client, joinSelect); - this.joinSelect = joinSelect; - } - - @Override - public SqlElasticRequestBuilder explain() throws SqlParseException { - JoinRequestBuilder requestBuilder = createSpecificBuilder(); - fillBasicJoinRequestBuilder(requestBuilder); - fillSpecificRequestBuilder(requestBuilder); - return requestBuilder; - } - - protected abstract void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) throws SqlParseException; - - protected abstract JoinRequestBuilder createSpecificBuilder(); - - - private void fillBasicJoinRequestBuilder(JoinRequestBuilder requestBuilder) throws SqlParseException { - - fillTableInJoinRequestBuilder(requestBuilder.getFirstTable(), joinSelect.getFirstTable()); - fillTableInJoinRequestBuilder(requestBuilder.getSecondTable(), joinSelect.getSecondTable()); - - requestBuilder.setJoinType(joinSelect.getJoinType()); - - requestBuilder.setTotalLimit(joinSelect.getTotalLimit()); - - updateRequestWithHints(requestBuilder); - - - } - - protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { - for (Hint hint : joinSelect.getHints()) { - Object[] params = hint.getParams(); - switch (hint.getType()) { - case JOIN_LIMIT: - requestBuilder.getFirstTable().setHintLimit((Integer) params[0]); - requestBuilder.getSecondTable().setHintLimit((Integer) params[1]); - break; - case JOIN_ALGORITHM_BLOCK_SIZE: - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - queryPlannerConfig(requestBuilder).configureBlockSize(hint.getParams()); - } - break; - case JOIN_SCROLL_PAGE_SIZE: - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - queryPlannerConfig(requestBuilder).configureScrollPageSize(hint.getParams()); - } - break; - case JOIN_CIRCUIT_BREAK_LIMIT: - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - queryPlannerConfig(requestBuilder).configureCircuitBreakLimit(hint.getParams()); - } - break; - case JOIN_BACK_OFF_RETRY_INTERVALS: - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - queryPlannerConfig(requestBuilder).configureBackOffRetryIntervals(hint.getParams()); - } - break; - case JOIN_TIME_OUT: - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - queryPlannerConfig(requestBuilder).configureTimeOut(hint.getParams()); - } - break; - } - } + protected JoinSelect joinSelect; + + public OpenSearchJoinQueryAction(Client client, JoinSelect joinSelect) { + super(client, joinSelect); + this.joinSelect = joinSelect; + } + + @Override + public SqlElasticRequestBuilder explain() throws SqlParseException { + JoinRequestBuilder requestBuilder = createSpecificBuilder(); + fillBasicJoinRequestBuilder(requestBuilder); + fillSpecificRequestBuilder(requestBuilder); + return requestBuilder; + } + + protected abstract void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) + throws SqlParseException; + + protected abstract JoinRequestBuilder createSpecificBuilder(); + + private void fillBasicJoinRequestBuilder(JoinRequestBuilder requestBuilder) + throws SqlParseException { + + fillTableInJoinRequestBuilder(requestBuilder.getFirstTable(), joinSelect.getFirstTable()); + fillTableInJoinRequestBuilder(requestBuilder.getSecondTable(), joinSelect.getSecondTable()); + + requestBuilder.setJoinType(joinSelect.getJoinType()); + + requestBuilder.setTotalLimit(joinSelect.getTotalLimit()); + + updateRequestWithHints(requestBuilder); + } + + protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { + for (Hint hint : joinSelect.getHints()) { + Object[] params = hint.getParams(); + switch (hint.getType()) { + case JOIN_LIMIT: + requestBuilder.getFirstTable().setHintLimit((Integer) params[0]); + requestBuilder.getSecondTable().setHintLimit((Integer) params[1]); + break; + case JOIN_ALGORITHM_BLOCK_SIZE: + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + queryPlannerConfig(requestBuilder).configureBlockSize(hint.getParams()); + } + break; + case JOIN_SCROLL_PAGE_SIZE: + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + queryPlannerConfig(requestBuilder).configureScrollPageSize(hint.getParams()); + } + break; + case JOIN_CIRCUIT_BREAK_LIMIT: + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + queryPlannerConfig(requestBuilder).configureCircuitBreakLimit(hint.getParams()); + } + break; + case JOIN_BACK_OFF_RETRY_INTERVALS: + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + queryPlannerConfig(requestBuilder).configureBackOffRetryIntervals(hint.getParams()); + } + break; + case JOIN_TIME_OUT: + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + queryPlannerConfig(requestBuilder).configureTimeOut(hint.getParams()); + } + break; + } } - - private Config queryPlannerConfig(JoinRequestBuilder requestBuilder) { - return ((HashJoinQueryPlanRequestBuilder) requestBuilder).getConfig(); - } - - private void fillTableInJoinRequestBuilder(TableInJoinRequestBuilder requestBuilder, - TableOnJoinSelect tableOnJoinSelect) throws SqlParseException { - List connectedFields = tableOnJoinSelect.getConnectedFields(); - addFieldsToSelectIfMissing(tableOnJoinSelect, connectedFields); - requestBuilder.setOriginalSelect(tableOnJoinSelect); - DefaultQueryAction queryAction = new DefaultQueryAction(client, tableOnJoinSelect); - queryAction.explain(); - requestBuilder.setRequestBuilder(queryAction.getRequestBuilder()); - requestBuilder.setReturnedFields(tableOnJoinSelect.getSelectedFields()); - requestBuilder.setAlias(tableOnJoinSelect.getAlias()); + } + + private Config queryPlannerConfig(JoinRequestBuilder requestBuilder) { + return ((HashJoinQueryPlanRequestBuilder) requestBuilder).getConfig(); + } + + private void fillTableInJoinRequestBuilder( + TableInJoinRequestBuilder requestBuilder, TableOnJoinSelect tableOnJoinSelect) + throws SqlParseException { + List connectedFields = tableOnJoinSelect.getConnectedFields(); + addFieldsToSelectIfMissing(tableOnJoinSelect, connectedFields); + requestBuilder.setOriginalSelect(tableOnJoinSelect); + DefaultQueryAction queryAction = new DefaultQueryAction(client, tableOnJoinSelect); + queryAction.explain(); + requestBuilder.setRequestBuilder(queryAction.getRequestBuilder()); + requestBuilder.setReturnedFields(tableOnJoinSelect.getSelectedFields()); + requestBuilder.setAlias(tableOnJoinSelect.getAlias()); + } + + private void addFieldsToSelectIfMissing(Select select, List fields) { + // this means all fields + if (select.getFields() == null || select.getFields().size() == 0) { + return; } - private void addFieldsToSelectIfMissing(Select select, List fields) { - //this means all fields - if (select.getFields() == null || select.getFields().size() == 0) { - return; - } - - List selectedFields = select.getFields(); - for (Field field : fields) { - if (!selectedFields.contains(field)) { - selectedFields.add(field); - } - } - + List selectedFields = select.getFields(); + for (Field field : fields) { + if (!selectedFields.contains(field)) { + selectedFields.add(field); + } } - + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryActionFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryActionFactory.java index c96cb6120c..c638f43519 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryActionFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchJoinQueryActionFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; import java.util.List; @@ -14,36 +13,32 @@ import org.opensearch.sql.legacy.domain.hints.HintType; import org.opensearch.sql.legacy.query.QueryAction; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public class OpenSearchJoinQueryActionFactory { - public static QueryAction createJoinAction(Client client, JoinSelect joinSelect) { - List connectedConditions = joinSelect.getConnectedConditions(); - boolean allEqual = true; - for (Condition condition : connectedConditions) { - if (condition.getOPERATOR() != Condition.OPERATOR.EQ) { - allEqual = false; - break; - } - - } - if (!allEqual) { - return new OpenSearchNestedLoopsQueryAction(client, joinSelect); - } - - boolean useNestedLoopsHintExist = false; - for (Hint hint : joinSelect.getHints()) { - if (hint.getType() == HintType.USE_NESTED_LOOPS) { - useNestedLoopsHintExist = true; - break; - } - } - if (useNestedLoopsHintExist) { - return new OpenSearchNestedLoopsQueryAction(client, joinSelect); - } - - return new OpenSearchHashJoinQueryAction(client, joinSelect); + public static QueryAction createJoinAction(Client client, JoinSelect joinSelect) { + List connectedConditions = joinSelect.getConnectedConditions(); + boolean allEqual = true; + for (Condition condition : connectedConditions) { + if (condition.getOPERATOR() != Condition.OPERATOR.EQ) { + allEqual = false; + break; + } + } + if (!allEqual) { + return new OpenSearchNestedLoopsQueryAction(client, joinSelect); + } + boolean useNestedLoopsHintExist = false; + for (Hint hint : joinSelect.getHints()) { + if (hint.getType() == HintType.USE_NESTED_LOOPS) { + useNestedLoopsHintExist = true; + break; + } } + if (useNestedLoopsHintExist) { + return new OpenSearchNestedLoopsQueryAction(client, joinSelect); + } + + return new OpenSearchHashJoinQueryAction(client, joinSelect); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchNestedLoopsQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchNestedLoopsQueryAction.java index 8954106f8a..e9e9169605 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchNestedLoopsQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/OpenSearchNestedLoopsQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; import org.opensearch.client.Client; @@ -13,45 +12,44 @@ import org.opensearch.sql.legacy.domain.hints.HintType; import org.opensearch.sql.legacy.exception.SqlParseException; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public class OpenSearchNestedLoopsQueryAction extends OpenSearchJoinQueryAction { - public OpenSearchNestedLoopsQueryAction(Client client, JoinSelect joinSelect) { - super(client, joinSelect); - } - - @Override - protected void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) throws SqlParseException { - NestedLoopsElasticRequestBuilder nestedBuilder = (NestedLoopsElasticRequestBuilder) requestBuilder; - Where where = joinSelect.getConnectedWhere(); - nestedBuilder.setConnectedWhere(where); - + public OpenSearchNestedLoopsQueryAction(Client client, JoinSelect joinSelect) { + super(client, joinSelect); + } + + @Override + protected void fillSpecificRequestBuilder(JoinRequestBuilder requestBuilder) + throws SqlParseException { + NestedLoopsElasticRequestBuilder nestedBuilder = + (NestedLoopsElasticRequestBuilder) requestBuilder; + Where where = joinSelect.getConnectedWhere(); + nestedBuilder.setConnectedWhere(where); + } + + @Override + protected JoinRequestBuilder createSpecificBuilder() { + return new NestedLoopsElasticRequestBuilder(); + } + + @Override + protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { + super.updateRequestWithHints(requestBuilder); + for (Hint hint : this.joinSelect.getHints()) { + if (hint.getType() == HintType.NL_MULTISEARCH_SIZE) { + Integer multiSearchMaxSize = (Integer) hint.getParams()[0]; + ((NestedLoopsElasticRequestBuilder) requestBuilder) + .setMultiSearchMaxSize(multiSearchMaxSize); + } } + } - @Override - protected JoinRequestBuilder createSpecificBuilder() { - return new NestedLoopsElasticRequestBuilder(); + private String removeAlias(String field) { + String alias = joinSelect.getFirstTable().getAlias(); + if (!field.startsWith(alias + ".")) { + alias = joinSelect.getSecondTable().getAlias(); } - - @Override - protected void updateRequestWithHints(JoinRequestBuilder requestBuilder) { - super.updateRequestWithHints(requestBuilder); - for (Hint hint : this.joinSelect.getHints()) { - if (hint.getType() == HintType.NL_MULTISEARCH_SIZE) { - Integer multiSearchMaxSize = (Integer) hint.getParams()[0]; - ((NestedLoopsElasticRequestBuilder) requestBuilder).setMultiSearchMaxSize(multiSearchMaxSize); - } - } - } - - private String removeAlias(String field) { - String alias = joinSelect.getFirstTable().getAlias(); - if (!field.startsWith(alias + ".")) { - alias = joinSelect.getSecondTable().getAlias(); - } - return field.replace(alias + ".", ""); - } - + return field.replace(alias + ".", ""); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/QueryMaker.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/QueryMaker.java index f36bca2686..75f3538981 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/QueryMaker.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/QueryMaker.java @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.maker; - import org.apache.lucene.search.join.ScoreMode; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -18,76 +16,76 @@ public class QueryMaker extends Maker { - /** - * - * - * @param where - * @return - * @throws SqlParseException - */ - public static BoolQueryBuilder explain(Where where) throws SqlParseException { - return explain(where, true); - } + /** + * @param where + * @return + * @throws SqlParseException + */ + public static BoolQueryBuilder explain(Where where) throws SqlParseException { + return explain(where, true); + } - public static BoolQueryBuilder explain(Where where, boolean isQuery) throws SqlParseException { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); - while (where.getWheres().size() == 1) { - where = where.getWheres().getFirst(); - } - new QueryMaker().explanWhere(boolQuery, where); - if (isQuery) { - return boolQuery; - } - return QueryBuilders.boolQuery().filter(boolQuery); + public static BoolQueryBuilder explain(Where where, boolean isQuery) throws SqlParseException { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + while (where.getWheres().size() == 1) { + where = where.getWheres().getFirst(); } - - private QueryMaker() { - super(true); + new QueryMaker().explanWhere(boolQuery, where); + if (isQuery) { + return boolQuery; } + return QueryBuilders.boolQuery().filter(boolQuery); + } - private void explanWhere(BoolQueryBuilder boolQuery, Where where) throws SqlParseException { - if (where instanceof Condition) { - addSubQuery(boolQuery, where, (QueryBuilder) make((Condition) where)); - } else { - BoolQueryBuilder subQuery = QueryBuilders.boolQuery(); - addSubQuery(boolQuery, where, subQuery); - for (Where subWhere : where.getWheres()) { - explanWhere(subQuery, subWhere); - } - } - } + private QueryMaker() { + super(true); + } - /** - * - * - * @param boolQuery - * @param where - * @param subQuery - */ - private void addSubQuery(BoolQueryBuilder boolQuery, Where where, QueryBuilder subQuery) { - if (where instanceof Condition) { - Condition condition = (Condition) where; + private void explanWhere(BoolQueryBuilder boolQuery, Where where) throws SqlParseException { + if (where instanceof Condition) { + addSubQuery(boolQuery, where, (QueryBuilder) make((Condition) where)); + } else { + BoolQueryBuilder subQuery = QueryBuilders.boolQuery(); + addSubQuery(boolQuery, where, subQuery); + for (Where subWhere : where.getWheres()) { + explanWhere(subQuery, subWhere); + } + } + } - if (condition.isNested()) { - // bugfix #628 - if ("missing".equalsIgnoreCase(String.valueOf(condition.getValue())) - && (condition.getOPERATOR() == Condition.OPERATOR.IS - || condition.getOPERATOR() == Condition.OPERATOR.EQ)) { - boolQuery.mustNot(QueryBuilders.nestedQuery(condition.getNestedPath(), - QueryBuilders.boolQuery().mustNot(subQuery), ScoreMode.None)); - return; - } + /** + * @param boolQuery + * @param where + * @param subQuery + */ + private void addSubQuery(BoolQueryBuilder boolQuery, Where where, QueryBuilder subQuery) { + if (where instanceof Condition) { + Condition condition = (Condition) where; - subQuery = QueryBuilders.nestedQuery(condition.getNestedPath(), subQuery, ScoreMode.None); - } else if (condition.isChildren()) { - subQuery = JoinQueryBuilders.hasChildQuery(condition.getChildType(), subQuery, ScoreMode.None); - } + if (condition.isNested()) { + // bugfix #628 + if ("missing".equalsIgnoreCase(String.valueOf(condition.getValue())) + && (condition.getOPERATOR() == Condition.OPERATOR.IS + || condition.getOPERATOR() == Condition.OPERATOR.EQ)) { + boolQuery.mustNot( + QueryBuilders.nestedQuery( + condition.getNestedPath(), + QueryBuilders.boolQuery().mustNot(subQuery), + ScoreMode.None)); + return; } - if (where.getConn() == Where.CONN.AND) { - boolQuery.must(subQuery); - } else { - boolQuery.should(subQuery); - } + subQuery = QueryBuilders.nestedQuery(condition.getNestedPath(), subQuery, ScoreMode.None); + } else if (condition.isChildren()) { + subQuery = + JoinQueryBuilders.hasChildQuery(condition.getChildType(), subQuery, ScoreMode.None); + } + } + + if (where.getConn() == Where.CONN.AND) { + boolQuery.must(subQuery); + } else { + boolQuery.should(subQuery); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/OpenSearchMultiQueryActionFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/OpenSearchMultiQueryActionFactory.java index be86fdef81..1f934e9a80 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/OpenSearchMultiQueryActionFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/OpenSearchMultiQueryActionFactory.java @@ -3,26 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.multi; import org.opensearch.client.Client; import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.query.QueryAction; -/** - * Created by Eliran on 19/8/2016. - */ +/** Created by Eliran on 19/8/2016. */ public class OpenSearchMultiQueryActionFactory { - public static QueryAction createMultiQueryAction(Client client, MultiQuerySelect multiSelect) - throws SqlParseException { - switch (multiSelect.getOperation()) { - case UNION_ALL: - case UNION: - return new MultiQueryAction(client, multiSelect); - default: - throw new SqlParseException("only supports union and union all"); - } + public static QueryAction createMultiQueryAction(Client client, MultiQuerySelect multiSelect) + throws SqlParseException { + switch (multiSelect.getOperation()) { + case UNION_ALL: + case UNION: + return new MultiQueryAction(client, multiSelect); + default: + throw new SqlParseException("only supports union and union all"); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Plan.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Plan.java index f163e61f0e..328bb9451f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Plan.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Plan.java @@ -3,26 +3,20 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import org.opensearch.sql.legacy.query.planner.core.PlanNode.Visitor; -/** - * Query plan - */ +/** Query plan */ public interface Plan { - /** - * Explain current query plan by visitor - * - * @param explanation visitor to explain the plan - */ - void traverse(Visitor explanation); - - /** - * Optimize current query plan to get the optimal one - */ - void optimize(); + /** + * Explain current query plan by visitor + * + * @param explanation visitor to explain the plan + */ + void traverse(Visitor explanation); + /** Optimize current query plan to get the optimal one */ + void optimize(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/PlanNode.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/PlanNode.java index ad421f82a4..b30ec9d3d9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/PlanNode.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/PlanNode.java @@ -3,54 +3,47 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; -/** - * Abstract plan node in query plan. - */ +/** Abstract plan node in query plan. */ public interface PlanNode { - /** - * All child nodes of current node used for traversal. - * - * @return all children - */ - PlanNode[] children(); + /** + * All child nodes of current node used for traversal. + * + * @return all children + */ + PlanNode[] children(); + + /** + * Accept a visitor and traverse the plan tree with it. + * + * @param visitor plan node visitor + */ + default void accept(Visitor visitor) { + if (visitor.visit(this)) { + for (PlanNode node : children()) { + node.accept(visitor); + } + } + visitor.endVisit(this); + } + + /** Plan node visitor. */ + interface Visitor { /** - * Accept a visitor and traverse the plan tree with it. + * To avoid listing all subclasses of PlanNode here, we dispatch manually in concrete visitor. * - * @param visitor plan node visitor + * @param op plan node being visited */ - default void accept(Visitor visitor) { - if (visitor.visit(this)) { - for (PlanNode node : children()) { - node.accept(visitor); - } - } - visitor.endVisit(this); - } + boolean visit(PlanNode op); /** - * Plan node visitor. + * Re-visit current node before return to parent node + * + * @param op plan node finished visit */ - interface Visitor { - - /** - * To avoid listing all subclasses of PlanNode here, we dispatch manually in concrete visitor. - * - * @param op plan node being visited - */ - boolean visit(PlanNode op); - - /** - * Re-visit current node before return to parent node - * - * @param op plan node finished visit - */ - default void endVisit(PlanNode op) { - } - } - + default void endVisit(PlanNode op) {} + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryParams.java index 2cb835da94..ae5f0fb9c8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryParams.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource; @@ -12,70 +11,68 @@ import org.opensearch.sql.legacy.domain.Field; import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; -/** - * All parameters required by QueryPlanner - */ +/** All parameters required by QueryPlanner */ public class QueryParams { - /** - * Request builder for first table - */ - private final TableInJoinRequestBuilder request1; + /** Request builder for first table */ + private final TableInJoinRequestBuilder request1; - /** - * Request builder for second table - */ - private final TableInJoinRequestBuilder request2; + /** Request builder for second table */ + private final TableInJoinRequestBuilder request2; - /** - * Join type, ex. inner join, left join - */ - private final SQLJoinTableSource.JoinType joinType; + /** Join type, ex. inner join, left join */ + private final SQLJoinTableSource.JoinType joinType; /** + *
      * Join conditions in ON clause grouped by OR.
      * For example, "ON (a.name = b.id AND a.age = b.age) OR a.location = b.address"
      * => list: [
      * [ (a.name, b.id), (a.age, b.age) ],
      * [ (a.location, b.address) ]
      * ]
+     * 
*/ private final List>> joinConditions; + public QueryParams( + TableInJoinRequestBuilder request1, + TableInJoinRequestBuilder request2, + SQLJoinTableSource.JoinType joinType, + List>> t1ToT2FieldsComparison) { + this.request1 = request1; + this.request2 = request2; + this.joinType = joinType; + this.joinConditions = t1ToT2FieldsComparison; + } - public QueryParams(TableInJoinRequestBuilder request1, - TableInJoinRequestBuilder request2, - SQLJoinTableSource.JoinType joinType, - List>> t1ToT2FieldsComparison) { - this.request1 = request1; - this.request2 = request2; - this.joinType = joinType; - this.joinConditions = t1ToT2FieldsComparison; - } - - public TableInJoinRequestBuilder firstRequest() { - return request1; - } + public TableInJoinRequestBuilder firstRequest() { + return request1; + } - public TableInJoinRequestBuilder secondRequest() { - return request2; - } + public TableInJoinRequestBuilder secondRequest() { + return request2; + } - public SQLJoinTableSource.JoinType joinType() { - return joinType; - } + public SQLJoinTableSource.JoinType joinType() { + return joinType; + } - public List>> joinConditions() { - return joinConditions; - } + public List>> joinConditions() { + return joinConditions; + } - @Override - public String toString() { - return "QueryParams{" - + "request1=" + request1 - + ", request2=" + request2 - + ", joinType=" + joinType - + ", joinConditions=" + joinConditions - + '}'; - } + @Override + public String toString() { + return "QueryParams{" + + "request1=" + + request1 + + ", request2=" + + request2 + + ", joinType=" + + joinType + + ", joinConditions=" + + joinConditions + + '}'; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryPlanner.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryPlanner.java index 56acfa5d0c..0a1c2fd24b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryPlanner.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/QueryPlanner.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import static org.opensearch.sql.legacy.query.planner.core.ExecuteParams.ExecuteParamType.CLIENT; @@ -21,89 +20,69 @@ import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; import org.opensearch.sql.legacy.query.planner.resource.Stats; -/** - * Query planner that driver the logical planning, physical planning, execute and explain. - */ +/** Query planner that driver the logical planning, physical planning, execute and explain. */ public class QueryPlanner { - /** - * Connection to ElasticSearch - */ - private final Client client; - - /** - * Query plan configuration - */ - private final Config config; - - /** - * Optimized logical plan - */ - private final LogicalPlan logicalPlan; - - /** - * Best physical plan to execute - */ - private final PhysicalPlan physicalPlan; - - /** - * Statistics collector - */ - private Stats stats; - - /** - * Resource monitor and statistics manager - */ - private ResourceManager resourceMgr; - - - public QueryPlanner(Client client, Config config, QueryParams params) { - this.client = client; - this.config = config; - this.stats = new Stats(client); - this.resourceMgr = new ResourceManager(stats, config); - - logicalPlan = new LogicalPlan(config, params); - logicalPlan.optimize(); - - physicalPlan = new PhysicalPlan(logicalPlan); - physicalPlan.optimize(); - } - - /** - * Execute query plan - * - * @return response of the execution - */ - public List execute() { - ExecuteParams params = new ExecuteParams(); - params.add(CLIENT, client); - params.add(TIMEOUT, config.timeout()); - params.add(RESOURCE_MANAGER, resourceMgr); - return physicalPlan.execute(params); - } - - /** - * Explain query plan - * - * @return explanation string of the plan - */ - public String explain() { - return new Explanation( - logicalPlan, physicalPlan, - new JsonExplanationFormat(4) - ).toString(); - } - - public MetaSearchResult getMetaResult() { - return resourceMgr.getMetaResult(); - } - - /** - * Setter for unit test - */ - public void setStats(Stats stats) { - this.stats = stats; - this.resourceMgr = new ResourceManager(stats, config); - } + /** Connection to ElasticSearch */ + private final Client client; + + /** Query plan configuration */ + private final Config config; + + /** Optimized logical plan */ + private final LogicalPlan logicalPlan; + + /** Best physical plan to execute */ + private final PhysicalPlan physicalPlan; + + /** Statistics collector */ + private Stats stats; + + /** Resource monitor and statistics manager */ + private ResourceManager resourceMgr; + + public QueryPlanner(Client client, Config config, QueryParams params) { + this.client = client; + this.config = config; + this.stats = new Stats(client); + this.resourceMgr = new ResourceManager(stats, config); + + logicalPlan = new LogicalPlan(config, params); + logicalPlan.optimize(); + + physicalPlan = new PhysicalPlan(logicalPlan); + physicalPlan.optimize(); + } + + /** + * Execute query plan + * + * @return response of the execution + */ + public List execute() { + ExecuteParams params = new ExecuteParams(); + params.add(CLIENT, client); + params.add(TIMEOUT, config.timeout()); + params.add(RESOURCE_MANAGER, resourceMgr); + return physicalPlan.execute(params); + } + + /** + * Explain query plan + * + * @return explanation string of the plan + */ + public String explain() { + return new Explanation(logicalPlan, physicalPlan, new JsonExplanationFormat(4)).toString(); + } + + public MetaSearchResult getMetaResult() { + return resourceMgr.getMetaResult(); + } + + /** Setter for unit test */ + public void setStats(Stats stats) { + this.stats = stats; + this.resourceMgr = new ResourceManager(stats, config); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/Project.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/Project.java index bd24564de2..4226744f1b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/Project.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/Project.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.logical.node; import com.google.common.collect.HashMultimap; @@ -23,126 +22,116 @@ import org.opensearch.sql.legacy.query.planner.physical.Row; import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; -/** - * Projection expression - */ +/** Projection expression */ public class Project implements LogicalOperator, PhysicalOperator { - private static final Logger LOG = LogManager.getLogger(); - - private final PlanNode next; + private static final Logger LOG = LogManager.getLogger(); - /** - * All columns being projected in SELECT in each table - */ - private final Multimap tableAliasColumns; + private final PlanNode next; - /** - * All columns full name (tableAlias.colName) to alias mapping - */ - private final Map fullNameAlias; + /** All columns being projected in SELECT in each table */ + private final Multimap tableAliasColumns; + /** All columns full name (tableAlias.colName) to alias mapping */ + private final Map fullNameAlias; - @SuppressWarnings("unchecked") - public Project(PlanNode next) { - this(next, HashMultimap.create()); - } + @SuppressWarnings("unchecked") + public Project(PlanNode next) { + this(next, HashMultimap.create()); + } - @SuppressWarnings("unchecked") - public Project(PlanNode next, Multimap tableAliasToColumns) { - this.next = next; - this.tableAliasColumns = tableAliasToColumns; - this.fullNameAlias = fullNameAndAlias(); - } - - @Override - public boolean isNoOp() { - return tableAliasColumns.isEmpty(); - } - - @Override - public PlanNode[] children() { - return new PlanNode[]{next}; - } + @SuppressWarnings("unchecked") + public Project(PlanNode next, Multimap tableAliasToColumns) { + this.next = next; + this.tableAliasColumns = tableAliasToColumns; + this.fullNameAlias = fullNameAndAlias(); + } - @Override - public PhysicalOperator[] toPhysical(Map> optimalOps) { - if (!(next instanceof LogicalOperator)) { - throw new IllegalStateException("Only logical operator can perform this toPhysical() operation"); - } - return new PhysicalOperator[]{ - new Project(optimalOps.get(next), tableAliasColumns) // Create physical Project instance - }; - } + @Override + public boolean isNoOp() { + return tableAliasColumns.isEmpty(); + } - @Override - public Cost estimate() { - return new Cost(); - } + @Override + public PlanNode[] children() { + return new PlanNode[] {next}; + } - @Override - public boolean hasNext() { - return ((PhysicalOperator) next).hasNext(); + @Override + public PhysicalOperator[] toPhysical(Map> optimalOps) { + if (!(next instanceof LogicalOperator)) { + throw new IllegalStateException( + "Only logical operator can perform this toPhysical() operation"); } - - @SuppressWarnings("unchecked") - @Override - public Row next() { - Row row = ((PhysicalOperator) this.next).next(); - - /* - * Empty means SELECT * which means retain all fields from both tables - * Because push down is always applied, only limited support for this. - */ - if (!fullNameAlias.isEmpty()) { - row.retain(fullNameAlias); - } - - LOG.trace("Projected row by fields {}: {}", tableAliasColumns, row); - return row; + return new PhysicalOperator[] { + new Project(optimalOps.get(next), tableAliasColumns) // Create physical Project instance + }; + } + + @Override + public Cost estimate() { + return new Cost(); + } + + @Override + public boolean hasNext() { + return ((PhysicalOperator) next).hasNext(); + } + + @SuppressWarnings("unchecked") + @Override + public Row next() { + Row row = ((PhysicalOperator) this.next).next(); + + /* + * Empty means SELECT * which means retain all fields from both tables + * Because push down is always applied, only limited support for this. + */ + if (!fullNameAlias.isEmpty()) { + row.retain(fullNameAlias); } - public void project(String tableAlias, Collection columns) { - tableAliasColumns.putAll(tableAlias, columns); - } + LOG.trace("Projected row by fields {}: {}", tableAliasColumns, row); + return row; + } - public void projectAll(String tableAlias) { - tableAliasColumns.put(tableAlias, new Field("*", "")); - } + public void project(String tableAlias, Collection columns) { + tableAliasColumns.putAll(tableAlias, columns); + } - public void forEach(BiConsumer> action) { - tableAliasColumns.asMap().forEach(action); - } + public void projectAll(String tableAlias) { + tableAliasColumns.put(tableAlias, new Field("*", "")); + } - public void pushDown(String tableAlias, Project pushedDownProj) { - Collection columns = pushedDownProj.tableAliasColumns.get(tableAlias); - if (columns != null) { - tableAliasColumns.putAll(tableAlias, columns); - } - } + public void forEach(BiConsumer> action) { + tableAliasColumns.asMap().forEach(action); + } - /** - * Return mapping from column full name ("e.age") and alias ("a" in "SELECT e.age AS a") - */ - private Map fullNameAndAlias() { - Map fullNamesAlias = new HashMap<>(); - forEach( - (tableAlias, fields) -> { - for (Field field : fields) { - fullNamesAlias.put(tableAlias + "." + field.getName(), field.getAlias()); - } - } - ); - return fullNamesAlias; + public void pushDown(String tableAlias, Project pushedDownProj) { + Collection columns = pushedDownProj.tableAliasColumns.get(tableAlias); + if (columns != null) { + tableAliasColumns.putAll(tableAlias, columns); } - - @Override - public String toString() { - List colStrs = new ArrayList<>(); - for (Map.Entry entry : tableAliasColumns.entries()) { - colStrs.add(entry.getKey() + "." + entry.getValue().getName()); - } - return "Project [ columns=[" + String.join(", ", colStrs) + "] ]"; + } + + /** Return mapping from column full name ("e.age") and alias ("a" in "SELECT e.age AS a") */ + private Map fullNameAndAlias() { + Map fullNamesAlias = new HashMap<>(); + forEach( + (tableAlias, fields) -> { + for (Field field : fields) { + fullNamesAlias.put(tableAlias + "." + field.getName(), field.getAlias()); + } + }); + return fullNamesAlias; + } + + @Override + public String toString() { + List colStrs = new ArrayList<>(); + for (Map.Entry entry : tableAliasColumns.entries()) { + colStrs.add(entry.getKey() + "." + entry.getValue().getName()); } - + return "Project [ columns=[" + String.join(", ", colStrs) + "] ]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/ProjectionPushDown.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/ProjectionPushDown.java index f5a3e28fce..5195894a75 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/ProjectionPushDown.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/ProjectionPushDown.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.logical.rule; import static java.util.stream.Collectors.toList; @@ -18,68 +17,54 @@ import org.opensearch.sql.legacy.query.planner.logical.node.Join; import org.opensearch.sql.legacy.query.planner.logical.node.Project; - -/** - * Projection push down optimization. - */ +/** Projection push down optimization. */ public class ProjectionPushDown implements LogicalPlanVisitor { - /** - * Project used to collect column names in SELECT, ON, ORDER BY... - */ - private final Project project = new Project(null); + /** Project used to collect column names in SELECT, ON, ORDER BY... */ + private final Project project = new Project(null); - @Override - public boolean visit(Project project) { - pushDown(project); - return true; - } + @Override + public boolean visit(Project project) { + pushDown(project); + return true; + } - @Override - public boolean visit(Join join) { - pushDown(join.conditions()); - return true; - } + @Override + public boolean visit(Join join) { + pushDown(join.conditions()); + return true; + } - @Override - public boolean visit(Group group) { - if (!project.isNoOp()) { - group.pushDown(project); - } - return false; // avoid iterating operators in virtual Group + @Override + public boolean visit(Group group) { + if (!project.isNoOp()) { + group.pushDown(project); } + return false; // avoid iterating operators in virtual Group + } - /** - * Note that raw type Project cause generic type of forEach be erased at compile time - */ - private void pushDown(Project project) { - project.forEach(this::project); - } - - private void pushDown(JoinCondition orCond) { - for (int i = 0; i < orCond.groupSize(); i++) { - project( - orCond.leftTableAlias(), - columnNamesToFields(orCond.leftColumnNames(i)) - ); - project( - orCond.rightTableAlias(), - columnNamesToFields(orCond.rightColumnNames(i)) - ); - } - } + /** Note that raw type Project cause generic type of forEach be erased at compile time */ + private void pushDown(Project project) { + project.forEach(this::project); + } - private void project(String tableAlias, Collection columns) { - project.project(tableAlias, columns); // Bug: Field doesn't implement hashCode() which leads to duplicate + private void pushDown(JoinCondition orCond) { + for (int i = 0; i < orCond.groupSize(); i++) { + project(orCond.leftTableAlias(), columnNamesToFields(orCond.leftColumnNames(i))); + project(orCond.rightTableAlias(), columnNamesToFields(orCond.rightColumnNames(i))); } + } - /** - * Convert column name string to Field object with empty alias - */ - private List columnNamesToFields(String[] colNames) { - return Arrays.stream(colNames). - map(name -> new Field(name, null)). // Alias is useless for pushed down project - collect(toList()); - } + private void project(String tableAlias, Collection columns) { + project.project( + tableAlias, columns); // Bug: Field doesn't implement hashCode() which leads to duplicate + } + /** Convert column name string to Field object with empty alias */ + private List columnNamesToFields(String[] colNames) { + return Arrays.stream(colNames) + .map(name -> new Field(name, null)) + . // Alias is useless for pushed down project + collect(toList()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/SelectionPushDown.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/SelectionPushDown.java index 61578f91b7..deae266afc 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/SelectionPushDown.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/rule/SelectionPushDown.java @@ -3,36 +3,32 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.logical.rule; import org.opensearch.sql.legacy.query.planner.logical.LogicalPlanVisitor; import org.opensearch.sql.legacy.query.planner.logical.node.Filter; import org.opensearch.sql.legacy.query.planner.logical.node.Group; -/** - * Push down selection (filter) - */ +/** Push down selection (filter) */ public class SelectionPushDown implements LogicalPlanVisitor { - /** - * Store the filter found in visit and reused to push down. - * It's not necessary to create a new one because no need to collect filter condition elsewhere - */ - private Filter filter; - - @Override - public boolean visit(Filter filter) { - this.filter = filter; - return true; + /** + * Store the filter found in visit and reused to push down. It's not necessary to create a new one + * because no need to collect filter condition elsewhere + */ + private Filter filter; + + @Override + public boolean visit(Filter filter) { + this.filter = filter; + return true; + } + + @Override + public boolean visit(Group group) { + if (filter != null && !filter.isNoOp()) { + group.pushDown(filter); } - - @Override - public boolean visit(Group group) { - if (filter != null && !filter.isNoOp()) { - group.pushDown(filter); - } - return false; // avoid iterating operators in virtual Group - } - + return false; // avoid iterating operators in virtual Group + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalOperator.java index 9271bae0d7..897beee3e9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical; import java.util.Iterator; @@ -11,40 +10,36 @@ import org.opensearch.sql.legacy.query.planner.core.PlanNode; import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; -/** - * Physical operator - */ +/** Physical operator */ public interface PhysicalOperator extends PlanNode, Iterator>, AutoCloseable { - /** - * Estimate the cost of current physical operator - * - * @return cost - */ - Cost estimate(); - - - /** - * Initialize operator. - * - * @param params exuecution parameters needed - */ - default void open(ExecuteParams params) throws Exception { - for (PlanNode node : children()) { - ((PhysicalOperator) node).open(params); - } + /** + * Estimate the cost of current physical operator + * + * @return cost + */ + Cost estimate(); + + /** + * Initialize operator. + * + * @param params exuecution parameters needed + */ + default void open(ExecuteParams params) throws Exception { + for (PlanNode node : children()) { + ((PhysicalOperator) node).open(params); } - - - /** - * Close resources related to the operator. - * - * @throws Exception potential exception raised - */ - @Override - default void close() { - for (PlanNode node : children()) { - ((PhysicalOperator) node).close(); - } + } + + /** + * Close resources related to the operator. + * + * @throws Exception potential exception raised + */ + @Override + default void close() { + for (PlanNode node : children()) { + ((PhysicalOperator) node).close(); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalPlan.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalPlan.java index eac4e855b0..5a79c63838 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalPlan.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/PhysicalPlan.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical; import java.util.ArrayList; @@ -18,81 +17,69 @@ import org.opensearch.sql.legacy.query.planner.physical.estimation.Estimation; import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; -/** - * Physical plan - */ +/** Physical plan */ public class PhysicalPlan implements Plan { - private static final Logger LOG = LogManager.getLogger(); + private static final Logger LOG = LogManager.getLogger(); - /** - * Optimized logical plan that being ready for physical planning - */ - private final LogicalPlan logicalPlan; + /** Optimized logical plan that being ready for physical planning */ + private final LogicalPlan logicalPlan; - /** - * Root of physical plan tree - */ - private PhysicalOperator root; + /** Root of physical plan tree */ + private PhysicalOperator root; - public PhysicalPlan(LogicalPlan logicalPlan) { - this.logicalPlan = logicalPlan; - } + public PhysicalPlan(LogicalPlan logicalPlan) { + this.logicalPlan = logicalPlan; + } - @Override - public void traverse(Visitor visitor) { - if (root != null) { - root.accept(visitor); - } + @Override + public void traverse(Visitor visitor) { + if (root != null) { + root.accept(visitor); } - - @Override - public void optimize() { - Estimation estimation = new Estimation<>(); - logicalPlan.traverse(estimation); - root = estimation.optimalPlan(); + } + + @Override + public void optimize() { + Estimation estimation = new Estimation<>(); + logicalPlan.traverse(estimation); + root = estimation.optimalPlan(); + } + + /** Execute physical plan after verifying if system is healthy at the moment */ + public List execute(ExecuteParams params) { + if (shouldReject(params)) { + throw new IllegalStateException("Query request rejected due to insufficient resource"); } - /** - * Execute physical plan after verifying if system is healthy at the moment - */ - public List execute(ExecuteParams params) { - if (shouldReject(params)) { - throw new IllegalStateException("Query request rejected due to insufficient resource"); - } - - try (PhysicalOperator op = root) { - return doExecutePlan(op, params); - } catch (Exception e) { - LOG.error("Error happened during execution", e); - // Runtime error or circuit break. Should we return partial result to customer? - throw new IllegalStateException("Error happened during execution", e); - } + try (PhysicalOperator op = root) { + return doExecutePlan(op, params); + } catch (Exception e) { + LOG.error("Error happened during execution", e); + // Runtime error or circuit break. Should we return partial result to customer? + throw new IllegalStateException("Error happened during execution", e); } - - /** - * Reject physical plan execution of new query request if unhealthy - */ - private boolean shouldReject(ExecuteParams params) { - return !((ResourceManager) params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER)).isHealthy(); + } + + /** Reject physical plan execution of new query request if unhealthy */ + private boolean shouldReject(ExecuteParams params) { + return !((ResourceManager) params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER)) + .isHealthy(); + } + + /** Execute physical plan in order: open, fetch result, close */ + private List doExecutePlan(PhysicalOperator op, ExecuteParams params) + throws Exception { + List hits = new ArrayList<>(); + op.open(params); + + while (op.hasNext()) { + hits.add(op.next().data()); } - /** - * Execute physical plan in order: open, fetch result, close - */ - private List doExecutePlan(PhysicalOperator op, - ExecuteParams params) throws Exception { - List hits = new ArrayList<>(); - op.open(params); - - while (op.hasNext()) { - hits.add(op.next().data()); - } - - if (LOG.isTraceEnabled()) { - hits.forEach(hit -> LOG.trace("Final result row: {}", hit.getSourceAsMap())); - } - return hits; + if (LOG.isTraceEnabled()) { + hits.forEach(hit -> LOG.trace("Final result row: {}", hit.getSourceAsMap())); } - + return hits; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/Row.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/Row.java index 9e7d81a194..5ed074da6d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/Row.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/Row.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical; import java.util.Arrays; @@ -17,106 +16,93 @@ */ public interface Row { - Row NULL = null; - - /** - * Generate key to represent identity of the row. - * - * @param colNames column names as keys - * @return row key - */ - RowKey key(String[] colNames); - - - /** - * Combine current row and another row together to generate a new combined row. - * - * @param otherRow another row - * @return combined row - */ - Row combine(Row otherRow); - - - /** - * Retain columns specified and rename to alias if any. - * - * @param colNameAlias column names to alias mapping - */ - void retain(Map colNameAlias); - + Row NULL = null; + + /** + * Generate key to represent identity of the row. + * + * @param colNames column names as keys + * @return row key + */ + RowKey key(String[] colNames); + + /** + * Combine current row and another row together to generate a new combined row. + * + * @param otherRow another row + * @return combined row + */ + Row combine(Row otherRow); + + /** + * Retain columns specified and rename to alias if any. + * + * @param colNameAlias column names to alias mapping + */ + void retain(Map colNameAlias); + + /** + * @return raw data of row wrapped inside + */ + T data(); + + /** Key that help Row be sorted or hashed. */ + class RowKey implements Comparable { + + /** Represent null key if any joined column value is NULL */ + public static final RowKey NULL = null; + + /** Values of row key */ + private final Object[] keys; + + /** Cached hash code since this class is intended to be used by hash table */ + private final int hashCode; + + public RowKey(Object... keys) { + this.keys = keys; + this.hashCode = Objects.hash(keys); + } - /** - * @return raw data of row wrapped inside - */ - T data(); + public Object[] keys() { + return keys; + } + @Override + public int hashCode() { + return hashCode; + } - /** - * Key that help Row be sorted or hashed. - */ - class RowKey implements Comparable { + @Override + public boolean equals(Object other) { + return other instanceof RowKey && Arrays.deepEquals(this.keys, ((RowKey) other).keys); + } - /** - * Represent null key if any joined column value is NULL - */ - public static final RowKey NULL = null; + @SuppressWarnings("unchecked") + @Override + public int compareTo(RowKey other) { + for (int i = 0; i < keys.length; i++) { - /** - * Values of row key + /* + * Only one is null, otherwise (both null or non-null) go ahead. + * Always consider NULL is smaller value which means NULL comes last in ASC and first in DESC */ - private final Object[] keys; - - /** - * Cached hash code since this class is intended to be used by hash table - */ - private final int hashCode; - - public RowKey(Object... keys) { - this.keys = keys; - this.hashCode = Objects.hash(keys); - } - - public Object[] keys() { - return keys; - } - - @Override - public int hashCode() { - return hashCode; + if (keys[i] == null ^ other.keys[i] == null) { + return keys[i] == null ? 1 : -1; } - @Override - public boolean equals(Object other) { - return other instanceof RowKey && Arrays.deepEquals(this.keys, ((RowKey) other).keys); - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(RowKey other) { - for (int i = 0; i < keys.length; i++) { - - /* - * Only one is null, otherwise (both null or non-null) go ahead. - * Always consider NULL is smaller value which means NULL comes last in ASC and first in DESC - */ - if (keys[i] == null ^ other.keys[i] == null) { - return keys[i] == null ? 1 : -1; - } - - if (keys[i] instanceof Comparable) { - int result = ((Comparable) keys[i]).compareTo(other.keys[i]); - if (result != 0) { - return result; - } - } // Ignore incomparable field silently? - } - return 0; - } - - @Override - public String toString() { - return "RowKey: " + Arrays.toString(keys); - } + if (keys[i] instanceof Comparable) { + int result = ((Comparable) keys[i]).compareTo(other.keys[i]); + if (result != 0) { + return result; + } + } // Ignore incomparable field silently? + } + return 0; + } + @Override + public String toString() { + return "RowKey: " + Arrays.toString(keys); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/project/PhysicalProject.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/project/PhysicalProject.java index 9c4bdc5c9e..e09ef5c3fe 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/project/PhysicalProject.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/project/PhysicalProject.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.project; import java.util.List; @@ -16,34 +15,34 @@ import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; import org.opensearch.sql.legacy.query.planner.physical.node.scroll.BindingTupleRow; -/** - * The definition of Project Operator. - */ +/** The definition of Project Operator. */ @RequiredArgsConstructor public class PhysicalProject implements PhysicalOperator { - private final PhysicalOperator next; - private final List fields; - - @Override - public Cost estimate() { - return null; - } - - @Override - public PlanNode[] children() { - return new PlanNode[]{next}; - } - - @Override - public boolean hasNext() { - return next.hasNext(); - } - - @Override - public Row next() { - BindingTuple input = next.next().data(); - BindingTuple.BindingTupleBuilder outputBindingTupleBuilder = BindingTuple.builder(); - fields.forEach(field -> outputBindingTupleBuilder.binding(field.getName(), field.getExpr().valueOf(input))); - return new BindingTupleRow(outputBindingTupleBuilder.build()); - } + private final PhysicalOperator next; + private final List fields; + + @Override + public Cost estimate() { + return null; + } + + @Override + public PlanNode[] children() { + return new PlanNode[] {next}; + } + + @Override + public boolean hasNext() { + return next.hasNext(); + } + + @Override + public Row next() { + BindingTuple input = next.next().data(); + BindingTuple.BindingTupleBuilder outputBindingTupleBuilder = BindingTuple.builder(); + fields.forEach( + field -> + outputBindingTupleBuilder.binding(field.getName(), field.getExpr().valueOf(input))); + return new BindingTupleRow(outputBindingTupleBuilder.build()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/PhysicalScroll.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/PhysicalScroll.java index 8866420218..16ad327a87 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/PhysicalScroll.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/PhysicalScroll.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.scroll; import java.util.Iterator; @@ -21,54 +20,53 @@ import org.opensearch.sql.legacy.query.planner.physical.Row; import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; -/** - * The definition of Scroll Operator. - */ +/** The definition of Scroll Operator. */ @RequiredArgsConstructor public class PhysicalScroll implements PhysicalOperator { - private final QueryAction queryAction; + private final QueryAction queryAction; - private Iterator rowIterator; + private Iterator rowIterator; - @Override - public Cost estimate() { - return null; - } + @Override + public Cost estimate() { + return null; + } - @Override - public PlanNode[] children() { - return new PlanNode[0]; - } + @Override + public PlanNode[] children() { + return new PlanNode[0]; + } - @Override - public boolean hasNext() { - return rowIterator.hasNext(); - } + @Override + public boolean hasNext() { + return rowIterator.hasNext(); + } - @Override - public Row next() { - return rowIterator.next(); - } + @Override + public Row next() { + return rowIterator.next(); + } - @Override - public void open(ExecuteParams params) { - try { - ActionResponse response = queryAction.explain().get(); - if (queryAction instanceof AggregationQueryAction) { - rowIterator = SearchAggregationResponseHelper - .populateSearchAggregationResponse(((SearchResponse) response).getAggregations()) - .iterator(); - } else { - throw new IllegalStateException("Not support QueryAction type: " + queryAction.getClass()); - } - } catch (SqlParseException e) { - throw new RuntimeException(e); - } + @Override + public void open(ExecuteParams params) { + try { + ActionResponse response = queryAction.explain().get(); + if (queryAction instanceof AggregationQueryAction) { + rowIterator = + SearchAggregationResponseHelper.populateSearchAggregationResponse( + ((SearchResponse) response).getAggregations()) + .iterator(); + } else { + throw new IllegalStateException("Not support QueryAction type: " + queryAction.getClass()); + } + } catch (SqlParseException e) { + throw new RuntimeException(e); } + } - @SneakyThrows - @Override - public String toString() { - return queryAction.explain().toString(); - } + @SneakyThrows + @Override + public String toString() { + return queryAction.explain().toString(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java index 2d781d7c3d..40e9860886 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.scroll; import java.util.Arrays; @@ -31,170 +30,160 @@ import org.opensearch.sql.legacy.query.planner.physical.node.BatchPhysicalOperator; import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; -/** - * OpenSearch Scroll API as physical implementation of TableScan - */ +/** OpenSearch Scroll API as physical implementation of TableScan */ public class Scroll extends BatchPhysicalOperator { - /** - * Request to submit to OpenSearch to scroll over - */ - private final TableInJoinRequestBuilder request; - - /** - * Page size to scroll over index - */ - private final int pageSize; - - /** - * Client connection to ElasticSearch - */ - private Client client; - - /** - * Currently undergoing Scroll - */ - private SearchResponse scrollResponse; - - /** - * Time out - */ - private Integer timeout; - - /** - * Resource monitor manager - */ - private ResourceManager resourceMgr; - - - public Scroll(TableInJoinRequestBuilder request, int pageSize) { - this.request = request; - this.pageSize = pageSize; + /** Request to submit to OpenSearch to scroll over */ + private final TableInJoinRequestBuilder request; + + /** Page size to scroll over index */ + private final int pageSize; + + /** Client connection to ElasticSearch */ + private Client client; + + /** Currently undergoing Scroll */ + private SearchResponse scrollResponse; + + /** Time out */ + private Integer timeout; + + /** Resource monitor manager */ + private ResourceManager resourceMgr; + + public Scroll(TableInJoinRequestBuilder request, int pageSize) { + this.request = request; + this.pageSize = pageSize; + } + + @Override + public PlanNode[] children() { + return new PlanNode[0]; + } + + @Override + public Cost estimate() { + return new Cost(); + } + + @Override + public void open(ExecuteParams params) throws Exception { + super.open(params); + client = params.get(ExecuteParams.ExecuteParamType.CLIENT); + timeout = params.get(ExecuteParams.ExecuteParamType.TIMEOUT); + resourceMgr = params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER); + + Object filter = params.get(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER); + if (filter instanceof BoolQueryBuilder) { + request + .getRequestBuilder() + .setQuery(generateNewQueryWithExtraFilter((BoolQueryBuilder) filter)); + + if (LOG.isDebugEnabled()) { + LOG.debug( + "Received extra query filter, re-build query: {}", + Strings.toString( + XContentType.JSON, request.getRequestBuilder().request().source(), true, true)); + } } - - @Override - public PlanNode[] children() { - return new PlanNode[0]; + } + + @Override + public void close() { + if (scrollResponse != null) { + LOG.debug("Closing all scroll resources"); + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(scrollResponse.getScrollId()).get(); + if (!clearScrollResponse.isSucceeded()) { + LOG.warn("Failed to close scroll: {}", clearScrollResponse.status()); + } + scrollResponse = null; + } else { + LOG.debug("Scroll already be closed"); } - - @Override - public Cost estimate() { - return new Cost(); + } + + @Override + protected Collection> prefetch() { + Objects.requireNonNull(client, "Client connection is not ready"); + Objects.requireNonNull(resourceMgr, "ResourceManager is not set"); + Objects.requireNonNull(timeout, "Time out is not set"); + + if (scrollResponse == null) { + loadFirstBatch(); + updateMetaResult(); + } else { + loadNextBatchByScrollId(); } - - @Override - public void open(ExecuteParams params) throws Exception { - super.open(params); - client = params.get(ExecuteParams.ExecuteParamType.CLIENT); - timeout = params.get(ExecuteParams.ExecuteParamType.TIMEOUT); - resourceMgr = params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER); - - Object filter = params.get(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER); - if (filter instanceof BoolQueryBuilder) { - request.getRequestBuilder().setQuery( - generateNewQueryWithExtraFilter((BoolQueryBuilder) filter)); - - if (LOG.isDebugEnabled()) { - LOG.debug("Received extra query filter, re-build query: {}", Strings.toString(XContentType.JSON, - request.getRequestBuilder().request().source(), true, true - )); - } - } + return wrapRowForCurrentBatch(); + } + + /** + * Extra filter pushed down from upstream. Re-parse WHERE clause with extra filter because + * OpenSearch RequestBuilder doesn't allow QueryBuilder inside be changed after added. + */ + private QueryBuilder generateNewQueryWithExtraFilter(BoolQueryBuilder filter) + throws SqlParseException { + Where where = request.getOriginalSelect().getWhere(); + BoolQueryBuilder newQuery; + if (where != null) { + newQuery = QueryMaker.explain(where, false); + newQuery.must(filter); + } else { + newQuery = filter; } - - @Override - public void close() { - if (scrollResponse != null) { - LOG.debug("Closing all scroll resources"); - ClearScrollResponse clearScrollResponse = client.prepareClearScroll(). - addScrollId(scrollResponse.getScrollId()). - get(); - if (!clearScrollResponse.isSucceeded()) { - LOG.warn("Failed to close scroll: {}", clearScrollResponse.status()); - } - scrollResponse = null; - } else { - LOG.debug("Scroll already be closed"); - } + return newQuery; + } + + private void loadFirstBatch() { + scrollResponse = + request + .getRequestBuilder() + .addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC) + .setSize(pageSize) + .setScroll(TimeValue.timeValueSeconds(timeout)) + .get(); + } + + private void updateMetaResult() { + resourceMgr.getMetaResult().addTotalNumOfShards(scrollResponse.getTotalShards()); + resourceMgr.getMetaResult().addSuccessfulShards(scrollResponse.getSuccessfulShards()); + resourceMgr.getMetaResult().addFailedShards(scrollResponse.getFailedShards()); + resourceMgr.getMetaResult().updateTimeOut(scrollResponse.isTimedOut()); + } + + private void loadNextBatchByScrollId() { + scrollResponse = + client + .prepareSearchScroll(scrollResponse.getScrollId()) + .setScroll(TimeValue.timeValueSeconds(timeout)) + .get(); + } + + @SuppressWarnings("unchecked") + private Collection> wrapRowForCurrentBatch() { + SearchHit[] hits = scrollResponse.getHits().getHits(); + Row[] rows = new Row[hits.length]; + for (int i = 0; i < hits.length; i++) { + rows[i] = new SearchHitRow(hits[i], request.getAlias()); } + return Arrays.asList(rows); + } - @Override - protected Collection> prefetch() { - Objects.requireNonNull(client, "Client connection is not ready"); - Objects.requireNonNull(resourceMgr, "ResourceManager is not set"); - Objects.requireNonNull(timeout, "Time out is not set"); - - if (scrollResponse == null) { - loadFirstBatch(); - updateMetaResult(); - } else { - loadNextBatchByScrollId(); - } - return wrapRowForCurrentBatch(); - } + @Override + public String toString() { + return "Scroll [ " + describeTable() + ", pageSize=" + pageSize + " ]"; + } - /** - * Extra filter pushed down from upstream. Re-parse WHERE clause with extra filter - * because OpenSearch RequestBuilder doesn't allow QueryBuilder inside be changed after added. - */ - private QueryBuilder generateNewQueryWithExtraFilter(BoolQueryBuilder filter) throws SqlParseException { - Where where = request.getOriginalSelect().getWhere(); - BoolQueryBuilder newQuery; - if (where != null) { - newQuery = QueryMaker.explain(where, false); - newQuery.must(filter); - } else { - newQuery = filter; - } - return newQuery; - } + private String describeTable() { + return request.getOriginalSelect().getFrom().get(0).getIndex() + " as " + request.getAlias(); + } - private void loadFirstBatch() { - scrollResponse = request.getRequestBuilder(). - addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC). - setSize(pageSize). - setScroll(TimeValue.timeValueSeconds(timeout)). - get(); - } - - private void updateMetaResult() { - resourceMgr.getMetaResult().addTotalNumOfShards(scrollResponse.getTotalShards()); - resourceMgr.getMetaResult().addSuccessfulShards(scrollResponse.getSuccessfulShards()); - resourceMgr.getMetaResult().addFailedShards(scrollResponse.getFailedShards()); - resourceMgr.getMetaResult().updateTimeOut(scrollResponse.isTimedOut()); - } - - private void loadNextBatchByScrollId() { - scrollResponse = client.prepareSearchScroll(scrollResponse.getScrollId()). - setScroll(TimeValue.timeValueSeconds(timeout)). - get(); - } - - @SuppressWarnings("unchecked") - private Collection> wrapRowForCurrentBatch() { - SearchHit[] hits = scrollResponse.getHits().getHits(); - Row[] rows = new Row[hits.length]; - for (int i = 0; i < hits.length; i++) { - rows[i] = new SearchHitRow(hits[i], request.getAlias()); - } - return Arrays.asList(rows); - } + /********************************************* + * Getters for Explain + *********************************************/ - @Override - public String toString() { - return "Scroll [ " + describeTable() + ", pageSize=" + pageSize + " ]"; - } - - private String describeTable() { - return request.getOriginalSelect().getFrom().get(0).getIndex() + " as " + request.getAlias(); - } - - - /********************************************* - * Getters for Explain - *********************************************/ - - public String getRequest() { - return Strings.toString(XContentType.JSON, request.getRequestBuilder().request().source()); - } + public String getRequest() { + return Strings.toString(XContentType.JSON, request.getRequestBuilder().request().source()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchAggregationResponseHelper.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchAggregationResponseHelper.java index 5e0ce1f2b4..ed0e0f2423 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchAggregationResponseHelper.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchAggregationResponseHelper.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.scroll; import com.google.common.annotations.VisibleForTesting; @@ -22,70 +21,82 @@ import org.opensearch.search.aggregations.metrics.Percentiles; import org.opensearch.sql.legacy.expression.domain.BindingTuple; -/** - * The definition of Search {@link Aggregations} parser helper class. - */ +/** The definition of Search {@link Aggregations} parser helper class. */ public class SearchAggregationResponseHelper { - public static List populateSearchAggregationResponse(Aggregations aggs) { - List> flatten = flatten(aggs); - List bindingTupleList = flatten.stream() - .map(BindingTuple::from) - .map(bindingTuple -> new BindingTupleRow(bindingTuple)) - .collect(Collectors.toList()); - return bindingTupleList; - } + public static List populateSearchAggregationResponse(Aggregations aggs) { + List> flatten = flatten(aggs); + List bindingTupleList = + flatten.stream() + .map(BindingTuple::from) + .map(bindingTuple -> new BindingTupleRow(bindingTuple)) + .collect(Collectors.toList()); + return bindingTupleList; + } - @VisibleForTesting - public static List> flatten(Aggregations aggregations) { - List aggregationList = aggregations.asList(); - List> resultList = new ArrayList<>(); - Map resultMap = new HashMap<>(); - for (Aggregation aggregation : aggregationList) { - if (aggregation instanceof Terms) { - for (Terms.Bucket bucket : ((Terms) aggregation).getBuckets()) { - List> internalBucketList = flatten(bucket.getAggregations()); - fillResultListWithInternalBucket(resultList, internalBucketList, aggregation.getName(), - bucket.getKey()); - } - } else if (aggregation instanceof NumericMetricsAggregation.SingleValue) { - resultMap.put(aggregation.getName(), ((NumericMetricsAggregation.SingleValue) aggregation).value()); - } else if (aggregation instanceof Percentiles) { - Percentiles percentiles = (Percentiles) aggregation; - resultMap.putAll((Map) StreamSupport.stream(percentiles.spliterator(), false) - .collect(Collectors.toMap( - (percentile) -> String.format("%s_%s", percentiles.getName(), percentile.getPercent()), - Percentile::getValue, (v1, v2) -> { - throw new IllegalArgumentException( - String.format("Duplicate key for values %s and %s", v1, v2)); - }, HashMap::new))); - } else if (aggregation instanceof Histogram) { - for (Histogram.Bucket bucket : ((Histogram) aggregation).getBuckets()) { - List> internalBucketList = flatten(bucket.getAggregations()); - fillResultListWithInternalBucket(resultList, internalBucketList, aggregation.getName(), - bucket.getKeyAsString()); - } - } else { - throw new RuntimeException("unsupported aggregation type " + aggregation.getType()); - } + @VisibleForTesting + public static List> flatten(Aggregations aggregations) { + List aggregationList = aggregations.asList(); + List> resultList = new ArrayList<>(); + Map resultMap = new HashMap<>(); + for (Aggregation aggregation : aggregationList) { + if (aggregation instanceof Terms) { + for (Terms.Bucket bucket : ((Terms) aggregation).getBuckets()) { + List> internalBucketList = flatten(bucket.getAggregations()); + fillResultListWithInternalBucket( + resultList, internalBucketList, aggregation.getName(), bucket.getKey()); } - if (!resultMap.isEmpty()) { - resultList.add(resultMap); + } else if (aggregation instanceof NumericMetricsAggregation.SingleValue) { + resultMap.put( + aggregation.getName(), ((NumericMetricsAggregation.SingleValue) aggregation).value()); + } else if (aggregation instanceof Percentiles) { + Percentiles percentiles = (Percentiles) aggregation; + resultMap.putAll( + (Map) + StreamSupport.stream(percentiles.spliterator(), false) + .collect( + Collectors.toMap( + (percentile) -> + String.format( + "%s_%s", percentiles.getName(), percentile.getPercent()), + Percentile::getValue, + (v1, v2) -> { + throw new IllegalArgumentException( + String.format("Duplicate key for values %s and %s", v1, v2)); + }, + HashMap::new))); + } else if (aggregation instanceof Histogram) { + for (Histogram.Bucket bucket : ((Histogram) aggregation).getBuckets()) { + List> internalBucketList = flatten(bucket.getAggregations()); + fillResultListWithInternalBucket( + resultList, internalBucketList, aggregation.getName(), bucket.getKeyAsString()); } - return resultList; + } else { + throw new RuntimeException("unsupported aggregation type " + aggregation.getType()); + } + } + if (!resultMap.isEmpty()) { + resultList.add(resultMap); } + return resultList; + } - private static void fillResultListWithInternalBucket(List> resultList, - List> internalBucketList, - String aggregationName, Object bucketKey) { - if (internalBucketList.isEmpty()) { - resultList.add(new HashMap() {{ - put(aggregationName, bucketKey); - }}); - } else { - for (Map map : internalBucketList) { - map.put(aggregationName, bucketKey); + private static void fillResultListWithInternalBucket( + List> resultList, + List> internalBucketList, + String aggregationName, + Object bucketKey) { + if (internalBucketList.isEmpty()) { + resultList.add( + new HashMap() { + { + put(aggregationName, bucketKey); } - resultList.addAll(internalBucketList); - } + }); + } else { + for (Map map : internalBucketList) { + map.put(aggregationName, bucketKey); + } + resultList.addAll(internalBucketList); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java index 27e3072bab..1750563e47 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.scroll; import com.google.common.base.Strings; @@ -15,6 +14,7 @@ import org.opensearch.sql.legacy.query.planner.physical.Row; /** + *
  * Search hit row that implements basic accessor for SearchHit.
  * Encapsulate all OpenSearch specific knowledge: how to parse source including nested path.
  * 

@@ -32,164 +32,156 @@ * ---------------------------------------------------------------------------------------------------------------------- * retain() in Project | {"firstName": "Allen", "age": 30 } | "" | retain("e.name.first", "e.age") * ---------------------------------------------------------------------------------------------------------------------- + *

*/ class SearchHitRow implements Row { - /** - * Native OpenSearch data object for each row - */ - private final SearchHit hit; - - /** - * Column and value pairs - */ - private final Map source; - - /** - * Table alias owned the row. Empty if this row comes from combination of two other rows - */ - private final String tableAlias; + /** Native OpenSearch data object for each row */ + private final SearchHit hit; - SearchHitRow(SearchHit hit, String tableAlias) { - this.hit = hit; - this.source = hit.getSourceAsMap(); - this.tableAlias = tableAlias; - } + /** Column and value pairs */ + private final Map source; - @Override - public RowKey key(String[] colNames) { - if (colNames.length == 0) { - return RowKey.NULL; - } - - Object[] keys = new Object[colNames.length]; - for (int i = 0; i < colNames.length; i++) { - keys[i] = getValueOfPath(colNames[i]); - - if (keys[i] == null) { - return RowKey.NULL; - } - } - return new RowKey(keys); - } + /** Table alias owned the row. Empty if this row comes from combination of two other rows */ + private final String tableAlias; - /** - * Replace column name by full name to avoid naming conflicts. - * For efficiency, this only happens here when matched rows found. - * Create a new one to avoid mutating the original ones in hash table which impact subsequent match. - */ - @Override - public Row combine(Row other) { - SearchHit combined = cloneHit(other); - - collectFullName(combined.getSourceAsMap(), this); - if (other != NULL) { - collectFullName(combined.getSourceAsMap(), (SearchHitRow) other); - } - return new SearchHitRow(combined, ""); - } + SearchHitRow(SearchHit hit, String tableAlias) { + this.hit = hit; + this.source = hit.getSourceAsMap(); + this.tableAlias = tableAlias; + } - @Override - public void retain(Map colNameAlias) { - Map aliasSource = new HashMap<>(); - colNameAlias.forEach((colName, alias) -> { - if (colName.endsWith(".*")) { - String tableAlias = colName.substring(0, colName.length() - 2) + "."; - retainAllFieldsFromTable(aliasSource, tableAlias); - } else { - retainOneField(aliasSource, colName, alias); - } - }); - resetSource(aliasSource); + @Override + public RowKey key(String[] colNames) { + if (colNames.length == 0) { + return RowKey.NULL; } - @Override - public SearchHit data() { - return hit; - } + Object[] keys = new Object[colNames.length]; + for (int i = 0; i < colNames.length; i++) { + keys[i] = getValueOfPath(colNames[i]); - @Override - public String toString() { - return "SearchHitRow{" + "hit=" + source + '}'; + if (keys[i] == null) { + return RowKey.NULL; + } } - - private Object getValueOfPath(String path) { - /* - * If table alias is missing which means the row was generated by combine(). - * In this case, table alias is present and the first dot should be ignored, ex. "e.name.first" - */ - return getValueOfPath(source, path, Strings.isNullOrEmpty(tableAlias)); + return new RowKey(keys); + } + + /** + * Replace column name by full name to avoid naming conflicts. For efficiency, this only happens + * here when matched rows found. Create a new one to avoid mutating the original ones in hash + * table which impact subsequent match. + */ + @Override + public Row combine(Row other) { + SearchHit combined = cloneHit(other); + + collectFullName(combined.getSourceAsMap(), this); + if (other != NULL) { + collectFullName(combined.getSourceAsMap(), (SearchHitRow) other); } - - /** - * Recursively get value for field name path, such as object field a.b.c + return new SearchHitRow(combined, ""); + } + + @Override + public void retain(Map colNameAlias) { + Map aliasSource = new HashMap<>(); + colNameAlias.forEach( + (colName, alias) -> { + if (colName.endsWith(".*")) { + String tableAlias = colName.substring(0, colName.length() - 2) + "."; + retainAllFieldsFromTable(aliasSource, tableAlias); + } else { + retainOneField(aliasSource, colName, alias); + } + }); + resetSource(aliasSource); + } + + @Override + public SearchHit data() { + return hit; + } + + @Override + public String toString() { + return "SearchHitRow{" + "hit=" + source + '}'; + } + + private Object getValueOfPath(String path) { + /* + * If table alias is missing which means the row was generated by combine(). + * In this case, table alias is present and the first dot should be ignored, ex. "e.name.first" */ - private Object getValueOfPath(Object source, String path, boolean isIgnoreFirstDot) { - if (!(source instanceof Map) || path.isEmpty()) { - return source; - } - - int dot = path.indexOf('.', (isIgnoreFirstDot ? path.indexOf('.') + 1 : 0)); - if (dot == -1) { - return ((Map) source).get(path); - } - - // Object field name maybe unexpanded without recursive object structure - // ex. {"a.b.c": value} instead of {"a": {"b": {"c": value}}}} - if (((Map) source).containsKey(path)) { - return ((Map) source).get(path); - } - - return getValueOfPath( - ((Map) source).get(path.substring(0, dot)), - path.substring(dot + 1), - false - ); - } + return getValueOfPath(source, path, Strings.isNullOrEmpty(tableAlias)); + } - private SearchHit cloneHit(Row other) { - Map documentFields = new HashMap<>(); - Map metaFields = new HashMap<>(); - hit.getFields().forEach((fieldName, docField) -> - (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) ? metaFields : documentFields).put(fieldName, docField)); - SearchHit combined = new SearchHit( - hit.docId(), - hit.getId() + "|" + (other == NULL ? "0" : ((SearchHitRow) other).hit.getId()), - documentFields, - metaFields - ); - combined.sourceRef(hit.getSourceRef()); - combined.getSourceAsMap().clear(); - return combined; + /** Recursively get value for field name path, such as object field a.b.c */ + private Object getValueOfPath(Object source, String path, boolean isIgnoreFirstDot) { + if (!(source instanceof Map) || path.isEmpty()) { + return source; } - private void collectFullName(Map newSource, SearchHitRow row) { - row.source.forEach((colName, value) -> newSource.put(row.tableAlias + "." + colName, value)); + int dot = path.indexOf('.', (isIgnoreFirstDot ? path.indexOf('.') + 1 : 0)); + if (dot == -1) { + return ((Map) source).get(path); } - private void retainAllFieldsFromTable(Map aliasSource, String tableAlias) { - source.entrySet(). - stream(). - filter(e -> e.getKey().startsWith(tableAlias)). - forEach(e -> aliasSource.put(e.getKey(), e.getValue())); + // Object field name maybe unexpanded without recursive object structure + // ex. {"a.b.c": value} instead of {"a": {"b": {"c": value}}}} + if (((Map) source).containsKey(path)) { + return ((Map) source).get(path); } - /** - * Note that column here is already prefixed by table alias after combine(). - *

- * Meanwhile check if column name with table alias prefix, ex. a.name, is property, namely a.name.lastname. - * In this case, split by first second dot and continue searching for the final value in nested map - * by getValueOfPath(source.get("a.name"), "lastname") - */ - private void retainOneField(Map aliasSource, String colName, String alias) { - aliasSource.put( - Strings.isNullOrEmpty(alias) ? colName : alias, - getValueOfPath(colName) - ); - } - - private void resetSource(Map newSource) { - source.clear(); - source.putAll(newSource); - } + return getValueOfPath( + ((Map) source).get(path.substring(0, dot)), path.substring(dot + 1), false); + } + + private SearchHit cloneHit(Row other) { + Map documentFields = new HashMap<>(); + Map metaFields = new HashMap<>(); + hit.getFields() + .forEach( + (fieldName, docField) -> + (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) + ? metaFields + : documentFields) + .put(fieldName, docField)); + SearchHit combined = + new SearchHit( + hit.docId(), + hit.getId() + "|" + (other == NULL ? "0" : ((SearchHitRow) other).hit.getId()), + documentFields, + metaFields); + combined.sourceRef(hit.getSourceRef()); + combined.getSourceAsMap().clear(); + return combined; + } + + private void collectFullName(Map newSource, SearchHitRow row) { + row.source.forEach((colName, value) -> newSource.put(row.tableAlias + "." + colName, value)); + } + + private void retainAllFieldsFromTable(Map aliasSource, String tableAlias) { + source.entrySet().stream() + .filter(e -> e.getKey().startsWith(tableAlias)) + .forEach(e -> aliasSource.put(e.getKey(), e.getValue())); + } + + /** + * Note that column here is already prefixed by table alias after combine(). + * + *

Meanwhile check if column name with table alias prefix, ex. a.name, is property, namely + * a.name.lastname. In this case, split by first second dot and continue searching for the final + * value in nested map by getValueOfPath(source.get("a.name"), "lastname") + */ + private void retainOneField(Map aliasSource, String colName, String alias) { + aliasSource.put(Strings.isNullOrEmpty(alias) ? colName : alias, getValueOfPath(colName)); + } + + private void resetSource(Map newSource) { + source.clear(); + source.putAll(newSource); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/sort/QuickSort.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/sort/QuickSort.java index 90ae595d56..abfcf273ad 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/sort/QuickSort.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/sort/QuickSort.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.sort; import static java.util.Collections.emptyList; @@ -23,83 +22,80 @@ import org.opensearch.sql.legacy.query.planner.physical.node.BatchPhysicalOperator; /** - * Physical operator to sort by quick sort implementation in JDK. - * Note that this is all in-memory operator which may be a problem for large index. + * Physical operator to sort by quick sort implementation in JDK. Note that this is all in-memory + * operator which may be a problem for large index. * * @param actual data type, ex.SearchHit */ public class QuickSort extends BatchPhysicalOperator { - private static final Logger LOG = LogManager.getLogger(); + private static final Logger LOG = LogManager.getLogger(); - private final PhysicalOperator next; + private final PhysicalOperator next; - /** - * Column name list in ORDER BY - */ - private final String[] orderByColNames; + /** Column name list in ORDER BY */ + private final String[] orderByColNames; - /** - * Order by type, ex. ASC, DESC - */ - private final String orderByType; + /** Order by type, ex. ASC, DESC */ + private final String orderByType; - private boolean isDone = false; + private boolean isDone = false; - public QuickSort(PhysicalOperator next, List orderByColNames, String orderByType) { - this.next = next; - this.orderByColNames = orderByColNames.toArray(new String[0]); - this.orderByType = orderByType; - } + public QuickSort(PhysicalOperator next, List orderByColNames, String orderByType) { + this.next = next; + this.orderByColNames = orderByColNames.toArray(new String[0]); + this.orderByType = orderByType; + } - @Override - public PlanNode[] children() { - return new PlanNode[]{next}; - } + @Override + public PlanNode[] children() { + return new PlanNode[] {next}; + } - @Override - public Cost estimate() { - return new Cost(); - } + @Override + public Cost estimate() { + return new Cost(); + } - @Override - public void open(ExecuteParams params) throws Exception { - super.open(params); - next.open(params); - } + @Override + public void open(ExecuteParams params) throws Exception { + super.open(params); + next.open(params); + } - /** - * Only load all data once and return one batch - */ - @Override - protected Collection> prefetch() { - if (isDone) { - return emptyList(); - } - - List> allRowsSorted = new ArrayList<>(); - next.forEachRemaining(allRowsSorted::add); - allRowsSorted.sort(createRowComparator()); - - if (LOG.isTraceEnabled()) { - LOG.trace("All rows being sorted in RB-Tree: {}", allRowsSorted); - } - - isDone = true; - return allRowsSorted; + /** Only load all data once and return one batch */ + @Override + protected Collection> prefetch() { + if (isDone) { + return emptyList(); } - private Comparator> createRowComparator() { - Comparator> comparator = Comparator.comparing(o -> o.key(orderByColNames)); - if ("DESC".equals(orderByType)) { - comparator = comparator.reversed(); - } - return comparator; - } + List> allRowsSorted = new ArrayList<>(); + next.forEachRemaining(allRowsSorted::add); + allRowsSorted.sort(createRowComparator()); - @Override - public String toString() { - return "QuickSort [ columns=" + Arrays.toString(orderByColNames) + ", order=" + orderByType + " ]"; + if (LOG.isTraceEnabled()) { + LOG.trace("All rows being sorted in RB-Tree: {}", allRowsSorted); } + isDone = true; + return allRowsSorted; + } + + private Comparator> createRowComparator() { + Comparator> comparator = Comparator.comparing(o -> o.key(orderByColNames)); + if ("DESC".equals(orderByType)) { + comparator = comparator.reversed(); + } + return comparator; + } + + @Override + public String toString() { + return "QuickSort [ columns=" + + Arrays.toString(orderByColNames) + + ", order=" + + orderByType + + " ]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/ResourceManager.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/ResourceManager.java index 32cc7f45e3..4818d0a3ee 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/ResourceManager.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/ResourceManager.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.resource; import java.time.Duration; @@ -18,55 +17,48 @@ import org.opensearch.sql.legacy.query.planner.resource.monitor.Monitor; import org.opensearch.sql.legacy.query.planner.resource.monitor.TotalMemoryMonitor; -/** - * Aggregated resource monitor - */ +/** Aggregated resource monitor */ public class ResourceManager { - private static final Logger LOG = LogManager.getLogger(); + private static final Logger LOG = LogManager.getLogger(); + + /** Actual resource monitor list */ + private final List monitors = new ArrayList<>(); - /** - * Actual resource monitor list - */ - private final List monitors = new ArrayList<>(); + /** Time out for the execution */ + private final int timeout; - /** - * Time out for the execution - */ - private final int timeout; - private final Instant startTime; + private final Instant startTime; - /** - * Meta result of the execution - */ - private final MetaSearchResult metaResult; + /** Meta result of the execution */ + private final MetaSearchResult metaResult; - public ResourceManager(Stats stats, Config config) { - this.monitors.add(new TotalMemoryMonitor(stats, config)); - this.timeout = config.timeout(); - this.startTime = Instant.now(); - this.metaResult = new MetaSearchResult(); - } + public ResourceManager(Stats stats, Config config) { + this.monitors.add(new TotalMemoryMonitor(stats, config)); + this.timeout = config.timeout(); + this.startTime = Instant.now(); + this.metaResult = new MetaSearchResult(); + } - /** - * Is all resource monitor healthy with strategy. - * - * @return true for yes - */ - public boolean isHealthy() { - return BackOffRetryStrategy.isHealthy(); - } + /** + * Is all resource monitor healthy with strategy. + * + * @return true for yes + */ + public boolean isHealthy() { + return BackOffRetryStrategy.isHealthy(); + } - /** - * Is current execution time out? - * - * @return true for yes - */ - public boolean isTimeout() { - return Duration.between(startTime, Instant.now()).getSeconds() >= timeout; - } + /** + * Is current execution time out? + * + * @return true for yes + */ + public boolean isTimeout() { + return Duration.between(startTime, Instant.now()).getSeconds() >= timeout; + } - public MetaSearchResult getMetaResult() { - return metaResult; - } + public MetaSearchResult getMetaResult() { + return metaResult; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/request/PreparedStatementRequest.java b/legacy/src/main/java/org/opensearch/sql/legacy/request/PreparedStatementRequest.java index deff4e2393..c32e529157 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/request/PreparedStatementRequest.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/request/PreparedStatementRequest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.request; import java.util.List; @@ -11,174 +10,181 @@ public class PreparedStatementRequest extends SqlRequest { - private List parameters; - private String sqlTemplate; - - public PreparedStatementRequest(String sql, JSONObject payloadJson, List parameters) { - super(null, payloadJson); - this.sqlTemplate = sql; - this.parameters = parameters; - this.sql = this.substituteParameters(); - } - - public PreparedStatementRequest(String sql, final Integer fetchSize, - JSONObject payloadJson, List parameters) { - this(sql, payloadJson, parameters); - this.fetchSize = fetchSize; + private List parameters; + private String sqlTemplate; + + public PreparedStatementRequest( + String sql, JSONObject payloadJson, List parameters) { + super(null, payloadJson); + this.sqlTemplate = sql; + this.parameters = parameters; + this.sql = this.substituteParameters(); + } + + public PreparedStatementRequest( + String sql, + final Integer fetchSize, + JSONObject payloadJson, + List parameters) { + this(sql, payloadJson, parameters); + this.fetchSize = fetchSize; + } + + public List getParameters() { + return this.parameters; + } + + @Override + public String getSql() { + return this.sql; + } + + public String getPreparedStatement() { + return this.sqlTemplate; + } + + private String substituteParameters() { + if (this.sqlTemplate == null) { + return null; } - public List getParameters() { - return this.parameters; - } - - @Override - public String getSql() { - return this.sql; - } - - public String getPreparedStatement() { - return this.sqlTemplate; - } - - private String substituteParameters() { - if (this.sqlTemplate == null) { - return null; - } - - StringBuilder sb = new StringBuilder(); - int paramIndex = 0; - int i = 0; + StringBuilder sb = new StringBuilder(); + int paramIndex = 0; + int i = 0; + while (i < this.sqlTemplate.length()) { + char c = this.sqlTemplate.charAt(i); + if (c == '\'') { + // found string starting quote character, skip the string + sb.append(c); + i++; while (i < this.sqlTemplate.length()) { - char c = this.sqlTemplate.charAt(i); - if (c == '\'') { - // found string starting quote character, skip the string - sb.append(c); - i++; - while (i < this.sqlTemplate.length()) { - char s = this.sqlTemplate.charAt(i); - sb.append(s); - if (s == '\'') { - if (this.sqlTemplate.charAt(i - 1) == '\\') { - // this is an escaped single quote (\') still in the string - i++; - } else if ((i + 1) < this.sqlTemplate.length() && this.sqlTemplate.charAt(i + 1) == '\'') { - // found 2 single quote {''} in a string, which is escaped single quote {'} - // move to next character - sb.append('\''); - i += 2; - } else { - // found the string ending single quote char - break; - } - } else { - // not single quote character, move on - i++; - } - } - } else if (c == '?') { - // question mark "?" not in a string - if (paramIndex >= this.parameters.size()) { - throw new IllegalStateException("Placeholder count is greater than parameter number " - + parameters.size() + " . Cannot convert PreparedStatement to sql query"); - } - sb.append(this.parameters.get(paramIndex).getSqlSubstitutionValue()); - paramIndex++; + char s = this.sqlTemplate.charAt(i); + sb.append(s); + if (s == '\'') { + if (this.sqlTemplate.charAt(i - 1) == '\\') { + // this is an escaped single quote (\') still in the string + i++; + } else if ((i + 1) < this.sqlTemplate.length() + && this.sqlTemplate.charAt(i + 1) == '\'') { + // found 2 single quote {''} in a string, which is escaped single quote {'} + // move to next character + sb.append('\''); + i += 2; } else { - // other character, simply append - sb.append(c); + // found the string ending single quote char + break; } + } else { + // not single quote character, move on i++; + } } - - return sb.toString(); + } else if (c == '?') { + // question mark "?" not in a string + if (paramIndex >= this.parameters.size()) { + throw new IllegalStateException( + "Placeholder count is greater than parameter number " + + parameters.size() + + " . Cannot convert PreparedStatement to sql query"); + } + sb.append(this.parameters.get(paramIndex).getSqlSubstitutionValue()); + paramIndex++; + } else { + // other character, simply append + sb.append(c); + } + i++; } - ////////////////////////////////////////////////// - // Parameter related types below - ////////////////////////////////////////////////// - public enum ParameterType { - BYTE, - SHORT, - INTEGER, - LONG, - FLOAT, - DOUBLE, - BOOLEAN, - STRING, - KEYWORD, - DATE, - NULL + return sb.toString(); + } + + ////////////////////////////////////////////////// + // Parameter related types below + ////////////////////////////////////////////////// + public enum ParameterType { + BYTE, + SHORT, + INTEGER, + LONG, + FLOAT, + DOUBLE, + BOOLEAN, + STRING, + KEYWORD, + DATE, + NULL + } + + public static class PreparedStatementParameter { + protected T value; + + public PreparedStatementParameter(T value) { + this.value = value; } - public static class PreparedStatementParameter { - protected T value; - - public PreparedStatementParameter(T value) { - this.value = value; - } - - public String getSqlSubstitutionValue() { - return String.valueOf(this.value); - } + public String getSqlSubstitutionValue() { + return String.valueOf(this.value); + } - public T getValue() { - return this.value; - } + public T getValue() { + return this.value; } + } - public static class StringParameter extends PreparedStatementParameter { + public static class StringParameter extends PreparedStatementParameter { - public StringParameter(String value) { - super(value); - } + public StringParameter(String value) { + super(value); + } - @Override - public String getSqlSubstitutionValue() { - // TODO: investigate other injection prevention - if (this.value == null) { - return "null"; - } - StringBuilder sb = new StringBuilder(); - sb.append('\''); // starting quote - for (int i = 0; i < this.value.length(); i++) { - char c = this.value.charAt(i); - switch (c) { - case 0: - sb.append('\\').append(0); - break; - case '\n': - sb.append('\\').append('n'); - break; - case '\r': - sb.append('\\').append('r'); - break; - case '\\': - sb.append('\\').append('\\'); - break; - case '\'': - sb.append('\\').append('\''); - break; - case '\"': - sb.append('\\').append('\"'); - break; - default: - sb.append(c); - } - } - sb.append('\''); // ending quote - return sb.toString(); + @Override + public String getSqlSubstitutionValue() { + // TODO: investigate other injection prevention + if (this.value == null) { + return "null"; + } + StringBuilder sb = new StringBuilder(); + sb.append('\''); // starting quote + for (int i = 0; i < this.value.length(); i++) { + char c = this.value.charAt(i); + switch (c) { + case 0: + sb.append('\\').append(0); + break; + case '\n': + sb.append('\\').append('n'); + break; + case '\r': + sb.append('\\').append('r'); + break; + case '\\': + sb.append('\\').append('\\'); + break; + case '\'': + sb.append('\\').append('\''); + break; + case '\"': + sb.append('\\').append('\"'); + break; + default: + sb.append(c); } + } + sb.append('\''); // ending quote + return sb.toString(); } + } - public static class NullParameter extends PreparedStatementParameter { + public static class NullParameter extends PreparedStatementParameter { - public NullParameter() { - super(null); - } + public NullParameter() { + super(null); + } - @Override - public String getSqlSubstitutionValue() { - return "null"; - } + @Override + public String getSqlSubstitutionValue() { + return "null"; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRule.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRule.java index 6744bfa3e5..cd6400ed88 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRule.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRule.java @@ -3,29 +3,26 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; import java.sql.SQLFeatureNotSupportedException; -/** - * Query Optimize Rule - */ +/** Query Optimize Rule */ public interface RewriteRule { - /** - * Checking whether the rule match the query? - * - * @return true if the rule match to the query. - * @throws SQLFeatureNotSupportedException - */ - boolean match(T expr) throws SQLFeatureNotSupportedException; + /** + * Checking whether the rule match the query? + * + * @return true if the rule match to the query. + * @throws SQLFeatureNotSupportedException + */ + boolean match(T expr) throws SQLFeatureNotSupportedException; - /** - * Optimize the query. - * - * @throws SQLFeatureNotSupportedException - */ - void rewrite(T expr) throws SQLFeatureNotSupportedException; + /** + * Optimize the query. + * + * @throws SQLFeatureNotSupportedException + */ + void rewrite(T expr) throws SQLFeatureNotSupportedException; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRuleExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRuleExecutor.java index 86aa3d0b20..20fd018ae8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRuleExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/RewriteRuleExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; @@ -11,50 +10,42 @@ import java.util.ArrayList; import java.util.List; -/** - * Query RewriteRuleExecutor which will execute the {@link RewriteRule} with registered order. - */ +/** Query RewriteRuleExecutor which will execute the {@link RewriteRule} with registered order. */ public class RewriteRuleExecutor { - private final List> rewriteRules; - - public RewriteRuleExecutor(List> rewriteRules) { - this.rewriteRules = rewriteRules; + private final List> rewriteRules; + + public RewriteRuleExecutor(List> rewriteRules) { + this.rewriteRules = rewriteRules; + } + + /** Execute the registered {@link RewriteRule} in order on the Query. */ + public void executeOn(T expr) throws SQLFeatureNotSupportedException { + for (RewriteRule rule : rewriteRules) { + if (rule.match(expr)) { + rule.rewrite(expr); + } } - - /** - * Execute the registered {@link RewriteRule} in order on the Query. - */ - public void executeOn(T expr) throws SQLFeatureNotSupportedException { - for (RewriteRule rule : rewriteRules) { - if (rule.match(expr)) { - rule.rewrite(expr); - } - } - } - - /** - * Build {@link RewriteRuleExecutor} - */ - public static BuilderOptimizer builder() { - return new BuilderOptimizer(); + } + + /** Build {@link RewriteRuleExecutor} */ + public static BuilderOptimizer builder() { + return new BuilderOptimizer(); + } + + /** Builder of {@link RewriteRuleExecutor} */ + public static class BuilderOptimizer { + private List> rewriteRules; + + public BuilderOptimizer withRule(RewriteRule rule) { + if (rewriteRules == null) { + rewriteRules = new ArrayList<>(); + } + rewriteRules.add(rule); + return this; } - /** - * Builder of {@link RewriteRuleExecutor} - */ - public static class BuilderOptimizer { - private List> rewriteRules; - - public BuilderOptimizer withRule(RewriteRule rule) { - if (rewriteRules == null) { - rewriteRules = new ArrayList<>(); - } - rewriteRules.add(rule); - return this; - } - - public RewriteRuleExecutor build() { - return new RewriteRuleExecutor(rewriteRules); - } + public RewriteRuleExecutor build() { + return new RewriteRuleExecutor(rewriteRules); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldProjection.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldProjection.java index 4fa4611f9a..83a94b1e9b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldProjection.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldProjection.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.nestedfield; import static com.alibaba.druid.sql.ast.statement.SQLJoinTableSource.JoinType; @@ -33,86 +32,90 @@ import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException; import org.opensearch.sql.legacy.utils.StringUtils; -/** - * Nested field projection class to make OpenSearch return matched rows in nested field. - */ +/** Nested field projection class to make OpenSearch return matched rows in nested field. */ public class NestedFieldProjection { - private final SearchRequestBuilder request; + private final SearchRequestBuilder request; + + public NestedFieldProjection(SearchRequestBuilder request) { + this.request = request; + } + + /** + * Project nested field in SELECT clause to InnerHit in NestedQueryBuilder + * + * @param fields list of field domain object + */ + public void project(List fields, JoinType nestedJoinType) { + if (isAnyNestedField(fields)) { + initBoolQueryFilterIfNull(); + List nestedQueries = extractNestedQueries(query()); + + if (nestedJoinType == JoinType.LEFT_OUTER_JOIN) { + // for LEFT JOIN on nested field as right table, the query will have only one nested field, + // so one path + Map> fieldNamesByPath = groupFieldNamesByPath(fields); + + if (fieldNamesByPath.size() > 1) { + String message = + StringUtils.format( + "only single nested field is allowed as right table for LEFT JOIN, found %s ", + fieldNamesByPath.keySet()); + + throw new VerificationException(message); + } - public NestedFieldProjection(SearchRequestBuilder request) { - this.request = request; + Map.Entry> pathToFields = + fieldNamesByPath.entrySet().iterator().next(); + String path = pathToFields.getKey(); + List fieldNames = pathToFields.getValue(); + buildNestedLeftJoinQuery(path, fieldNames); + } else { + + groupFieldNamesByPath(fields) + .forEach( + (path, fieldNames) -> + buildInnerHit(fieldNames, findNestedQueryWithSamePath(nestedQueries, path))); + } } - - /** - * Project nested field in SELECT clause to InnerHit in NestedQueryBuilder - * - * @param fields list of field domain object - */ - public void project(List fields, JoinType nestedJoinType) { - if (isAnyNestedField(fields)) { - initBoolQueryFilterIfNull(); - List nestedQueries = extractNestedQueries(query()); - - if (nestedJoinType == JoinType.LEFT_OUTER_JOIN) { - // for LEFT JOIN on nested field as right table, the query will have only one nested field, so one path - Map> fieldNamesByPath = groupFieldNamesByPath(fields); - - if (fieldNamesByPath.size() > 1) { - String message = StringUtils.format( - "only single nested field is allowed as right table for LEFT JOIN, found %s ", - fieldNamesByPath.keySet() - ); - - throw new VerificationException(message); - } - - Map.Entry> pathToFields = fieldNamesByPath.entrySet().iterator().next(); - String path = pathToFields.getKey(); - List fieldNames = pathToFields.getValue(); - buildNestedLeftJoinQuery(path, fieldNames); - } else { - - groupFieldNamesByPath(fields).forEach( - (path, fieldNames) -> buildInnerHit(fieldNames, findNestedQueryWithSamePath(nestedQueries, path)) - ); - } - } + } + + /** + * Check via traditional for loop first to avoid lambda performance impact on all queries even + * though those without nested field + */ + private boolean isAnyNestedField(List fields) { + for (Field field : fields) { + if (field.isNested() && !field.isReverseNested()) { + return true; + } } + return false; + } - /** - * Check via traditional for loop first to avoid lambda performance impact on all queries - * even though those without nested field - */ - private boolean isAnyNestedField(List fields) { - for (Field field : fields) { - if (field.isNested() && !field.isReverseNested()) { - return true; - } - } - return false; + private void initBoolQueryFilterIfNull() { + if (request.request().source() == null || query() == null) { + request.setQuery(boolQuery()); } - - private void initBoolQueryFilterIfNull() { - if (request.request().source() == null || query() == null) { - request.setQuery(boolQuery()); - } - if (query().filter().isEmpty()) { - query().filter(boolQuery()); - } + if (query().filter().isEmpty()) { + query().filter(boolQuery()); } + } - private Map> groupFieldNamesByPath(List fields) { - return fields.stream(). - filter(Field::isNested). - filter(not(Field::isReverseNested)). - collect(groupingBy(Field::getNestedPath, mapping(Field::getName, toList()))); - } + private Map> groupFieldNamesByPath(List fields) { + return fields.stream() + .filter(Field::isNested) + .filter(not(Field::isReverseNested)) + .collect(groupingBy(Field::getNestedPath, mapping(Field::getName, toList()))); + } /** * Why search for NestedQueryBuilder recursively? - * Because 1) it was added and wrapped by BoolQuery when WHERE explained (far from here) - * 2) InnerHit must be added to the NestedQueryBuilder related + * Because + *

    + *
  1. it was added and wrapped by BoolQuery when WHERE explained (far from here) + *
  2. InnerHit must be added to the NestedQueryBuilder related + *
*

* Either we store it to global data structure (which requires to be thread-safe or ThreadLocal) * or we peel off BoolQuery to find it (the way we followed here because recursion tree should be very thin). @@ -130,55 +133,54 @@ private List extractNestedQueries(QueryBuilder query) { return result; } - private void buildInnerHit(List fieldNames, NestedQueryBuilder query) { - query.innerHit(new InnerHitBuilder().setFetchSourceContext( - new FetchSourceContext(true, fieldNames.toArray(new String[0]), null) - )); - } - - /** - * Why linear search? Because NestedQueryBuilder hides "path" field from any access. - * Assumption: collected NestedQueryBuilder list should be very small or mostly only one. - */ - private NestedQueryBuilder findNestedQueryWithSamePath(List nestedQueries, String path) { - return nestedQueries.stream(). - filter(query -> isSamePath(path, query)). - findAny(). - orElseGet(createEmptyNestedQuery(path)); - } - - private boolean isSamePath(String path, NestedQueryBuilder query) { - return nestedQuery(path, query.query(), query.scoreMode()).equals(query); - } - - /** - * Create a nested query with match all filter to place inner hits - */ - private Supplier createEmptyNestedQuery(String path) { - return () -> { - NestedQueryBuilder nestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None); - ((BoolQueryBuilder) query().filter().get(0)).must(nestedQuery); - return nestedQuery; - }; - } - - private BoolQueryBuilder query() { - return (BoolQueryBuilder) request.request().source().query(); - } - - private Predicate not(Predicate predicate) { - return predicate.negate(); - } - - - private void buildNestedLeftJoinQuery(String path, List fieldNames) { - BoolQueryBuilder existsNestedQuery = boolQuery(); - existsNestedQuery.mustNot().add(nestedQuery(path, existsQuery(path), ScoreMode.None)); - - NestedQueryBuilder matchAllNestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None); - buildInnerHit(fieldNames, matchAllNestedQuery); - - ((BoolQueryBuilder) query().filter().get(0)).should().add(existsNestedQuery); - ((BoolQueryBuilder) query().filter().get(0)).should().add(matchAllNestedQuery); - } + private void buildInnerHit(List fieldNames, NestedQueryBuilder query) { + query.innerHit( + new InnerHitBuilder() + .setFetchSourceContext( + new FetchSourceContext(true, fieldNames.toArray(new String[0]), null))); + } + + /** + * Why linear search? Because NestedQueryBuilder hides "path" field from any access. Assumption: + * collected NestedQueryBuilder list should be very small or mostly only one. + */ + private NestedQueryBuilder findNestedQueryWithSamePath( + List nestedQueries, String path) { + return nestedQueries.stream() + .filter(query -> isSamePath(path, query)) + .findAny() + .orElseGet(createEmptyNestedQuery(path)); + } + + private boolean isSamePath(String path, NestedQueryBuilder query) { + return nestedQuery(path, query.query(), query.scoreMode()).equals(query); + } + + /** Create a nested query with match all filter to place inner hits */ + private Supplier createEmptyNestedQuery(String path) { + return () -> { + NestedQueryBuilder nestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None); + ((BoolQueryBuilder) query().filter().get(0)).must(nestedQuery); + return nestedQuery; + }; + } + + private BoolQueryBuilder query() { + return (BoolQueryBuilder) request.request().source().query(); + } + + private Predicate not(Predicate predicate) { + return predicate.negate(); + } + + private void buildNestedLeftJoinQuery(String path, List fieldNames) { + BoolQueryBuilder existsNestedQuery = boolQuery(); + existsNestedQuery.mustNot().add(nestedQuery(path, existsQuery(path), ScoreMode.None)); + + NestedQueryBuilder matchAllNestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None); + buildInnerHit(fieldNames, matchAllNestedQuery); + + ((BoolQueryBuilder) query().filter().get(0)).should().add(existsNestedQuery); + ((BoolQueryBuilder) query().filter().get(0)).should().add(matchAllNestedQuery); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldRewriter.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldRewriter.java index f93f5e344e..976075a72d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldRewriter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/NestedFieldRewriter.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.nestedfield; import static org.opensearch.sql.legacy.utils.Util.NESTED_JOIN_TYPE; @@ -16,6 +15,7 @@ import java.util.Deque; /** + *

  * Visitor to rewrite AST (abstract syntax tree) for nested type fields to support implicit nested() function call.
  * Intuitively, the approach is to implement SQLIdentifier.visit() and wrap nested() function for nested field.
  * The parsing result of FROM clause will be used to determine if an identifier is nested field.
@@ -47,66 +47,64 @@
  * 1) Manage environment in the case of subquery
  * 2) Add nested field to select for SELECT *
  * 3) Merge conditions of same nested field to single nested() call
+ * 
*/ public class NestedFieldRewriter extends MySqlASTVisitorAdapter { - /** - * Scope stack to record the state (nested field names etc) for current query. - * In the case of subquery, the active scope of current query is the top element of the stack. - */ - private Deque environment = new ArrayDeque<>(); - - /** - * Rewrite FROM here to make sure FROM statement always be visited before other statement in query. - * Note that return true anyway to continue visiting FROM in subquery if any. - */ - @Override - public boolean visit(MySqlSelectQueryBlock query) { - environment.push(new Scope()); - if (query.getFrom() == null) { - return false; - } - - query.getFrom().setParent(query); - new From(query.getFrom()).rewrite(curScope()); + /** + * Scope stack to record the state (nested field names etc) for current query. In the case of + * subquery, the active scope of current query is the top element of the stack. + */ + private Deque environment = new ArrayDeque<>(); + + /** + * Rewrite FROM here to make sure FROM statement always be visited before other statement in + * query. Note that return true anyway to continue visiting FROM in subquery if any. + */ + @Override + public boolean visit(MySqlSelectQueryBlock query) { + environment.push(new Scope()); + if (query.getFrom() == null) { + return false; + } - if (curScope().isAnyNestedField() && isNotGroupBy(query)) { - new Select(query.getSelectList()).rewrite(curScope()); - } + query.getFrom().setParent(query); + new From(query.getFrom()).rewrite(curScope()); - query.putAttribute(NESTED_JOIN_TYPE, curScope().getActualJoinType()); - return true; + if (curScope().isAnyNestedField() && isNotGroupBy(query)) { + new Select(query.getSelectList()).rewrite(curScope()); } - @Override - public boolean visit(SQLIdentifierExpr expr) { - if (curScope().isAnyNestedField()) { - new Identifier(expr).rewrite(curScope()); - } - return true; - } + query.putAttribute(NESTED_JOIN_TYPE, curScope().getActualJoinType()); + return true; + } - @Override - public void endVisit(SQLBinaryOpExpr expr) { - if (curScope().isAnyNestedField()) { - new Where(expr).rewrite(curScope()); - } + @Override + public boolean visit(SQLIdentifierExpr expr) { + if (curScope().isAnyNestedField()) { + new Identifier(expr).rewrite(curScope()); } + return true; + } - @Override - public void endVisit(MySqlSelectQueryBlock query) { - environment.pop(); + @Override + public void endVisit(SQLBinaryOpExpr expr) { + if (curScope().isAnyNestedField()) { + new Where(expr).rewrite(curScope()); } + } - /** - * Current scope which is top of the stack - */ - private Scope curScope() { - return environment.peek(); - } + @Override + public void endVisit(MySqlSelectQueryBlock query) { + environment.pop(); + } - private boolean isNotGroupBy(MySqlSelectQueryBlock query) { - return query.getGroupBy() == null; - } + /** Current scope which is top of the stack */ + private Scope curScope() { + return environment.peek(); + } + private boolean isNotGroupBy(MySqlSelectQueryBlock query) { + return query.getGroupBy() == null; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Scope.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Scope.java index 5f035bc725..f65d7f166b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Scope.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Scope.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.nestedfield; import static com.alibaba.druid.sql.ast.statement.SQLJoinTableSource.JoinType; @@ -14,71 +13,68 @@ import java.util.Map; import java.util.Set; -/** - * Nested field information in current query being visited. - */ +/** Nested field information in current query being visited. */ class Scope { - /** Join Type as passed in the actual SQL subquery */ - private JoinType actualJoinType; - - /** Alias of parent such as alias "t" of parent table "team" in "FROM team t, t.employees e" */ - - private String parentAlias; - - /** - * Mapping from nested field path alias to path full name in FROM. - * eg. e in {e => employees} in "FROM t.employees e" - */ - private Map aliasFullPaths = new HashMap<>(); - - /** - * Mapping from binary operation condition (in WHERE) to nested - * field tag (full path for nested, EMPTY for non-nested field) - */ - private Map conditionTags = new IdentityHashMap<>(); - - String getParentAlias() { - return parentAlias; - } - - void setParentAlias(String parentAlias) { - this.parentAlias = parentAlias; + /** Join Type as passed in the actual SQL subquery */ + private JoinType actualJoinType; + + /** Alias of parent such as alias "t" of parent table "team" in "FROM team t, t.employees e" */ + private String parentAlias; + + /** + * Mapping from nested field path alias to path full name in FROM. eg. e in {e => employees} in + * "FROM t.employees e" + */ + private Map aliasFullPaths = new HashMap<>(); + + /** + * Mapping from binary operation condition (in WHERE) to nested field tag (full path for nested, + * EMPTY for non-nested field) + */ + private Map conditionTags = new IdentityHashMap<>(); + + String getParentAlias() { + return parentAlias; + } + + void setParentAlias(String parentAlias) { + this.parentAlias = parentAlias; + } + + void addAliasFullPath(String alias, String path) { + if (alias.isEmpty()) { + aliasFullPaths.put(path, path); + } else { + aliasFullPaths.put(alias, path); } + } - void addAliasFullPath(String alias, String path) { - if (alias.isEmpty()) { - aliasFullPaths.put(path, path); - } else { - aliasFullPaths.put(alias, path); - } - } + String getFullPath(String alias) { + return aliasFullPaths.getOrDefault(alias, ""); + } - String getFullPath(String alias) { - return aliasFullPaths.getOrDefault(alias, ""); - } + boolean isAnyNestedField() { + return !aliasFullPaths.isEmpty(); + } - boolean isAnyNestedField() { - return !aliasFullPaths.isEmpty(); - } + Set getAliases() { + return aliasFullPaths.keySet(); + } - Set getAliases() { - return aliasFullPaths.keySet(); - } + String getConditionTag(SQLBinaryOpExpr expr) { + return conditionTags.getOrDefault(expr, ""); + } - String getConditionTag(SQLBinaryOpExpr expr) { - return conditionTags.getOrDefault(expr, ""); - } + void addConditionTag(SQLBinaryOpExpr expr, String tag) { + conditionTags.put(expr, tag); + } - void addConditionTag(SQLBinaryOpExpr expr, String tag) { - conditionTags.put(expr, tag); - } - - JoinType getActualJoinType() { - return actualJoinType; - } + JoinType getActualJoinType() { + return actualJoinType; + } - void setActualJoinType(JoinType joinType) { - actualJoinType = joinType; - } + void setActualJoinType(JoinType joinType) { + actualJoinType = joinType; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Select.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Select.java index f514e6d081..8d2d6402e1 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Select.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/nestedfield/Select.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.nestedfield; import com.alibaba.druid.sql.ast.expr.SQLAllColumnExpr; @@ -11,39 +10,37 @@ import com.alibaba.druid.sql.ast.statement.SQLSelectItem; import java.util.List; -/** - * Column list in SELECT statement. - */ +/** Column list in SELECT statement. */ class Select extends SQLClause> { - Select(List expr) { - super(expr); - } - - /** - * Rewrite by adding nested field to SELECT in the case of 'SELECT *'. - *

- * Ex. 'SELECT *' => 'SELECT *, employees.*' - * So that NestedFieldProjection will add 'employees.*' to includes list in inner_hits. - */ - @Override - void rewrite(Scope scope) { - if (isSelectAllOnly()) { - addSelectAllForNestedField(scope); - } + Select(List expr) { + super(expr); + } + + /** + * Rewrite by adding nested field to SELECT in the case of 'SELECT *'. + * + *

Ex. 'SELECT *' => 'SELECT *, employees.*' So that NestedFieldProjection will add + * 'employees.*' to includes list in inner_hits. + */ + @Override + void rewrite(Scope scope) { + if (isSelectAllOnly()) { + addSelectAllForNestedField(scope); } + } - private boolean isSelectAllOnly() { - return expr.size() == 1 && expr.get(0).getExpr() instanceof SQLAllColumnExpr; - } + private boolean isSelectAllOnly() { + return expr.size() == 1 && expr.get(0).getExpr() instanceof SQLAllColumnExpr; + } - private void addSelectAllForNestedField(Scope scope) { - for (String alias : scope.getAliases()) { - expr.add(createSelectItem(alias + ".*")); - } + private void addSelectAllForNestedField(Scope scope) { + for (String alias : scope.getAliases()) { + expr.add(createSelectItem(alias + ".*")); } + } - private SQLSelectItem createSelectItem(String name) { - return new SQLSelectItem(new SQLIdentifierExpr(name)); - } + private SQLSelectItem createSelectItem(String name) { + return new SQLSelectItem(new SQLIdentifierExpr(name)); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/ordinal/OrdinalRewriterRule.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/ordinal/OrdinalRewriterRule.java index 1d44ac8261..03ff07b1b8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/ordinal/OrdinalRewriterRule.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/ordinal/OrdinalRewriterRule.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.ordinal; import com.alibaba.druid.sql.ast.SQLExpr; @@ -23,128 +22,131 @@ import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException; /** - * Rewrite rule for changing ordinal alias in order by and group by to actual select field. - * Since we cannot clone or deepcopy the Druid SQL objects, we need to generate the - * two syntax tree from the original query to map Group By and Order By fields with ordinal alias - * to Select fields in newly generated syntax tree. + * Rewrite rule for changing ordinal alias in order by and group by to actual select field. Since we + * cannot clone or deepcopy the Druid SQL objects, we need to generate the two syntax tree from the + * original query to map Group By and Order By fields with ordinal alias to Select fields in newly + * generated syntax tree. * - * This rewriter assumes that all the backticks have been removed from identifiers. - * It also assumes that table alias have been removed from SELECT, WHERE, GROUP BY, ORDER BY fields. + *

This rewriter assumes that all the backticks have been removed from identifiers. It also + * assumes that table alias have been removed from SELECT, WHERE, GROUP BY, ORDER BY fields. */ - public class OrdinalRewriterRule implements RewriteRule { - private final String sql; + private final String sql; - public OrdinalRewriterRule(String sql) { - this.sql = sql; - } + public OrdinalRewriterRule(String sql) { + this.sql = sql; + } - @Override - public boolean match(SQLQueryExpr root) { - SQLSelectQuery sqlSelectQuery = root.getSubQuery().getQuery(); - if (!(sqlSelectQuery instanceof MySqlSelectQueryBlock)) { - // it could be SQLUnionQuery - return false; - } - - MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) sqlSelectQuery; - if (!hasGroupByWithOrdinals(query) && !hasOrderByWithOrdinals(query)) { - return false; - } - return true; + @Override + public boolean match(SQLQueryExpr root) { + SQLSelectQuery sqlSelectQuery = root.getSubQuery().getQuery(); + if (!(sqlSelectQuery instanceof MySqlSelectQueryBlock)) { + // it could be SQLUnionQuery + return false; } - @Override - public void rewrite(SQLQueryExpr root) { - // we cannot clone SQLSelectItem, so we need similar objects to assign to GroupBy and OrderBy items - SQLQueryExpr sqlExprGroupCopy = toSqlExpr(); - SQLQueryExpr sqlExprOrderCopy = toSqlExpr(); - - changeOrdinalAliasInGroupAndOrderBy(root, sqlExprGroupCopy, sqlExprOrderCopy); + MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) sqlSelectQuery; + if (!hasGroupByWithOrdinals(query) && !hasOrderByWithOrdinals(query)) { + return false; } - - private void changeOrdinalAliasInGroupAndOrderBy(SQLQueryExpr root, - SQLQueryExpr exprGroup, - SQLQueryExpr exprOrder) { - root.accept(new MySqlASTVisitorAdapter() { - - private String groupException = "Invalid ordinal [%s] specified in [GROUP BY %s]"; - private String orderException = "Invalid ordinal [%s] specified in [ORDER BY %s]"; - - private List groupSelectList = ((MySqlSelectQueryBlock) exprGroup.getSubQuery().getQuery()) - .getSelectList(); - - private List orderSelectList = ((MySqlSelectQueryBlock) exprOrder.getSubQuery().getQuery()) - .getSelectList(); - - @Override - public boolean visit(MySqlSelectGroupByExpr groupByExpr) { - SQLExpr expr = groupByExpr.getExpr(); - if (expr instanceof SQLIntegerExpr) { - Integer ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue(); - SQLExpr newExpr = checkAndGet(groupSelectList, ordinalValue, groupException); - groupByExpr.setExpr(newExpr); - newExpr.setParent(groupByExpr); - } - return false; + return true; + } + + @Override + public void rewrite(SQLQueryExpr root) { + // we cannot clone SQLSelectItem, so we need similar objects to assign to GroupBy and OrderBy + // items + SQLQueryExpr sqlExprGroupCopy = toSqlExpr(); + SQLQueryExpr sqlExprOrderCopy = toSqlExpr(); + + changeOrdinalAliasInGroupAndOrderBy(root, sqlExprGroupCopy, sqlExprOrderCopy); + } + + private void changeOrdinalAliasInGroupAndOrderBy( + SQLQueryExpr root, SQLQueryExpr exprGroup, SQLQueryExpr exprOrder) { + root.accept( + new MySqlASTVisitorAdapter() { + + private String groupException = "Invalid ordinal [%s] specified in [GROUP BY %s]"; + private String orderException = "Invalid ordinal [%s] specified in [ORDER BY %s]"; + + private List groupSelectList = + ((MySqlSelectQueryBlock) exprGroup.getSubQuery().getQuery()).getSelectList(); + + private List orderSelectList = + ((MySqlSelectQueryBlock) exprOrder.getSubQuery().getQuery()).getSelectList(); + + @Override + public boolean visit(MySqlSelectGroupByExpr groupByExpr) { + SQLExpr expr = groupByExpr.getExpr(); + if (expr instanceof SQLIntegerExpr) { + Integer ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue(); + SQLExpr newExpr = checkAndGet(groupSelectList, ordinalValue, groupException); + groupByExpr.setExpr(newExpr); + newExpr.setParent(groupByExpr); } - - @Override - public boolean visit(SQLSelectOrderByItem orderByItem) { - SQLExpr expr = orderByItem.getExpr(); - Integer ordinalValue; - - if (expr instanceof SQLIntegerExpr) { - ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue(); - SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException); - orderByItem.setExpr(newExpr); - newExpr.setParent(orderByItem); - } else if (expr instanceof SQLBinaryOpExpr - && ((SQLBinaryOpExpr) expr).getLeft() instanceof SQLIntegerExpr) { - // support ORDER BY IS NULL/NOT NULL - SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) expr; - SQLIntegerExpr integerExpr = (SQLIntegerExpr) binaryOpExpr.getLeft(); - - ordinalValue = integerExpr.getNumber().intValue(); - SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException); - binaryOpExpr.setLeft(newExpr); - newExpr.setParent(binaryOpExpr); - } - - return false; + return false; + } + + @Override + public boolean visit(SQLSelectOrderByItem orderByItem) { + SQLExpr expr = orderByItem.getExpr(); + Integer ordinalValue; + + if (expr instanceof SQLIntegerExpr) { + ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue(); + SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException); + orderByItem.setExpr(newExpr); + newExpr.setParent(orderByItem); + } else if (expr instanceof SQLBinaryOpExpr + && ((SQLBinaryOpExpr) expr).getLeft() instanceof SQLIntegerExpr) { + // support ORDER BY IS NULL/NOT NULL + SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) expr; + SQLIntegerExpr integerExpr = (SQLIntegerExpr) binaryOpExpr.getLeft(); + + ordinalValue = integerExpr.getNumber().intValue(); + SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException); + binaryOpExpr.setLeft(newExpr); + newExpr.setParent(binaryOpExpr); } - }); - } - private SQLExpr checkAndGet(List selectList, Integer ordinal, String exception) { - if (ordinal > selectList.size()) { - throw new VerificationException(String.format(exception, ordinal, ordinal)); - } + return false; + } + }); + } - return selectList.get(ordinal-1).getExpr(); + private SQLExpr checkAndGet(List selectList, Integer ordinal, String exception) { + if (ordinal > selectList.size()) { + throw new VerificationException(String.format(exception, ordinal, ordinal)); } - private boolean hasGroupByWithOrdinals(MySqlSelectQueryBlock query) { - if (query.getGroupBy() == null) { - return false; - } else if (query.getGroupBy().getItems().isEmpty()){ - return false; - } + return selectList.get(ordinal - 1).getExpr(); + } - return query.getGroupBy().getItems().stream().anyMatch(x -> - x instanceof MySqlSelectGroupByExpr && ((MySqlSelectGroupByExpr) x).getExpr() instanceof SQLIntegerExpr - ); + private boolean hasGroupByWithOrdinals(MySqlSelectQueryBlock query) { + if (query.getGroupBy() == null) { + return false; + } else if (query.getGroupBy().getItems().isEmpty()) { + return false; } - private boolean hasOrderByWithOrdinals(MySqlSelectQueryBlock query) { - if (query.getOrderBy() == null) { - return false; - } else if (query.getOrderBy().getItems().isEmpty()){ - return false; - } + return query.getGroupBy().getItems().stream() + .anyMatch( + x -> + x instanceof MySqlSelectGroupByExpr + && ((MySqlSelectGroupByExpr) x).getExpr() instanceof SQLIntegerExpr); + } + + private boolean hasOrderByWithOrdinals(MySqlSelectQueryBlock query) { + if (query.getOrderBy() == null) { + return false; + } else if (query.getOrderBy().getItems().isEmpty()) { + return false; + } /** + *

          * The second condition checks valid AST that meets ORDER BY IS NULL/NOT NULL condition
          *
          *            SQLSelectOrderByItem
@@ -152,6 +154,7 @@ private boolean hasOrderByWithOrdinals(MySqlSelectQueryBlock query) {
          *             SQLBinaryOpExpr (Is || IsNot)
          *                    /  \
          *    SQLIdentifierExpr  SQLNullExpr
+         *  
*/ return query.getOrderBy().getItems().stream().anyMatch(x -> x.getExpr() instanceof SQLIntegerExpr @@ -162,9 +165,9 @@ private boolean hasOrderByWithOrdinals(MySqlSelectQueryBlock query) { ); } - private SQLQueryExpr toSqlExpr() { - SQLExprParser parser = new ElasticSqlExprParser(sql); - SQLExpr expr = parser.expr(); - return (SQLQueryExpr) expr; - } + private SQLQueryExpr toSqlExpr() { + SQLExprParser parser = new ElasticSqlExprParser(sql); + SQLExpr expr = parser.expr(); + return (SQLQueryExpr) expr; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/NestedQueryContext.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/NestedQueryContext.java index ce254e2103..b300015d49 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/NestedQueryContext.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/NestedQueryContext.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.subquery; import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; @@ -14,53 +13,51 @@ import java.util.Map; /** - * {@link NestedQueryContext} build the context with Query to detected the specified table is nested or not. - * Todo current implementation doesn't rely on the index mapping which should be added after the semantics is builded. + * {@link NestedQueryContext} build the context with Query to detected the specified table is nested + * or not. + *
Todo current implementation doesn't rely on the index mapping which should be added after + * the semantics is built. */ public class NestedQueryContext { - private static final String SEPARATOR = "."; - private static final String EMPTY = ""; - // , if parentTable not exist, parentTableAlias = ""; - private final Map aliasParents = new HashMap<>(); + private static final String SEPARATOR = "."; + private static final String EMPTY = ""; + // , if parentTable not exist, parentTableAlias = ""; + private final Map aliasParents = new HashMap<>(); - /** - * Is the table refer to the nested field of the parent table. - */ - public boolean isNested(SQLExprTableSource table) { - String parent = parent(table); - if (Strings.isNullOrEmpty(parent)) { - return !Strings.isNullOrEmpty(aliasParents.get(alias(table))); - } else { - return aliasParents.containsKey(parent); - } + /** Is the table refer to the nested field of the parent table. */ + public boolean isNested(SQLExprTableSource table) { + String parent = parent(table); + if (Strings.isNullOrEmpty(parent)) { + return !Strings.isNullOrEmpty(aliasParents.get(alias(table))); + } else { + return aliasParents.containsKey(parent); } + } - /** - * add table to the context. - */ - public void add(SQLTableSource table) { - if (table instanceof SQLExprTableSource) { - process((SQLExprTableSource) table); - } else if (table instanceof SQLJoinTableSource) { - add(((SQLJoinTableSource) table).getLeft()); - add(((SQLJoinTableSource) table).getRight()); - } else { - throw new IllegalStateException("unsupported table source"); - } + /** add table to the context. */ + public void add(SQLTableSource table) { + if (table instanceof SQLExprTableSource) { + process((SQLExprTableSource) table); + } else if (table instanceof SQLJoinTableSource) { + add(((SQLJoinTableSource) table).getLeft()); + add(((SQLJoinTableSource) table).getRight()); + } else { + throw new IllegalStateException("unsupported table source"); } + } - private void process(SQLExprTableSource table) { - String alias = alias(table); - String parent = parent(table); - if (!Strings.isNullOrEmpty(alias)) { - aliasParents.putIfAbsent(alias, parent); - } + private void process(SQLExprTableSource table) { + String alias = alias(table); + String parent = parent(table); + if (!Strings.isNullOrEmpty(alias)) { + aliasParents.putIfAbsent(alias, parent); } + } /** - * Extract the parent alias from the tableName. For example - * SELECT * FROM employee e, e.project as p, - * For expr: employee, the parent alias is "". + * Extract the parent alias from the tableName. For example
+ * SELECT * FROM employee e, e.project as p,
+ * For expr: employee, the parent alias is "".
* For expr: e.project, the parent alias is e. */ private String parent(SQLExprTableSource table) { @@ -69,10 +66,10 @@ private String parent(SQLExprTableSource table) { return index == -1 ? EMPTY : tableName.substring(0, index); } - private String alias(SQLExprTableSource table) { - if (Strings.isNullOrEmpty(table.getAlias())) { - return table.getExpr().toString(); - } - return table.getAlias(); + private String alias(SQLExprTableSource table) { + if (Strings.isNullOrEmpty(table.getAlias())) { + return table.getExpr().toString(); } + return table.getAlias(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/RewriterContext.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/RewriterContext.java index 09698095e6..54cba6547b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/RewriterContext.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/RewriterContext.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.subquery; import com.alibaba.druid.sql.ast.SQLExpr; @@ -19,68 +18,66 @@ import java.util.Deque; import java.util.List; -/** - * Environment for rewriting the SQL. - */ +/** Environment for rewriting the SQL. */ public class RewriterContext { - private final Deque tableStack = new ArrayDeque<>(); - private final Deque conditionStack = new ArrayDeque<>(); - private final List sqlInSubQueryExprs = new ArrayList<>(); - private final List sqlExistsExprs = new ArrayList<>(); - private final NestedQueryContext nestedQueryDetector = new NestedQueryContext(); + private final Deque tableStack = new ArrayDeque<>(); + private final Deque conditionStack = new ArrayDeque<>(); + private final List sqlInSubQueryExprs = new ArrayList<>(); + private final List sqlExistsExprs = new ArrayList<>(); + private final NestedQueryContext nestedQueryDetector = new NestedQueryContext(); - public SQLTableSource popJoin() { - return tableStack.pop(); - } + public SQLTableSource popJoin() { + return tableStack.pop(); + } - public SQLExpr popWhere() { - return conditionStack.pop(); - } + public SQLExpr popWhere() { + return conditionStack.pop(); + } - public void addWhere(SQLExpr expr) { - conditionStack.push(expr); - } + public void addWhere(SQLExpr expr) { + conditionStack.push(expr); + } - /** - * Add the Join right table and {@link JoinType} and {@link SQLBinaryOpExpr} which will - * merge the left table in the tableStack. - */ - public void addJoin(SQLTableSource right, JoinType joinType, SQLBinaryOpExpr condition) { - SQLTableSource left = tableStack.pop(); - SQLJoinTableSource joinTableSource = new SQLJoinTableSource(); - joinTableSource.setLeft(left); - joinTableSource.setRight(right); - joinTableSource.setJoinType(joinType); - joinTableSource.setCondition(condition); - tableStack.push(joinTableSource); - } + /** + * Add the Join right table and {@link JoinType} and {@link SQLBinaryOpExpr} which will merge the + * left table in the tableStack. + */ + public void addJoin(SQLTableSource right, JoinType joinType, SQLBinaryOpExpr condition) { + SQLTableSource left = tableStack.pop(); + SQLJoinTableSource joinTableSource = new SQLJoinTableSource(); + joinTableSource.setLeft(left); + joinTableSource.setRight(right); + joinTableSource.setJoinType(joinType); + joinTableSource.setCondition(condition); + tableStack.push(joinTableSource); + } - public void addJoin(SQLTableSource right, JoinType joinType) { - addJoin(right, joinType, null); - } + public void addJoin(SQLTableSource right, JoinType joinType) { + addJoin(right, joinType, null); + } - public void addTable(SQLTableSource table) { - tableStack.push(table); - nestedQueryDetector.add(table); - } + public void addTable(SQLTableSource table) { + tableStack.push(table); + nestedQueryDetector.add(table); + } - public boolean isNestedQuery(SQLExprTableSource table) { - return nestedQueryDetector.isNested(table); - } + public boolean isNestedQuery(SQLExprTableSource table) { + return nestedQueryDetector.isNested(table); + } - public void setInSubQuery(SQLInSubQueryExpr expr) { - sqlInSubQueryExprs.add(expr); - } + public void setInSubQuery(SQLInSubQueryExpr expr) { + sqlInSubQueryExprs.add(expr); + } - public void setExistsSubQuery(SQLExistsExpr expr) { - sqlExistsExprs.add(expr); - } + public void setExistsSubQuery(SQLExistsExpr expr) { + sqlExistsExprs.add(expr); + } - public List getSqlInSubQueryExprs() { - return sqlInSubQueryExprs; - } + public List getSqlInSubQueryExprs() { + return sqlInSubQueryExprs; + } - public List getSqlExistsExprs() { - return sqlExistsExprs; - } + public List getSqlExistsExprs() { + return sqlExistsExprs; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/Rewriter.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/Rewriter.java index 5ca0a38d7f..a23eaaf514 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/Rewriter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/Rewriter.java @@ -3,28 +3,21 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.subquery.rewriter; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator; -/** - * Interface of SQL Rewriter - */ +/** Interface of SQL Rewriter */ public interface Rewriter { - /** - * Whether the Rewriter can rewrite the SQL? - */ - boolean canRewrite(); + /** Whether the Rewriter can rewrite the SQL? */ + boolean canRewrite(); - /** - * Rewrite the SQL. - */ - void rewrite(); + /** Rewrite the SQL. */ + void rewrite(); - default SQLBinaryOpExpr and(SQLBinaryOpExpr left, SQLBinaryOpExpr right) { - return new SQLBinaryOpExpr(left, SQLBinaryOperator.BooleanAnd, right); - } + default SQLBinaryOpExpr and(SQLBinaryOpExpr left, SQLBinaryOpExpr right) { + return new SQLBinaryOpExpr(left, SQLBinaryOperator.BooleanAnd, right); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/RewriterFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/RewriterFactory.java index ace333e981..6e6656ec37 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/RewriterFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/subquery/rewriter/RewriterFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.subquery.rewriter; import com.alibaba.druid.sql.ast.SQLExpr; @@ -13,32 +12,26 @@ import java.util.List; import org.opensearch.sql.legacy.rewriter.subquery.RewriterContext; -/** - * Factory for generating the {@link Rewriter}. - */ +/** Factory for generating the {@link Rewriter}. */ public class RewriterFactory { - /** - * Create list of {@link Rewriter}. - */ - public static List createRewriterList(SQLExpr expr, RewriterContext bb) { - if (expr instanceof SQLExistsExpr) { - return existRewriterList((SQLExistsExpr) expr, bb); - } else if (expr instanceof SQLInSubQueryExpr) { - return inRewriterList((SQLInSubQueryExpr) expr, bb); - } - return ImmutableList.of(); + /** Create list of {@link Rewriter}. */ + public static List createRewriterList(SQLExpr expr, RewriterContext bb) { + if (expr instanceof SQLExistsExpr) { + return existRewriterList((SQLExistsExpr) expr, bb); + } else if (expr instanceof SQLInSubQueryExpr) { + return inRewriterList((SQLInSubQueryExpr) expr, bb); } + return ImmutableList.of(); + } - private static List existRewriterList(SQLExistsExpr existsExpr, RewriterContext bb) { - return new ImmutableList.Builder() - .add(new NestedExistsRewriter(existsExpr, bb)) - .build(); - } + private static List existRewriterList(SQLExistsExpr existsExpr, RewriterContext bb) { + return new ImmutableList.Builder() + .add(new NestedExistsRewriter(existsExpr, bb)) + .build(); + } - private static List inRewriterList(SQLInSubQueryExpr inExpr, RewriterContext bb) { - return new ImmutableList.Builder() - .add(new InRewriter(inExpr, bb)) - .build(); - } + private static List inRewriterList(SQLInSubQueryExpr inExpr, RewriterContext bb) { + return new ImmutableList.Builder().add(new InRewriter(inExpr, bb)).build(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/Point.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/Point.java index c449ef1364..f3f8639a1c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/Point.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/Point.java @@ -3,26 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; -/** - * Created by Eliran on 1/8/2015. - */ +/** Created by Eliran on 1/8/2015. */ public class Point { - private double lon; - private double lat; + private double lon; + private double lat; - public Point(double lon, double lat) { - this.lon = lon; - this.lat = lat; - } + public Point(double lon, double lat) { + this.lon = lon; + this.lat = lat; + } - public double getLon() { - return lon; - } + public double getLon() { + return lon; + } - public double getLat() { - return lat; - } + public double getLat() { + return lat; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/PolygonFilterParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/PolygonFilterParams.java index 0d0592f519..1aeddb24a4 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/PolygonFilterParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/PolygonFilterParams.java @@ -3,22 +3,19 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; import java.util.List; -/** - * Created by Eliran on 15/8/2015. - */ +/** Created by Eliran on 15/8/2015. */ public class PolygonFilterParams { - private List polygon; + private List polygon; - public PolygonFilterParams(List polygon) { - this.polygon = polygon; - } + public PolygonFilterParams(List polygon) { + this.polygon = polygon; + } - public List getPolygon() { - return polygon; - } + public List getPolygon() { + return polygon; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/RangeDistanceFilterParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/RangeDistanceFilterParams.java index 91962332bf..0bdb01c3ce 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/RangeDistanceFilterParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/RangeDistanceFilterParams.java @@ -3,25 +3,22 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; -/** - * Created by Eliran on 15/8/2015. - */ +/** Created by Eliran on 15/8/2015. */ public class RangeDistanceFilterParams extends DistanceFilterParams { - private String distanceTo; + private String distanceTo; - public RangeDistanceFilterParams(String distanceFrom, String distanceTo, Point from) { - super(distanceFrom, from); - this.distanceTo = distanceTo; - } + public RangeDistanceFilterParams(String distanceFrom, String distanceTo, Point from) { + super(distanceFrom, from); + this.distanceTo = distanceTo; + } - public String getDistanceTo() { - return distanceTo; - } + public String getDistanceTo() { + return distanceTo; + } - public String getDistanceFrom() { - return this.getDistance(); - } + public String getDistanceFrom() { + return this.getDistance(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/utils/QueryDataAnonymizer.java b/legacy/src/main/java/org/opensearch/sql/legacy/utils/QueryDataAnonymizer.java index b58691c022..acf7a73ba5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/utils/QueryDataAnonymizer.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/utils/QueryDataAnonymizer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.utils; import static org.opensearch.sql.legacy.utils.Util.toSqlExpr; @@ -14,35 +13,35 @@ import org.apache.logging.log4j.Logger; import org.opensearch.sql.legacy.rewriter.identifier.AnonymizeSensitiveDataRule; -/** - * Utility class to mask sensitive information in incoming SQL queries - */ +/** Utility class to mask sensitive information in incoming SQL queries */ public class QueryDataAnonymizer { - private static final Logger LOG = LogManager.getLogger(QueryDataAnonymizer.class); + private static final Logger LOG = LogManager.getLogger(QueryDataAnonymizer.class); - /** - * This method is used to anonymize sensitive data in SQL query. - * Sensitive data includes index names, column names etc., - * which in druid parser are parsed to SQLIdentifierExpr instances - * @param query entire sql query string - * @return sql query string with all identifiers replaced with "***" on success - * and failure string otherwise to ensure no non-anonymized data is logged in production. - */ - public static String anonymizeData(String query) { - String resultQuery; - try { - AnonymizeSensitiveDataRule rule = new AnonymizeSensitiveDataRule(); - SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(query); - rule.rewrite(sqlExpr); - resultQuery = SQLUtils.toMySqlString(sqlExpr).replaceAll("0", "number") - .replaceAll("false", "boolean_literal") - .replaceAll("[\\n][\\t]+", " "); - } catch (Exception e) { - LOG.warn("Caught an exception when anonymizing sensitive data."); - LOG.debug("String {} failed anonymization.", query); - resultQuery = "Failed to anonymize data."; - } - return resultQuery; + /** + * This method is used to anonymize sensitive data in SQL query. Sensitive data includes index + * names, column names etc., which in druid parser are parsed to SQLIdentifierExpr instances + * + * @param query entire sql query string + * @return sql query string with all identifiers replaced with "***" on success and failure string + * otherwise to ensure no non-anonymized data is logged in production. + */ + public static String anonymizeData(String query) { + String resultQuery; + try { + AnonymizeSensitiveDataRule rule = new AnonymizeSensitiveDataRule(); + SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(query); + rule.rewrite(sqlExpr); + resultQuery = + SQLUtils.toMySqlString(sqlExpr) + .replaceAll("0", "number") + .replaceAll("false", "boolean_literal") + .replaceAll("[\\n][\\t]+", " "); + } catch (Exception e) { + LOG.warn("Caught an exception when anonymizing sensitive data."); + LOG.debug("String {} failed anonymization.", query); + resultQuery = "Failed to anonymize data."; } + return resultQuery; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/ProductTypeTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/ProductTypeTest.java index 326dd6ce06..5c87aabdee 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/ProductTypeTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/ProductTypeTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types; import static java.util.Collections.singletonList; @@ -18,56 +17,53 @@ import org.junit.Test; import org.opensearch.sql.legacy.antlr.semantic.types.special.Product; -/** - * Test cases fro product type - */ +/** Test cases fro product type */ public class ProductTypeTest { - @Test - public void singleSameTypeInTwoProductsShouldPass() { - Product product1 = new Product(singletonList(INTEGER)); - Product product2 = new Product(singletonList(INTEGER)); - Assert.assertTrue(product1.isCompatible(product2)); - Assert.assertTrue(product2.isCompatible(product1)); - } - - @Test - public void singleCompatibleTypeInTwoProductsShouldPass() { - Product product1 = new Product(singletonList(NUMBER)); - Product product2 = new Product(singletonList(INTEGER)); - Assert.assertTrue(product1.isCompatible(product2)); - Assert.assertTrue(product2.isCompatible(product1)); - } + @Test + public void singleSameTypeInTwoProductsShouldPass() { + Product product1 = new Product(singletonList(INTEGER)); + Product product2 = new Product(singletonList(INTEGER)); + Assert.assertTrue(product1.isCompatible(product2)); + Assert.assertTrue(product2.isCompatible(product1)); + } - @Test - public void twoCompatibleTypesInTwoProductsShouldPass() { - Product product1 = new Product(Arrays.asList(NUMBER, KEYWORD)); - Product product2 = new Product(Arrays.asList(INTEGER, STRING)); - Assert.assertTrue(product1.isCompatible(product2)); - Assert.assertTrue(product2.isCompatible(product1)); - } + @Test + public void singleCompatibleTypeInTwoProductsShouldPass() { + Product product1 = new Product(singletonList(NUMBER)); + Product product2 = new Product(singletonList(INTEGER)); + Assert.assertTrue(product1.isCompatible(product2)); + Assert.assertTrue(product2.isCompatible(product1)); + } - @Test - public void incompatibleTypesInTwoProductsShouldFail() { - Product product1 = new Product(singletonList(BOOLEAN)); - Product product2 = new Product(singletonList(STRING)); - Assert.assertFalse(product1.isCompatible(product2)); - Assert.assertFalse(product2.isCompatible(product1)); - } + @Test + public void twoCompatibleTypesInTwoProductsShouldPass() { + Product product1 = new Product(Arrays.asList(NUMBER, KEYWORD)); + Product product2 = new Product(Arrays.asList(INTEGER, STRING)); + Assert.assertTrue(product1.isCompatible(product2)); + Assert.assertTrue(product2.isCompatible(product1)); + } - @Test - public void compatibleButDifferentTypeNumberInTwoProductsShouldFail() { - Product product1 = new Product(Arrays.asList(KEYWORD, INTEGER)); - Product product2 = new Product(singletonList(STRING)); - Assert.assertFalse(product1.isCompatible(product2)); - Assert.assertFalse(product2.isCompatible(product1)); - } + @Test + public void incompatibleTypesInTwoProductsShouldFail() { + Product product1 = new Product(singletonList(BOOLEAN)); + Product product2 = new Product(singletonList(STRING)); + Assert.assertFalse(product1.isCompatible(product2)); + Assert.assertFalse(product2.isCompatible(product1)); + } - @Test - public void baseTypeShouldBeIncompatibleWithProductType() { - Product product = new Product(singletonList(INTEGER)); - Assert.assertFalse(INTEGER.isCompatible(product)); - Assert.assertFalse(product.isCompatible(INTEGER)); - } + @Test + public void compatibleButDifferentTypeNumberInTwoProductsShouldFail() { + Product product1 = new Product(Arrays.asList(KEYWORD, INTEGER)); + Product product2 = new Product(singletonList(STRING)); + Assert.assertFalse(product1.isCompatible(product2)); + Assert.assertFalse(product2.isCompatible(product1)); + } + @Test + public void baseTypeShouldBeIncompatibleWithProductType() { + Product product = new Product(singletonList(INTEGER)); + Assert.assertFalse(INTEGER.isCompatible(product)); + Assert.assertFalse(product.isCompatible(INTEGER)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/ResultSetTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/ResultSetTest.java index 69da4ca475..7cfada0b78 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/ResultSetTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/ResultSetTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static org.junit.Assert.assertFalse; @@ -13,18 +12,21 @@ public class ResultSetTest { - private final ResultSet resultSet = new ResultSet() { - @Override - public Schema getSchema() { - return super.getSchema(); - } - }; + private final ResultSet resultSet = + new ResultSet() { + @Override + public Schema getSchema() { + return super.getSchema(); + } + }; /** * Case #1: * LIKE 'test%' is converted to: - * 1. Regex pattern: test.* - * 2. OpenSearch search pattern: test* + *
    + *
  1. Regex pattern: test.* + *
  2. OpenSearch search pattern: test* + *
* In this case, what OpenSearch returns is the final result. */ @Test @@ -35,8 +37,10 @@ public void testWildcardForZeroOrMoreCharacters() { /** * Case #2: * LIKE 'test_123' is converted to: - * 1. Regex pattern: test.123 - * 2. OpenSearch search pattern: (all) + *
    x + *
  1. Regex pattern: test.123 + *
  2. OpenSearch search pattern: (all) + *
* Because OpenSearch doesn't support single wildcard character, in this case, none is passed * as OpenSearch search pattern. So all index names are returned and need to be filtered by * regex pattern again. @@ -49,12 +53,10 @@ public void testWildcardForSingleCharacter() { } /** - * Case #3: - * LIKE 'acc' has same regex and OpenSearch pattern. - * In this case, only index name(s) aliased by 'acc' is returned. - * So regex match is skipped to avoid wrong empty result. - * The assumption here is OpenSearch won't return unrelated index names if - * LIKE pattern doesn't include any wildcard. + * Case #3: LIKE 'acc' has same regex and OpenSearch pattern. In this case, only index name(s) + * aliased by 'acc' is returned. So regex match is skipped to avoid wrong empty result. The + * assumption here is OpenSearch won't return unrelated index names if LIKE pattern doesn't + * include any wildcard. */ @Test public void testIndexAlias() { @@ -62,11 +64,9 @@ public void testIndexAlias() { } /** - * Case #4: - * LIKE 'test.2020.10' has same regex pattern. Because it includes dot (wildcard), - * OpenSearch search pattern is all. - * In this case, all index names are returned. Because the pattern includes dot, - * it's treated as regex and regex match won't be skipped. + * Case #4: LIKE 'test.2020.10' has same regex pattern. Because it includes dot (wildcard), + * OpenSearch search pattern is all. In this case, all index names are returned. Because the + * pattern includes dot, it's treated as regex and regex match won't be skipped. */ @Test public void testIndexNameWithDot() { @@ -74,5 +74,4 @@ public void testIndexNameWithDot() { assertFalse(resultSet.matchesPatternIfRegex(".opensearch_dashboards", "test.2020.10")); assertTrue(resultSet.matchesPatternIfRegex("test.2020.10", "test.2020.10")); } - } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java index 64e5d161b7..30d8c9d27d 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java @@ -34,25 +34,19 @@ import org.opensearch.sql.sql.domain.SQLQueryRequest; import org.opensearch.threadpool.ThreadPool; -/** - * A test suite that verifies fallback behaviour of cursor queries. - */ +/** A test suite that verifies fallback behaviour of cursor queries. */ @RunWith(MockitoJUnitRunner.class) public class RestSQLQueryActionCursorFallbackTest extends BaseRestHandler { private NodeClient nodeClient; - @Mock - private ThreadPool threadPool; + @Mock private ThreadPool threadPool; - @Mock - private QueryManager queryManager; + @Mock private QueryManager queryManager; - @Mock - private QueryPlanFactory factory; + @Mock private QueryPlanFactory factory; - @Mock - private RestChannel restChannel; + @Mock private RestChannel restChannel; private Injector injector; @@ -60,11 +54,14 @@ public class RestSQLQueryActionCursorFallbackTest extends BaseRestHandler { public void setup() { nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool); ModulesBuilder modules = new ModulesBuilder(); - modules.add(b -> { - b.bind(SQLService.class).toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); - }); + modules.add( + b -> { + b.bind(SQLService.class) + .toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); + }); injector = modules.createInjector(); - Mockito.lenient().when(threadPool.getThreadContext()) + Mockito.lenient() + .when(threadPool.getThreadContext()) .thenReturn(new ThreadContext(org.opensearch.common.settings.Settings.EMPTY)); } @@ -73,17 +70,14 @@ public void setup() { @Test public void no_fallback_with_column_reference() throws Exception { String query = "SELECT name FROM test1"; - SQLQueryRequest request = createSqlQueryRequest(query, Optional.empty(), - Optional.of(5)); + SQLQueryRequest request = createSqlQueryRequest(query, Optional.empty(), Optional.of(5)); assertFalse(doesQueryFallback(request)); } - private static SQLQueryRequest createSqlQueryRequest(String query, Optional cursorId, - Optional fetchSize) throws IOException { - var builder = XContentFactory.jsonBuilder() - .startObject() - .field("query").value(query); + private static SQLQueryRequest createSqlQueryRequest( + String query, Optional cursorId, Optional fetchSize) throws IOException { + var builder = XContentFactory.jsonBuilder().startObject().field("query").value(query); if (cursorId.isPresent()) { builder.field("cursor").value(cursorId.get()); } @@ -94,17 +88,21 @@ private static SQLQueryRequest createSqlQueryRequest(String query, Optional { - fallback.set(true); - }, (channel, exception) -> { - }).accept(restChannel); + queryAction + .prepareRequest( + request, + (channel, exception) -> { + fallback.set(true); + }, + (channel, exception) -> {}) + .accept(restChannel); return fallback.get(); } @@ -115,8 +113,8 @@ public String getName() { } @Override - protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) - { + protected BaseRestHandler.RestChannelConsumer prepareRequest( + RestRequest restRequest, NodeClient nodeClient) { // do nothing, RestChannelConsumer is protected which required to extend BaseRestHandler return null; } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java index be572f3dfb..b14b2c09cb 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.plugin; import static org.junit.Assert.assertTrue; @@ -42,17 +41,13 @@ public class RestSQLQueryActionTest extends BaseRestHandler { private NodeClient nodeClient; - @Mock - private ThreadPool threadPool; + @Mock private ThreadPool threadPool; - @Mock - private QueryManager queryManager; + @Mock private QueryManager queryManager; - @Mock - private QueryPlanFactory factory; + @Mock private QueryPlanFactory factory; - @Mock - private RestChannel restChannel; + @Mock private RestChannel restChannel; private Injector injector; @@ -60,88 +55,112 @@ public class RestSQLQueryActionTest extends BaseRestHandler { public void setup() { nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool); ModulesBuilder modules = new ModulesBuilder(); - modules.add(b -> { - b.bind(SQLService.class).toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); - }); + modules.add( + b -> { + b.bind(SQLService.class) + .toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); + }); injector = modules.createInjector(); - Mockito.lenient().when(threadPool.getThreadContext()) + Mockito.lenient() + .when(threadPool.getThreadContext()) .thenReturn(new ThreadContext(org.opensearch.common.settings.Settings.EMPTY)); } @Test public void handleQueryThatCanSupport() throws Exception { - SQLQueryRequest request = new SQLQueryRequest( - new JSONObject("{\"query\": \"SELECT -123\"}"), - "SELECT -123", - QUERY_API_ENDPOINT, - "jdbc"); + SQLQueryRequest request = + new SQLQueryRequest( + new JSONObject("{\"query\": \"SELECT -123\"}"), + "SELECT -123", + QUERY_API_ENDPOINT, + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); - queryAction.prepareRequest(request, (channel, exception) -> { - fail(); - }, (channel, exception) -> { - fail(); - }).accept(restChannel); + queryAction + .prepareRequest( + request, + (channel, exception) -> { + fail(); + }, + (channel, exception) -> { + fail(); + }) + .accept(restChannel); } @Test public void handleExplainThatCanSupport() throws Exception { - SQLQueryRequest request = new SQLQueryRequest( - new JSONObject("{\"query\": \"SELECT -123\"}"), - "SELECT -123", - EXPLAIN_API_ENDPOINT, - "jdbc"); + SQLQueryRequest request = + new SQLQueryRequest( + new JSONObject("{\"query\": \"SELECT -123\"}"), + "SELECT -123", + EXPLAIN_API_ENDPOINT, + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); - queryAction.prepareRequest(request, (channel, exception) -> { - fail(); - }, (channel, exception) -> { - fail(); - }).accept(restChannel); + queryAction + .prepareRequest( + request, + (channel, exception) -> { + fail(); + }, + (channel, exception) -> { + fail(); + }) + .accept(restChannel); } @Test public void queryThatNotSupportIsHandledByFallbackHandler() throws Exception { - SQLQueryRequest request = new SQLQueryRequest( - new JSONObject( - "{\"query\": \"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name\"}"), - "SELECT name FROM test1 JOIN test2 ON test1.name = test2.name", - QUERY_API_ENDPOINT, - "jdbc"); + SQLQueryRequest request = + new SQLQueryRequest( + new JSONObject( + "{\"query\": \"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name\"}"), + "SELECT name FROM test1 JOIN test2 ON test1.name = test2.name", + QUERY_API_ENDPOINT, + "jdbc"); AtomicBoolean fallback = new AtomicBoolean(false); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); - queryAction.prepareRequest(request, (channel, exception) -> { - fallback.set(true); - assertTrue(exception instanceof SyntaxCheckException); - }, (channel, exception) -> { - fail(); - }).accept(restChannel); + queryAction + .prepareRequest( + request, + (channel, exception) -> { + fallback.set(true); + assertTrue(exception instanceof SyntaxCheckException); + }, + (channel, exception) -> { + fail(); + }) + .accept(restChannel); assertTrue(fallback.get()); } @Test public void queryExecutionFailedIsHandledByExecutionErrorHandler() throws Exception { - SQLQueryRequest request = new SQLQueryRequest( - new JSONObject( - "{\"query\": \"SELECT -123\"}"), - "SELECT -123", - QUERY_API_ENDPOINT, - "jdbc"); + SQLQueryRequest request = + new SQLQueryRequest( + new JSONObject("{\"query\": \"SELECT -123\"}"), + "SELECT -123", + QUERY_API_ENDPOINT, + "jdbc"); - doThrow(new IllegalStateException("execution exception")) - .when(queryManager) - .submit(any()); + doThrow(new IllegalStateException("execution exception")).when(queryManager).submit(any()); AtomicBoolean executionErrorHandler = new AtomicBoolean(false); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); - queryAction.prepareRequest(request, (channel, exception) -> { - assertTrue(exception instanceof SyntaxCheckException); - }, (channel, exception) -> { - executionErrorHandler.set(true); - assertTrue(exception instanceof IllegalStateException); - }).accept(restChannel); + queryAction + .prepareRequest( + request, + (channel, exception) -> { + assertTrue(exception instanceof SyntaxCheckException); + }, + (channel, exception) -> { + executionErrorHandler.set(true); + assertTrue(exception instanceof IllegalStateException); + }) + .accept(restChannel); assertTrue(executionErrorHandler.get()); } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java index fe5c641009..dd0fc626c0 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java @@ -20,7 +20,7 @@ public void testKeyWithObjectField() { SearchHit hit = new SearchHit(1); hit.sourceRef(new BytesArray("{\"id\": {\"serial\": 3}}")); SearchHitRow row = new SearchHitRow(hit, "a"); - RowKey key = row.key(new String[]{"id.serial"}); + RowKey key = row.key(new String[] {"id.serial"}); Object[] data = key.keys(); assertEquals(1, data.length); @@ -32,7 +32,7 @@ public void testKeyWithUnexpandedObjectField() { SearchHit hit = new SearchHit(1); hit.sourceRef(new BytesArray("{\"attributes.hardware.correlate_id\": 10}")); SearchHitRow row = new SearchHitRow(hit, "a"); - RowKey key = row.key(new String[]{"attributes.hardware.correlate_id"}); + RowKey key = row.key(new String[] {"attributes.hardware.correlate_id"}); Object[] data = key.keys(); assertEquals(1, data.length); diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldProjectionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldProjectionTest.java index 63af01caaa..859259756f 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldProjectionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldProjectionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static org.hamcrest.MatcherAssert.assertThat; @@ -52,372 +51,284 @@ public class NestedFieldProjectionTest { - @Test - public void regression() { - assertThat(query("SELECT region FROM team"), is(anything())); - assertThat(query("SELECT region FROM team WHERE nested(employees.age) = 30"), is(anything())); - assertThat(query("SELECT * FROM team WHERE region = 'US'"), is(anything())); - } - - @Test - public void nestedFieldSelectAll() { - assertThat( - query("SELECT nested(employees.*) FROM team"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.*") - ) - ) - ) - ) - ) - ) - ); - } - - @Test - public void nestedFieldInSelect() { - assertThat( - query("SELECT nested(employees.firstname) FROM team"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.firstname") - ) - ) - ) - ) - ) - ) - ); - } - - @Test - public void regularAndNestedFieldInSelect() { - assertThat( - query("SELECT region, nested(employees.firstname) FROM team"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.firstname") - ) - ) - ) - ) - ), - fetchSource("region") - ) - ); - } - - /* - // Should be integration test - @Test - public void nestedFieldInWhereSelectAll() {} - */ - - @Test - public void nestedFieldInSelectAndWhere() { - assertThat( - query("SELECT nested(employees.firstname) " + - " FROM team " + - " WHERE nested(employees.age) = 30"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.firstname") - ) - ) - ) - ) - ) - ) - ); - } - - @Test - public void regularAndNestedFieldInSelectAndWhere() { - assertThat( - query("SELECT region, nested(employees.firstname) " + - " FROM team " + - " WHERE nested(employees.age) = 30"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - innerHits("employees.firstname") - ) - ) - ) - ) - ), - fetchSource("region") - ) - ); - } - - @Test - public void multipleSameNestedFields() { - assertThat( - query("SELECT nested(employees.firstname), nested(employees.lastname) " + - " FROM team " + - " WHERE nested(\"employees\", employees.age = 30 AND employees.firstname LIKE 'John')"), - source( - boolQuery( - filter( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.firstname", "employees.lastname") - ) - ) - ) - ) - ) - ) - ); + @Test + public void regression() { + assertThat(query("SELECT region FROM team"), is(anything())); + assertThat(query("SELECT region FROM team WHERE nested(employees.age) = 30"), is(anything())); + assertThat(query("SELECT * FROM team WHERE region = 'US'"), is(anything())); + } + + @Test + public void nestedFieldSelectAll() { + assertThat( + query("SELECT nested(employees.*) FROM team"), + source( + boolQuery( + filter( + boolQuery(must(nestedQuery(path("employees"), innerHits("employees.*")))))))); + } + + @Test + public void nestedFieldInSelect() { + assertThat( + query("SELECT nested(employees.firstname) FROM team"), + source( + boolQuery( + filter( + boolQuery( + must(nestedQuery(path("employees"), innerHits("employees.firstname")))))))); + } + + @Test + public void regularAndNestedFieldInSelect() { + assertThat( + query("SELECT region, nested(employees.firstname) FROM team"), + source( + boolQuery( + filter( + boolQuery( + must(nestedQuery(path("employees"), innerHits("employees.firstname")))))), + fetchSource("region"))); + } + + /* + // Should be integration test + @Test + public void nestedFieldInWhereSelectAll() {} + */ + + @Test + public void nestedFieldInSelectAndWhere() { + assertThat( + query( + "SELECT nested(employees.firstname) " + + " FROM team " + + " WHERE nested(employees.age) = 30"), + source( + boolQuery( + filter( + boolQuery( + must(nestedQuery(path("employees"), innerHits("employees.firstname")))))))); + } + + @Test + public void regularAndNestedFieldInSelectAndWhere() { + assertThat( + query( + "SELECT region, nested(employees.firstname) " + + " FROM team " + + " WHERE nested(employees.age) = 30"), + source( + boolQuery(filter(boolQuery(must(nestedQuery(innerHits("employees.firstname")))))), + fetchSource("region"))); + } + + @Test + public void multipleSameNestedFields() { + assertThat( + query( + "SELECT nested(employees.firstname), nested(employees.lastname) FROM team WHERE" + + " nested(\"employees\", employees.age = 30 AND employees.firstname LIKE 'John')"), + source( + boolQuery( + filter( + boolQuery( + must( + nestedQuery( + path("employees"), + innerHits("employees.firstname", "employees.lastname")))))))); + } + + @Test + public void multipleDifferentNestedFields() { + assertThat( + query( + "SELECT region, nested(employees.firstname), nested(manager.name) " + + " FROM team " + + " WHERE nested(employees.age) = 30 AND nested(manager.age) = 50"), + source( + boolQuery( + filter( + boolQuery( + must( + boolQuery( + must( + nestedQuery( + path("employees"), innerHits("employees.firstname")), + nestedQuery(path("manager"), innerHits("manager.name")))))))), + fetchSource("region"))); + } + + @Test + public void leftJoinWithSelectAll() { + assertThat( + query("SELECT * FROM team AS t LEFT JOIN t.projects AS p "), + source( + boolQuery( + filter( + boolQuery( + should( + boolQuery(mustNot(nestedQuery(path("projects")))), + nestedQuery(path("projects"), innerHits("projects.*")))))))); + } + + @Test + public void leftJoinWithSpecificFields() { + assertThat( + query("SELECT t.name, p.name, p.started_year FROM team AS t LEFT JOIN t.projects AS p "), + source( + boolQuery( + filter( + boolQuery( + should( + boolQuery(mustNot(nestedQuery(path("projects")))), + nestedQuery( + path("projects"), + innerHits("projects.name", "projects.started_year")))))), + fetchSource("name"))); + } + + private Matcher source(Matcher queryMatcher) { + return featureValueOf("query", queryMatcher, SearchSourceBuilder::query); + } + + private Matcher source( + Matcher queryMatcher, Matcher fetchSourceMatcher) { + return allOf( + featureValueOf("query", queryMatcher, SearchSourceBuilder::query), + featureValueOf("fetchSource", fetchSourceMatcher, SearchSourceBuilder::fetchSource)); + } + + /** + * Asserting instanceOf and continue other chained matchers of subclass requires explicity cast + */ + @SuppressWarnings("unchecked") + private Matcher boolQuery(Matcher matcher) { + return (Matcher) allOf(instanceOf(BoolQueryBuilder.class), matcher); + } + + @SafeVarargs + @SuppressWarnings("unchecked") + private final Matcher nestedQuery(Matcher... matchers) { + return (Matcher) + both(is(Matchers.instanceOf(NestedQueryBuilder.class))) + .and(allOf(matchers)); + } + + @SafeVarargs + private final FeatureMatcher> filter( + Matcher... matchers) { + return hasClauses("filter", BoolQueryBuilder::filter, matchers); + } + + @SafeVarargs + private final FeatureMatcher> must( + Matcher... matchers) { + return hasClauses("must", BoolQueryBuilder::must, matchers); + } + + @SafeVarargs + private final FeatureMatcher> mustNot( + Matcher... matchers) { + return hasClauses("must_not", BoolQueryBuilder::mustNot, matchers); + } + + @SafeVarargs + private final FeatureMatcher> should( + Matcher... matchers) { + return hasClauses("should", BoolQueryBuilder::should, matchers); + } + + /** Hide contains() assertion to simplify */ + @SafeVarargs + private final FeatureMatcher> hasClauses( + String name, + Function> func, + Matcher... matchers) { + return new FeatureMatcher>( + contains(matchers), name, name) { + @Override + protected List featureValueOf(BoolQueryBuilder query) { + return func.apply(query); + } + }; + } + + private Matcher path(String expected) { + return HasFieldWithValue.hasFieldWithValue("path", "path", is(equalTo(expected))); + } + + /** Skip intermediate property along the path. Hide arrayContaining assertion to simplify. */ + private FeatureMatcher innerHits(String... expected) { + return featureValueOf( + "innerHits", + arrayContaining(expected), + (nestedQuery -> nestedQuery.innerHit().getFetchSourceContext().includes())); + } + + @SuppressWarnings("unchecked") + private Matcher fetchSource(String... expected) { + if (expected.length == 0) { + return anyOf( + is(nullValue()), + featureValueOf("includes", is(nullValue()), FetchSourceContext::includes), + featureValueOf("includes", is(emptyArray()), FetchSourceContext::includes)); } - - @Test - public void multipleDifferentNestedFields() { - assertThat( - query("SELECT region, nested(employees.firstname), nested(manager.name) " + - " FROM team " + - " WHERE nested(employees.age) = 30 AND nested(manager.age) = 50"), - source( - boolQuery( - filter( - boolQuery( - must( - boolQuery( - must( - nestedQuery( - path("employees"), - innerHits("employees.firstname") - ), - nestedQuery( - path("manager"), - innerHits("manager.name") - ) - ) - ) - ) - ) - ) - ), - fetchSource("region") - ) - ); - } - - - @Test - public void leftJoinWithSelectAll() { - assertThat( - query("SELECT * FROM team AS t LEFT JOIN t.projects AS p "), - source( - boolQuery( - filter( - boolQuery( - should( - boolQuery( - mustNot( - nestedQuery( - path("projects") - ) - ) - ), - nestedQuery( - path("projects"), - innerHits("projects.*") - ) - ) - ) - ) - ) - ) - ); - } - - @Test - public void leftJoinWithSpecificFields() { - assertThat( - query("SELECT t.name, p.name, p.started_year FROM team AS t LEFT JOIN t.projects AS p "), - source( - boolQuery( - filter( - boolQuery( - should( - boolQuery( - mustNot( - nestedQuery( - path("projects") - ) - ) - ), - nestedQuery( - path("projects"), - innerHits("projects.name", "projects.started_year") - ) - ) - ) - ) - ), - fetchSource("name") - ) - ); - } - - private Matcher source(Matcher queryMatcher) { - return featureValueOf("query", queryMatcher, SearchSourceBuilder::query); - } - - private Matcher source(Matcher queryMatcher, - Matcher fetchSourceMatcher) { - return allOf( - featureValueOf("query", queryMatcher, SearchSourceBuilder::query), - featureValueOf("fetchSource", fetchSourceMatcher, SearchSourceBuilder::fetchSource) - ); - } - - /** Asserting instanceOf and continue other chained matchers of subclass requires explicity cast */ - @SuppressWarnings("unchecked") - private Matcher boolQuery(Matcher matcher) { - return (Matcher) allOf(instanceOf(BoolQueryBuilder.class), matcher); + return featureValueOf( + "includes", contains(expected), fetchSource -> Arrays.asList(fetchSource.includes())); + } + + private FeatureMatcher featureValueOf( + String name, Matcher subMatcher, Function getter) { + return new FeatureMatcher(subMatcher, name, name) { + @Override + protected U featureValueOf(T actual) { + return getter.apply(actual); + } + }; + } + + private SearchSourceBuilder query(String sql) { + SQLQueryExpr expr = parseSql(sql); + if (sql.contains("nested")) { + return translate(expr).source(); } - @SafeVarargs - @SuppressWarnings("unchecked") - private final Matcher nestedQuery(Matcher... matchers) { - return (Matcher) both(is(Matchers.instanceOf(NestedQueryBuilder.class))). - and(allOf(matchers)); + expr = rewrite(expr); + return translate(expr).source(); + } + + private SearchRequest translate(SQLQueryExpr expr) { + try { + Client mockClient = Mockito.mock(Client.class); + SearchRequestBuilder request = new SearchRequestBuilder(mockClient, SearchAction.INSTANCE); + Select select = new SqlParser().parseSelect(expr); + + DefaultQueryAction action = new DefaultQueryAction(mockClient, select); + action.initialize(request); + action.setFields(select.getFields()); + + if (select.getWhere() != null) { + request.setQuery(QueryMaker.explain(select.getWhere(), select.isQuery)); + } + new NestedFieldProjection(request).project(select.getFields(), select.getNestedJoinType()); + return request.request(); + } catch (SqlParseException e) { + throw new ParserException("Illegal sql expr: " + expr.toString()); } + } - @SafeVarargs - private final FeatureMatcher> filter(Matcher... matchers) { - return hasClauses("filter", BoolQueryBuilder::filter, matchers); + private SQLQueryExpr parseSql(String sql) { + ElasticSqlExprParser parser = new ElasticSqlExprParser(sql); + SQLExpr expr = parser.expr(); + if (parser.getLexer().token() != Token.EOF) { + throw new ParserException("Illegal sql: " + sql); } + return (SQLQueryExpr) expr; + } - @SafeVarargs - private final FeatureMatcher> must(Matcher... matchers) { - return hasClauses("must", BoolQueryBuilder::must, matchers); - } - - @SafeVarargs - private final FeatureMatcher> mustNot(Matcher... matchers) { - return hasClauses("must_not", BoolQueryBuilder::mustNot, matchers); - } - - @SafeVarargs - private final FeatureMatcher> should(Matcher... matchers) { - return hasClauses("should", BoolQueryBuilder::should, matchers); - } - - /** Hide contains() assertion to simplify */ - @SafeVarargs - private final FeatureMatcher> hasClauses(String name, - Function> func, - Matcher... matchers) { - return new FeatureMatcher>(contains(matchers), name, name) { - @Override - protected List featureValueOf(BoolQueryBuilder query) { - return func.apply(query); - } - }; - } - - private Matcher path(String expected) { - return HasFieldWithValue.hasFieldWithValue("path", "path", is(equalTo(expected))); - } - - /** Skip intermediate property along the path. Hide arrayContaining assertion to simplify. */ - private FeatureMatcher innerHits(String... expected) { - return featureValueOf("innerHits", - arrayContaining(expected), - (nestedQuery -> nestedQuery.innerHit().getFetchSourceContext().includes())); - } - - @SuppressWarnings("unchecked") - private Matcher fetchSource(String... expected) { - if (expected.length == 0) { - return anyOf(is(nullValue()), - featureValueOf("includes", is(nullValue()), FetchSourceContext::includes), - featureValueOf("includes", is(emptyArray()), FetchSourceContext::includes)); - } - return featureValueOf("includes", contains(expected), fetchSource -> Arrays.asList(fetchSource.includes())); - } - - private FeatureMatcher featureValueOf(String name, Matcher subMatcher, Function getter) { - return new FeatureMatcher(subMatcher, name, name) { - @Override - protected U featureValueOf(T actual) { - return getter.apply(actual); - } - }; - } - - private SearchSourceBuilder query(String sql) { - SQLQueryExpr expr = parseSql(sql); - if (sql.contains("nested")) { - return translate(expr).source(); - } - - expr = rewrite(expr); - return translate(expr).source(); - } - - private SearchRequest translate(SQLQueryExpr expr) { - try { - Client mockClient = Mockito.mock(Client.class); - SearchRequestBuilder request = new SearchRequestBuilder(mockClient, SearchAction.INSTANCE); - Select select = new SqlParser().parseSelect(expr); - - DefaultQueryAction action = new DefaultQueryAction(mockClient, select); - action.initialize(request); - action.setFields(select.getFields()); - - if (select.getWhere() != null) { - request.setQuery(QueryMaker.explain(select.getWhere(), select.isQuery)); - } - new NestedFieldProjection(request).project(select.getFields(), select.getNestedJoinType()); - return request.request(); - } - catch (SqlParseException e) { - throw new ParserException("Illegal sql expr: " + expr.toString()); - } - } - - private SQLQueryExpr parseSql(String sql) { - ElasticSqlExprParser parser = new ElasticSqlExprParser(sql); - SQLExpr expr = parser.expr(); - if (parser.getLexer().token() != Token.EOF) { - throw new ParserException("Illegal sql: " + sql); - } - return (SQLQueryExpr) expr; - } - - private SQLQueryExpr rewrite(SQLQueryExpr expr) { - expr.accept(new NestedFieldRewriter()); - return expr; - } + private SQLQueryExpr rewrite(SQLQueryExpr expr) { + expr.accept(new NestedFieldRewriter()); + return expr; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldRewriterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldRewriterTest.java index 58a6f7e244..2593f25379 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldRewriterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/NestedFieldRewriterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static java.util.stream.IntStream.range; @@ -29,630 +28,608 @@ public class NestedFieldRewriterTest { - @Test - public void regression() { - noImpact("SELECT * FROM team"); - noImpact("SELECT region FROM team/test, employees/test"); - noImpact("SELECT manager.name FROM team WHERE region = 'US' ORDER BY COUNT(*)"); - noImpact("SELECT COUNT(*) FROM team GROUP BY region"); - } - - @Test - public void selectWithoutFrom() { - // Expect no exception thrown - query("SELECT now()"); - } - - @Test - public void selectAll() { - same( - query("SELECT * FROM team t, t.employees"), - query("SELECT *, nested(employees.*, 'employees') FROM team") - ); - } - - @Test - public void selectAllWithGroupBy() { - same( - query("SELECT * FROM team t, t.employees e GROUP BY e.firstname"), - query("SELECT * FROM team GROUP BY nested(employees.firstname, 'employees')") - ); - } - - @Test - public void selectAllWithCondition() { - same( - query("SELECT * FROM team t, t.employees e WHERE e.age = 26"), - query("SELECT *, nested(employees.*, 'employees') FROM team WHERE nested(employees.age, 'employees') = 26") - ); - } - - @Test - public void singleCondition() { - same( - query("SELECT region FROM team t, t.employees e WHERE e.age = 26"), - query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26") - ); - } - - @Test - public void mixedWithObjectType() { - same( - query("SELECT region FROM team t, t.employees e WHERE e.age > 30 OR manager.age = 50"), - query("SELECT region FROM team WHERE nested(employees.age, 'employees') > 30 OR manager.age = 50") - ); - } - - @Test - public void noAlias() { - same( - query("SELECT region FROM team t, t.employees WHERE employees.age = 26"), - query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26") - ); - } - - @Test(expected = AssertionError.class) - public void multipleRegularTables() { - same( - query("SELECT region FROM team t, t.employees e, company c WHERE e.age = 26"), - query("SELECT region FROM team, company WHERE nested(employees.age) = 26") - ); - } - - @Test - public void eraseParentAlias() { - same( - query("SELECT t.age FROM team t, t.employees e WHERE t.region = 'US' AND age > 26"), - query("SELECT age FROM team WHERE region = 'US' AND age > 26") - ); - noImpact("SELECT t.age FROM team t WHERE t.region = 'US'"); - } - - @Test - public void select() { - same( - query("SELECT e.age FROM team t, t.employees e"), - query("SELECT nested(employees.age, 'employees' ) FROM team") - ); - } - - @Test - public void aggregationInSelect() { - same( - query("SELECT AVG(e.age) FROM team t, t.employees e"), - query("SELECT AVG(nested(employees.age, 'employees')) FROM team") - ); - } - - @Test - public void multipleAggregationsInSelect() { - same( - query("SELECT COUNT(*), AVG(e.age) FROM team t, t.employees e"), - query("SELECT COUNT(*), AVG(nested(employees.age, 'employees')) FROM team") - ); - } - - @Test - public void groupBy() { - same( - query("SELECT e.firstname, COUNT(*) FROM team t, t.employees e GROUP BY e.firstname"), - query("SELECT nested(employees.firstname, 'employees'), COUNT(*) FROM team GROUP BY nested(employees.firstname, 'employees')") - ); - } - - @Test - public void multipleFieldsInGroupBy() { - same( - query("SELECT COUNT(*) FROM team t, t.employees e GROUP BY t.manager, e.age"), - query("SELECT COUNT(*) FROM team GROUP BY manager, nested(employees.age, 'employees')") - ); - } - - @Test - public void orderBy() { - same( - query("SELECT region FROM team t, t.employees e ORDER BY e.age"), - query("SELECT region FROM team ORDER BY nested(employees.age)") - ); - } - - @Test - public void multipleConditions() { - same( - query("SELECT region " + - "FROM team t, t.manager m, t.employees e " + - "WHERE t.department = 'IT' AND " + - " (e.age = 26 OR (e.firstname = 'John' AND e.lastname = 'Smith')) AND " + - " t.region = 'US' AND " + - " (m.name = 'Alice' AND m.age = 50)"), - query("SELECT region " + - "FROM team " + - "WHERE department = 'IT' AND " + - " nested(\"employees\", employees.age = 26 OR (employees.firstname = 'John' AND employees.lastname = 'Smith')) AND " + - " region = 'US' AND " + - " nested(\"manager\", manager.name = 'Alice' AND manager.age = 50)") - ); - } - - @Test - public void multipleFieldsInFrom() { - same( - query("SELECT region FROM team/test t, t.manager m, t.employees e WHERE m.age = 30 AND e.age = 26"), - query("SELECT region FROM team/test WHERE nested(manager.age, 'manager') = 30 " + - "AND nested(employees.age, 'employees') = 26") - ); - } - - @Test - public void unionAll() { - // NLPchina doesn't support UNION (intersection) - same( - query("SELECT region FROM team t, t.employees e WHERE e.age = 26 " + - "UNION ALL " + - "SELECT region FROM team t, t.employees e WHERE e.firstname = 'John'"), - query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26 " + - "UNION ALL " + - "SELECT region FROM team WHERE nested(employees.firstname, 'employees') = 'John'") - ); - } - - @Test - public void minus() { - same( - query("SELECT region FROM team t, t.employees e WHERE e.age = 26 " + - "MINUS " + - "SELECT region FROM team t, t.employees e WHERE e.firstname = 'John'"), - query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26 " + - "MINUS " + - "SELECT region FROM team WHERE nested(employees.firstname, 'employees') = 'John'") - ); - } - - public void join() { - // TODO - } - - @Test - public void subQuery() { - // Subquery only support IN and TERMS - same( - query("SELECT region FROM team t, t.employees e " + - " WHERE e.age IN " + - " (SELECT t1.manager.age FROM team t1, t1.employees e1 WHERE e1.age > 0)"), - query("SELECT region FROM team " + - " WHERE nested(employees.age, 'employees') IN " + - " (SELECT manager.age FROM team WHERE nested(employees.age, 'employees') > 0)") - ); - } - - @Test - public void subQueryWitSameAlias() { - // Inner alias e shadow outer alias e of nested field - same( - query("SELECT name FROM team t, t.employees e " + - " WHERE e.age IN " + - " (SELECT e.age FROM team e, e.manager m WHERE e.age > 0 OR m.name = 'Alice')"), - query("SELECT name FROM team " + - " WHERE nested(employees.age, 'employees') IN " + - " (SELECT age FROM team WHERE age > 0 OR nested(manager.name, 'manager') = 'Alice')") - ); - } - - @Test - public void isNotNull() { - same( - query("SELECT e.name " + - "FROM employee as e, e.projects as p " + - "WHERE p IS NOT MISSING"), - query("SELECT name " + - "FROM employee " + - "WHERE nested(projects, 'projects') IS NOT MISSING") - ); - } - - @Test - public void isNotNullAndCondition() { - same( - query("SELECT e.name " + - "FROM employee as e, e.projects as p " + - "WHERE p IS NOT MISSING AND p.name LIKE 'security'"), - query("SELECT name " + - "FROM employee " + - "WHERE nested('projects', projects IS NOT MISSING AND projects.name LIKE 'security')") - ); - } - - @Test - public void multiCondition() { - same( - query("SELECT e.name FROM employee as e, e.projects as p WHERE p.year = 2016 and p.name LIKE 'security'"), - query("SELECT name FROM employee WHERE nested('projects', projects.year = 2016 AND projects.name LIKE 'security')") - ); - } - - @Test - public void nestedAndParentCondition() { - same( - query("SELECT name " + - "FROM employee " + - "WHERE nested(projects, 'projects') IS NOT MISSING AND name LIKE 'security'"), - query("SELECT e.name " + - "FROM employee e, e.projects p " + - "WHERE p IS NOT MISSING AND e.name LIKE 'security'") - ); - } - - @Test - public void aggWithWhereOnParent() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE name LIKE '%smith%' " + - "GROUP BY name " + - "HAVING c > 1") - ); - - } - - @Test - public void aggWithWhereOnNested() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING c > 1") - ); - } - - @Test - public void aggWithWhereOnParentOrNested() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' or p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE name LIKE '%smith%' OR nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING c > 1") - ); - } - - @Test - public void aggWithWhereOnParentAndNested() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' AND p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE name LIKE '%smith%' AND nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING c > 1") - ); - } - - @Test - public void aggWithWhereOnNestedAndNested() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.started_year > 1990 AND p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE nested('projects', projects.started_year > 1990 AND projects.name LIKE '%security%') " + - "GROUP BY name " + - "HAVING c > 1") - ); - } - - @Test - public void aggWithWhereOnNestedOrNested() { - same( - query("SELECT e.name, COUNT(p) as c " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.started_year > 1990 OR p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING c > 1"), - query("SELECT name, COUNT(nested(projects, 'projects')) AS c " + - "FROM employee " + - "WHERE nested('projects', projects.started_year > 1990 OR projects.name LIKE '%security%') " + - "GROUP BY name " + - "HAVING c > 1") - ); - } - - @Test - public void aggInHavingWithWhereOnParent() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE name LIKE '%smith%' " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - - } - - @Test - public void aggInHavingWithWhereOnNested() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - } - - @Test - public void aggInHavingWithWhereOnParentOrNested() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' or p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE name LIKE '%smith%' OR nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - } - - @Test - public void aggInHavingWithWhereOnParentAndNested() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE e.name like '%smith%' AND p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE name LIKE '%smith%' AND nested(projects.name, 'projects') LIKE '%security%' " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - } - - @Test - public void aggInHavingWithWhereOnNestedAndNested() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.started_year > 1990 AND p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE nested('projects', projects.started_year > 1990 AND projects.name LIKE '%security%') " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - } - - @Test - public void aggInHavingWithWhereOnNestedOrNested() { - same( - query("SELECT e.name " + - "FROM employee AS e, e.projects AS p " + - "WHERE p.started_year > 1990 OR p.name LIKE '%security%' " + - "GROUP BY e.name " + - "HAVING COUNT(p) > 1"), - query("SELECT name " + - "FROM employee " + - "WHERE nested('projects', projects.started_year > 1990 OR projects.name LIKE '%security%') " + - "GROUP BY name " + - "HAVING COUNT(nested(projects, 'projects')) > 1") - ); - } - - @Test - public void notIsNotNull() { - same( - query("SELECT name " + - "FROM employee " + - "WHERE not (nested(projects, 'projects') IS NOT MISSING)"), - query("SELECT e.name " + - "FROM employee as e, e.projects as p " + - "WHERE not (p IS NOT MISSING)") - ); - } - - @Test - public void notIsNotNullAndCondition() { - same( - query("SELECT e.name " + - "FROM employee as e, e.projects as p " + - "WHERE not (p IS NOT MISSING AND p.name LIKE 'security')"), - query("SELECT name " + - "FROM employee " + - "WHERE not nested('projects', projects IS NOT MISSING AND projects.name LIKE 'security')") - ); - } - - @Test - public void notMultiCondition() { - same( - query("SELECT name " + - "FROM employee " + - "WHERE not nested('projects', projects.year = 2016 AND projects.name LIKE 'security')"), - query("SELECT e.name " + - "FROM employee as e, e.projects as p " + - "WHERE not (p.year = 2016 and p.name LIKE 'security')") - ); - } - - @Test - public void notNestedAndParentCondition() { - same( - query("SELECT name " + - "FROM employee " + - "WHERE (not nested(projects, 'projects') IS NOT MISSING) AND name LIKE 'security'"), - query("SELECT e.name " + - "FROM employee e, e.projects p " + - "WHERE not (p IS NOT MISSING) AND e.name LIKE 'security'") - ); - } - - private void noImpact(String sql) { - same(parse(sql), rewrite(parse(sql))); - } - - /** - * The intention for this assert method is: - * - * 1) MySqlSelectQueryBlock.equals() doesn't call super.equals(). - * But select items, from, where and group by are all held by parent class SQLSelectQueryBlock. - * - * 2) SQLSelectGroupByClause doesn't implement equals() at all.. - * MySqlSelectGroupByExpr compares identity of expression.. - * - * 3) MySqlUnionQuery doesn't implement equals() at all - */ - private void same(SQLQueryExpr actual, SQLQueryExpr expected) { - assertEquals(expected.getClass(), actual.getClass()); - - SQLSelect expectedQuery = expected.getSubQuery(); - SQLSelect actualQuery = actual.getSubQuery(); - assertEquals(expectedQuery.getOrderBy(), actualQuery.getOrderBy()); - assertQuery(expectedQuery, actualQuery); - } - - private void assertQuery(SQLSelect expected, SQLSelect actual) { - SQLSelectQuery expectedQuery = expected.getQuery(); - SQLSelectQuery actualQuery = actual.getQuery(); - if (actualQuery instanceof SQLSelectQueryBlock) { - assertQueryBlock( - (SQLSelectQueryBlock) expectedQuery, - (SQLSelectQueryBlock) actualQuery - ); - } - else if (actualQuery instanceof SQLUnionQuery) { - assertQueryBlock( - (SQLSelectQueryBlock) ((SQLUnionQuery) expectedQuery).getLeft(), - (SQLSelectQueryBlock) ((SQLUnionQuery) actualQuery).getLeft() - ); - assertQueryBlock( - (SQLSelectQueryBlock) ((SQLUnionQuery) expectedQuery).getRight(), - (SQLSelectQueryBlock) ((SQLUnionQuery) actualQuery).getRight() - ); - assertEquals( - ((SQLUnionQuery) expectedQuery).getOperator(), - ((SQLUnionQuery) actualQuery).getOperator() - ); - } - else { - throw new IllegalStateException("Unsupported test SQL"); - } - } - - private void assertQueryBlock(SQLSelectQueryBlock expected, SQLSelectQueryBlock actual) { - assertEquals("SELECT", expected.getSelectList(), actual.getSelectList()); - assertEquals("INTO", expected.getInto(), actual.getInto()); - assertEquals("WHERE", expected.getWhere(), actual.getWhere()); - if (actual.getWhere() instanceof SQLInSubQueryExpr) { - assertQuery( - ((SQLInSubQueryExpr) expected.getWhere()).getSubQuery(), - ((SQLInSubQueryExpr) actual.getWhere()).getSubQuery() - ); - } - assertEquals("PARENTHESIZED", expected.isParenthesized(), actual.isParenthesized()); - assertEquals("DISTION", expected.getDistionOption(), actual.getDistionOption()); - assertFrom(expected, actual); - if (!(expected.getGroupBy() == null && actual.getGroupBy() == null)) { - assertGroupBy(expected.getGroupBy(), actual.getGroupBy()); - } - } - - private void assertFrom(SQLSelectQueryBlock expected, SQLSelectQueryBlock actual) { - // Only 2 tables JOIN at most is supported - if (expected.getFrom() instanceof SQLExprTableSource) { - assertTable(expected.getFrom(), actual.getFrom()); - } else { - assertEquals(actual.getFrom().getClass(), SQLJoinTableSource.class); - assertTable( - ((SQLJoinTableSource) expected.getFrom()).getLeft(), - ((SQLJoinTableSource) actual.getFrom()).getLeft() - ); - assertTable( - ((SQLJoinTableSource) expected.getFrom()).getRight(), - ((SQLJoinTableSource) actual.getFrom()).getRight() - ); - assertEquals( - ((SQLJoinTableSource) expected.getFrom()).getJoinType(), - ((SQLJoinTableSource) actual.getFrom()).getJoinType() - ); - } - } - - private void assertGroupBy(SQLSelectGroupByClause expected, SQLSelectGroupByClause actual) { - assertEquals("HAVING", expected.getHaving(), actual.getHaving()); - - List expectedGroupby = expected.getItems(); - List actualGroupby = actual.getItems(); - assertEquals(expectedGroupby.size(), actualGroupby.size()); - range(0, expectedGroupby.size()). - forEach(i -> assertEquals( - ((MySqlSelectGroupByExpr) expectedGroupby.get(i)).getExpr(), - ((MySqlSelectGroupByExpr) actualGroupby.get(i)).getExpr()) - ); - } - - private void assertTable(SQLTableSource expect, SQLTableSource actual) { - assertEquals(SQLExprTableSource.class, expect.getClass()); - assertEquals(SQLExprTableSource.class, actual.getClass()); - assertEquals(((SQLExprTableSource) expect).getExpr(), ((SQLExprTableSource) actual).getExpr()); - assertEquals(expect.getAlias(), actual.getAlias()); - } - - /** - * Walk through extra rewrite logic if NOT found "nested" in SQL query statement. - * Otherwise return as before so that original logic be compared with result of rewrite. - * - * @param sql Test sql - * @return Node parsed out of sql - */ - private SQLQueryExpr query(String sql) { - SQLQueryExpr expr = SqlParserUtils.parse(sql); - if (sql.contains("nested")) { - return expr; - } - return rewrite(expr); - } - - private SQLQueryExpr rewrite(SQLQueryExpr expr) { - expr.accept(new NestedFieldRewriter()); - return expr; - } - + @Test + public void regression() { + noImpact("SELECT * FROM team"); + noImpact("SELECT region FROM team/test, employees/test"); + noImpact("SELECT manager.name FROM team WHERE region = 'US' ORDER BY COUNT(*)"); + noImpact("SELECT COUNT(*) FROM team GROUP BY region"); + } + + @Test + public void selectWithoutFrom() { + // Expect no exception thrown + query("SELECT now()"); + } + + @Test + public void selectAll() { + same( + query("SELECT * FROM team t, t.employees"), + query("SELECT *, nested(employees.*, 'employees') FROM team")); + } + + @Test + public void selectAllWithGroupBy() { + same( + query("SELECT * FROM team t, t.employees e GROUP BY e.firstname"), + query("SELECT * FROM team GROUP BY nested(employees.firstname, 'employees')")); + } + + @Test + public void selectAllWithCondition() { + same( + query("SELECT * FROM team t, t.employees e WHERE e.age = 26"), + query( + "SELECT *, nested(employees.*, 'employees') FROM team WHERE nested(employees.age," + + " 'employees') = 26")); + } + + @Test + public void singleCondition() { + same( + query("SELECT region FROM team t, t.employees e WHERE e.age = 26"), + query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26")); + } + + @Test + public void mixedWithObjectType() { + same( + query("SELECT region FROM team t, t.employees e WHERE e.age > 30 OR manager.age = 50"), + query( + "SELECT region FROM team WHERE nested(employees.age, 'employees') > 30 OR manager.age =" + + " 50")); + } + + @Test + public void noAlias() { + same( + query("SELECT region FROM team t, t.employees WHERE employees.age = 26"), + query("SELECT region FROM team WHERE nested(employees.age, 'employees') = 26")); + } + + @Test(expected = AssertionError.class) + public void multipleRegularTables() { + same( + query("SELECT region FROM team t, t.employees e, company c WHERE e.age = 26"), + query("SELECT region FROM team, company WHERE nested(employees.age) = 26")); + } + + @Test + public void eraseParentAlias() { + same( + query("SELECT t.age FROM team t, t.employees e WHERE t.region = 'US' AND age > 26"), + query("SELECT age FROM team WHERE region = 'US' AND age > 26")); + noImpact("SELECT t.age FROM team t WHERE t.region = 'US'"); + } + + @Test + public void select() { + same( + query("SELECT e.age FROM team t, t.employees e"), + query("SELECT nested(employees.age, 'employees' ) FROM team")); + } + + @Test + public void aggregationInSelect() { + same( + query("SELECT AVG(e.age) FROM team t, t.employees e"), + query("SELECT AVG(nested(employees.age, 'employees')) FROM team")); + } + + @Test + public void multipleAggregationsInSelect() { + same( + query("SELECT COUNT(*), AVG(e.age) FROM team t, t.employees e"), + query("SELECT COUNT(*), AVG(nested(employees.age, 'employees')) FROM team")); + } + + @Test + public void groupBy() { + same( + query("SELECT e.firstname, COUNT(*) FROM team t, t.employees e GROUP BY e.firstname"), + query( + "SELECT nested(employees.firstname, 'employees'), COUNT(*) FROM team GROUP BY" + + " nested(employees.firstname, 'employees')")); + } + + @Test + public void multipleFieldsInGroupBy() { + same( + query("SELECT COUNT(*) FROM team t, t.employees e GROUP BY t.manager, e.age"), + query("SELECT COUNT(*) FROM team GROUP BY manager, nested(employees.age, 'employees')")); + } + + @Test + public void orderBy() { + same( + query("SELECT region FROM team t, t.employees e ORDER BY e.age"), + query("SELECT region FROM team ORDER BY nested(employees.age)")); + } + + @Test + public void multipleConditions() { + same( + query( + "SELECT region " + + "FROM team t, t.manager m, t.employees e " + + "WHERE t.department = 'IT' AND " + + " (e.age = 26 OR (e.firstname = 'John' AND e.lastname = 'Smith')) AND " + + " t.region = 'US' AND " + + " (m.name = 'Alice' AND m.age = 50)"), + query( + "SELECT region FROM team WHERE department = 'IT' AND nested(\"employees\"," + + " employees.age = 26 OR (employees.firstname = 'John' AND employees.lastname =" + + " 'Smith')) AND region = 'US' AND nested(\"manager\", manager.name =" + + " 'Alice' AND manager.age = 50)")); + } + + @Test + public void multipleFieldsInFrom() { + same( + query( + "SELECT region FROM team/test t, t.manager m, t.employees e WHERE m.age = 30 AND e.age" + + " = 26"), + query( + "SELECT region FROM team/test WHERE nested(manager.age, 'manager') = 30 " + + "AND nested(employees.age, 'employees') = 26")); + } + + @Test + public void unionAll() { + // NLPchina doesn't support UNION (intersection) + same( + query( + "SELECT region FROM team t, t.employees e WHERE e.age = 26 " + + "UNION ALL " + + "SELECT region FROM team t, t.employees e WHERE e.firstname = 'John'"), + query( + "SELECT region FROM team WHERE nested(employees.age, 'employees') = 26 UNION ALL SELECT" + + " region FROM team WHERE nested(employees.firstname, 'employees') = 'John'")); + } + + @Test + public void minus() { + same( + query( + "SELECT region FROM team t, t.employees e WHERE e.age = 26 " + + "MINUS " + + "SELECT region FROM team t, t.employees e WHERE e.firstname = 'John'"), + query( + "SELECT region FROM team WHERE nested(employees.age, 'employees') = 26 MINUS SELECT" + + " region FROM team WHERE nested(employees.firstname, 'employees') = 'John'")); + } + + public void join() { + // TODO + } + + @Test + public void subQuery() { + // Subquery only support IN and TERMS + same( + query( + "SELECT region FROM team t, t.employees e " + + " WHERE e.age IN " + + " (SELECT t1.manager.age FROM team t1, t1.employees e1 WHERE e1.age > 0)"), + query( + "SELECT region FROM team WHERE nested(employees.age, 'employees') IN (SELECT" + + " manager.age FROM team WHERE nested(employees.age, 'employees') > 0)")); + } + + @Test + public void subQueryWitSameAlias() { + // Inner alias e shadow outer alias e of nested field + same( + query( + "SELECT name FROM team t, t.employees e WHERE e.age IN (SELECT e.age FROM team e," + + " e.manager m WHERE e.age > 0 OR m.name = 'Alice')"), + query( + "SELECT name FROM team WHERE nested(employees.age, 'employees') IN (SELECT age" + + " FROM team WHERE age > 0 OR nested(manager.name, 'manager') = 'Alice')")); + } + + @Test + public void isNotNull() { + same( + query("SELECT e.name " + "FROM employee as e, e.projects as p " + "WHERE p IS NOT MISSING"), + query( + "SELECT name " + + "FROM employee " + + "WHERE nested(projects, 'projects') IS NOT MISSING")); + } + + @Test + public void isNotNullAndCondition() { + same( + query( + "SELECT e.name " + + "FROM employee as e, e.projects as p " + + "WHERE p IS NOT MISSING AND p.name LIKE 'security'"), + query( + "SELECT name FROM employee WHERE nested('projects', projects IS NOT MISSING AND" + + " projects.name LIKE 'security')")); + } + + @Test + public void multiCondition() { + same( + query( + "SELECT e.name FROM employee as e, e.projects as p WHERE p.year = 2016 and p.name LIKE" + + " 'security'"), + query( + "SELECT name FROM employee WHERE nested('projects', projects.year = 2016 AND" + + " projects.name LIKE 'security')")); + } + + @Test + public void nestedAndParentCondition() { + same( + query( + "SELECT name " + + "FROM employee " + + "WHERE nested(projects, 'projects') IS NOT MISSING AND name LIKE 'security'"), + query( + "SELECT e.name " + + "FROM employee e, e.projects p " + + "WHERE p IS NOT MISSING AND e.name LIKE 'security'")); + } + + @Test + public void aggWithWhereOnParent() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c " + + "FROM employee " + + "WHERE name LIKE '%smith%' " + + "GROUP BY name " + + "HAVING c > 1")); + } + + @Test + public void aggWithWhereOnNested() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c " + + "FROM employee " + + "WHERE nested(projects.name, 'projects') LIKE '%security%' " + + "GROUP BY name " + + "HAVING c > 1")); + } + + @Test + public void aggWithWhereOnParentOrNested() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' or p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c FROM employee WHERE name LIKE" + + " '%smith%' OR nested(projects.name, 'projects') LIKE '%security%' GROUP BY name" + + " HAVING c > 1")); + } + + @Test + public void aggWithWhereOnParentAndNested() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' AND p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c FROM employee WHERE name LIKE" + + " '%smith%' AND nested(projects.name, 'projects') LIKE '%security%' GROUP BY name" + + " HAVING c > 1")); + } + + @Test + public void aggWithWhereOnNestedAndNested() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.started_year > 1990 AND p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c FROM employee WHERE" + + " nested('projects', projects.started_year > 1990 AND projects.name LIKE" + + " '%security%') GROUP BY name HAVING c > 1")); + } + + @Test + public void aggWithWhereOnNestedOrNested() { + same( + query( + "SELECT e.name, COUNT(p) as c " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.started_year > 1990 OR p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING c > 1"), + query( + "SELECT name, COUNT(nested(projects, 'projects')) AS c FROM employee WHERE" + + " nested('projects', projects.started_year > 1990 OR projects.name LIKE" + + " '%security%') GROUP BY name HAVING c > 1")); + } + + @Test + public void aggInHavingWithWhereOnParent() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name " + + "FROM employee " + + "WHERE name LIKE '%smith%' " + + "GROUP BY name " + + "HAVING COUNT(nested(projects, 'projects')) > 1")); + } + + @Test + public void aggInHavingWithWhereOnNested() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name " + + "FROM employee " + + "WHERE nested(projects.name, 'projects') LIKE '%security%' " + + "GROUP BY name " + + "HAVING COUNT(nested(projects, 'projects')) > 1")); + } + + @Test + public void aggInHavingWithWhereOnParentOrNested() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' or p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name FROM employee WHERE name LIKE '%smith%' OR nested(projects.name," + + " 'projects') LIKE '%security%' GROUP BY name HAVING COUNT(nested(projects," + + " 'projects')) > 1")); + } + + @Test + public void aggInHavingWithWhereOnParentAndNested() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE e.name like '%smith%' AND p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name FROM employee WHERE name LIKE '%smith%' AND nested(projects.name," + + " 'projects') LIKE '%security%' GROUP BY name HAVING COUNT(nested(projects," + + " 'projects')) > 1")); + } + + @Test + public void aggInHavingWithWhereOnNestedAndNested() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.started_year > 1990 AND p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name FROM employee WHERE nested('projects', projects.started_year > 1990 AND" + + " projects.name LIKE '%security%') GROUP BY name HAVING COUNT(nested(projects," + + " 'projects')) > 1")); + } + + @Test + public void aggInHavingWithWhereOnNestedOrNested() { + same( + query( + "SELECT e.name " + + "FROM employee AS e, e.projects AS p " + + "WHERE p.started_year > 1990 OR p.name LIKE '%security%' " + + "GROUP BY e.name " + + "HAVING COUNT(p) > 1"), + query( + "SELECT name FROM employee WHERE nested('projects', projects.started_year > 1990 OR" + + " projects.name LIKE '%security%') GROUP BY name HAVING COUNT(nested(projects," + + " 'projects')) > 1")); + } + + @Test + public void notIsNotNull() { + same( + query( + "SELECT name " + + "FROM employee " + + "WHERE not (nested(projects, 'projects') IS NOT MISSING)"), + query( + "SELECT e.name " + + "FROM employee as e, e.projects as p " + + "WHERE not (p IS NOT MISSING)")); + } + + @Test + public void notIsNotNullAndCondition() { + same( + query( + "SELECT e.name " + + "FROM employee as e, e.projects as p " + + "WHERE not (p IS NOT MISSING AND p.name LIKE 'security')"), + query( + "SELECT name FROM employee WHERE not nested('projects', projects IS NOT MISSING AND" + + " projects.name LIKE 'security')")); + } + + @Test + public void notMultiCondition() { + same( + query( + "SELECT name FROM employee WHERE not nested('projects', projects.year = 2016 AND" + + " projects.name LIKE 'security')"), + query( + "SELECT e.name " + + "FROM employee as e, e.projects as p " + + "WHERE not (p.year = 2016 and p.name LIKE 'security')")); + } + + @Test + public void notNestedAndParentCondition() { + same( + query( + "SELECT name FROM employee WHERE (not nested(projects, 'projects') IS NOT MISSING) AND" + + " name LIKE 'security'"), + query( + "SELECT e.name " + + "FROM employee e, e.projects p " + + "WHERE not (p IS NOT MISSING) AND e.name LIKE 'security'")); + } + + private void noImpact(String sql) { + same(parse(sql), rewrite(parse(sql))); + } + + /** + * The intention for this assert method is: + *
    + *
  1. MySqlSelectQueryBlock.equals() doesn't call super.equals(). But select items, from, where + * and group by are all held by parent class SQLSelectQueryBlock. + * + *
  2. SQLSelectGroupByClause doesn't implement equals() at all.. MySqlSelectGroupByExpr + * compares identity of expression.. + * + *
  3. MySqlUnionQuery doesn't implement equals() at all + *
+ */ + private void same(SQLQueryExpr actual, SQLQueryExpr expected) { + assertEquals(expected.getClass(), actual.getClass()); + + SQLSelect expectedQuery = expected.getSubQuery(); + SQLSelect actualQuery = actual.getSubQuery(); + assertEquals(expectedQuery.getOrderBy(), actualQuery.getOrderBy()); + assertQuery(expectedQuery, actualQuery); + } + + private void assertQuery(SQLSelect expected, SQLSelect actual) { + SQLSelectQuery expectedQuery = expected.getQuery(); + SQLSelectQuery actualQuery = actual.getQuery(); + if (actualQuery instanceof SQLSelectQueryBlock) { + assertQueryBlock((SQLSelectQueryBlock) expectedQuery, (SQLSelectQueryBlock) actualQuery); + } else if (actualQuery instanceof SQLUnionQuery) { + assertQueryBlock( + (SQLSelectQueryBlock) ((SQLUnionQuery) expectedQuery).getLeft(), + (SQLSelectQueryBlock) ((SQLUnionQuery) actualQuery).getLeft()); + assertQueryBlock( + (SQLSelectQueryBlock) ((SQLUnionQuery) expectedQuery).getRight(), + (SQLSelectQueryBlock) ((SQLUnionQuery) actualQuery).getRight()); + assertEquals( + ((SQLUnionQuery) expectedQuery).getOperator(), + ((SQLUnionQuery) actualQuery).getOperator()); + } else { + throw new IllegalStateException("Unsupported test SQL"); + } + } + + private void assertQueryBlock(SQLSelectQueryBlock expected, SQLSelectQueryBlock actual) { + assertEquals("SELECT", expected.getSelectList(), actual.getSelectList()); + assertEquals("INTO", expected.getInto(), actual.getInto()); + assertEquals("WHERE", expected.getWhere(), actual.getWhere()); + if (actual.getWhere() instanceof SQLInSubQueryExpr) { + assertQuery( + ((SQLInSubQueryExpr) expected.getWhere()).getSubQuery(), + ((SQLInSubQueryExpr) actual.getWhere()).getSubQuery()); + } + assertEquals("PARENTHESIZED", expected.isParenthesized(), actual.isParenthesized()); + assertEquals("DISTION", expected.getDistionOption(), actual.getDistionOption()); + assertFrom(expected, actual); + if (!(expected.getGroupBy() == null && actual.getGroupBy() == null)) { + assertGroupBy(expected.getGroupBy(), actual.getGroupBy()); + } + } + + private void assertFrom(SQLSelectQueryBlock expected, SQLSelectQueryBlock actual) { + // Only 2 tables JOIN at most is supported + if (expected.getFrom() instanceof SQLExprTableSource) { + assertTable(expected.getFrom(), actual.getFrom()); + } else { + assertEquals(actual.getFrom().getClass(), SQLJoinTableSource.class); + assertTable( + ((SQLJoinTableSource) expected.getFrom()).getLeft(), + ((SQLJoinTableSource) actual.getFrom()).getLeft()); + assertTable( + ((SQLJoinTableSource) expected.getFrom()).getRight(), + ((SQLJoinTableSource) actual.getFrom()).getRight()); + assertEquals( + ((SQLJoinTableSource) expected.getFrom()).getJoinType(), + ((SQLJoinTableSource) actual.getFrom()).getJoinType()); + } + } + + private void assertGroupBy(SQLSelectGroupByClause expected, SQLSelectGroupByClause actual) { + assertEquals("HAVING", expected.getHaving(), actual.getHaving()); + + List expectedGroupby = expected.getItems(); + List actualGroupby = actual.getItems(); + assertEquals(expectedGroupby.size(), actualGroupby.size()); + range(0, expectedGroupby.size()) + .forEach( + i -> + assertEquals( + ((MySqlSelectGroupByExpr) expectedGroupby.get(i)).getExpr(), + ((MySqlSelectGroupByExpr) actualGroupby.get(i)).getExpr())); + } + + private void assertTable(SQLTableSource expect, SQLTableSource actual) { + assertEquals(SQLExprTableSource.class, expect.getClass()); + assertEquals(SQLExprTableSource.class, actual.getClass()); + assertEquals(((SQLExprTableSource) expect).getExpr(), ((SQLExprTableSource) actual).getExpr()); + assertEquals(expect.getAlias(), actual.getAlias()); + } + + /** + * Walk through extra rewrite logic if NOT found "nested" in SQL query statement. Otherwise return + * as before so that original logic be compared with result of rewrite. + * + * @param sql Test sql + * @return Node parsed out of sql + */ + private SQLQueryExpr query(String sql) { + SQLQueryExpr expr = SqlParserUtils.parse(sql); + if (sql.contains("nested")) { + return expr; + } + return rewrite(expr); + } + + private SQLQueryExpr rewrite(SQLQueryExpr expr) { + expr.accept(new NestedFieldRewriter()); + return expr; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/OpenSearchClientTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/OpenSearchClientTest.java index 2a654774d4..2dd5cc16ac 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/OpenSearchClientTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/OpenSearchClientTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static org.mockito.Matchers.any; @@ -27,39 +26,42 @@ public class OpenSearchClientTest { - @Mock - protected Client client; + @Mock protected Client client; - @Before - public void init() { - MockitoAnnotations.initMocks(this); - ActionFuture mockFuture = mock(ActionFuture.class); - when(client.multiSearch(any())).thenReturn(mockFuture); + @Before + public void init() { + MockitoAnnotations.initMocks(this); + ActionFuture mockFuture = mock(ActionFuture.class); + when(client.multiSearch(any())).thenReturn(mockFuture); - MultiSearchResponse response = mock(MultiSearchResponse.class); - when(mockFuture.actionGet()).thenReturn(response); + MultiSearchResponse response = mock(MultiSearchResponse.class); + when(mockFuture.actionGet()).thenReturn(response); - MultiSearchResponse.Item item0 = new MultiSearchResponse.Item(mock(SearchResponse.class), null); - MultiSearchResponse.Item item1 = new MultiSearchResponse.Item(mock(SearchResponse.class), new Exception()); - MultiSearchResponse.Item[] itemsRetry0 = new MultiSearchResponse.Item[]{item0, item1}; - MultiSearchResponse.Item[] itemsRetry1 = new MultiSearchResponse.Item[]{item0}; - when(response.getResponses()).thenAnswer(new Answer() { - private int callCnt; + MultiSearchResponse.Item item0 = new MultiSearchResponse.Item(mock(SearchResponse.class), null); + MultiSearchResponse.Item item1 = + new MultiSearchResponse.Item(mock(SearchResponse.class), new Exception()); + MultiSearchResponse.Item[] itemsRetry0 = new MultiSearchResponse.Item[] {item0, item1}; + MultiSearchResponse.Item[] itemsRetry1 = new MultiSearchResponse.Item[] {item0}; + when(response.getResponses()) + .thenAnswer( + new Answer() { + private int callCnt; - @Override - public MultiSearchResponse.Item[] answer(InvocationOnMock invocation) { + @Override + public MultiSearchResponse.Item[] answer(InvocationOnMock invocation) { return callCnt++ == 0 ? itemsRetry0 : itemsRetry1; - } - }); - } - - @Test - public void multiSearchRetryOneTime() { - OpenSearchClient openSearchClient = new OpenSearchClient(client); - MultiSearchResponse.Item[] res = openSearchClient.multiSearch(new MultiSearchRequest().add(new SearchRequest()).add(new SearchRequest())); - Assert.assertEquals(res.length, 2); - Assert.assertFalse(res[0].isFailure()); - Assert.assertFalse(res[1].isFailure()); - } + } + }); + } + @Test + public void multiSearchRetryOneTime() { + OpenSearchClient openSearchClient = new OpenSearchClient(client); + MultiSearchResponse.Item[] res = + openSearchClient.multiSearch( + new MultiSearchRequest().add(new SearchRequest()).add(new SearchRequest())); + Assert.assertEquals(res.length, 2); + Assert.assertFalse(res[0].isFailure()); + Assert.assertFalse(res[1].isFailure()); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/PreparedStatementRequestTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/PreparedStatementRequestTest.java index 0b714ed41c..8a31c530e3 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/PreparedStatementRequestTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/PreparedStatementRequestTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import java.util.ArrayList; @@ -15,64 +14,68 @@ public class PreparedStatementRequestTest { - @Test - public void testSubstitute() { - String sqlTemplate = "select * from table_name where number_param > ? and string_param = 'Amazon.com' " + - "and test_str = '''test escape? \\'' and state in (?,?) and null_param = ? and double_param = ? " + - "and question_mark = '?'"; - List params = new ArrayList<>(); - params.add(new PreparedStatementRequest.PreparedStatementParameter(10)); - params.add(new PreparedStatementRequest.StringParameter("WA")); - params.add(new PreparedStatementRequest.StringParameter("")); - params.add(new PreparedStatementRequest.NullParameter()); - params.add(new PreparedStatementRequest.PreparedStatementParameter(2.0)); - PreparedStatementRequest psr = new PreparedStatementRequest(sqlTemplate, new JSONObject(), params); - String generatedSql = psr.getSql(); - - String expectedSql = "select * from table_name where number_param > 10 and string_param = 'Amazon.com' " + - "and test_str = '''test escape? \\'' and state in ('WA','') and null_param = null " + - "and double_param = 2.0 and question_mark = '?'"; - Assert.assertEquals(expectedSql, generatedSql); - } - - @Test - public void testStringParameter() { - PreparedStatementRequest.StringParameter param; - param = new PreparedStatementRequest.StringParameter("test string"); - Assert.assertEquals("'test string'", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test ' single ' quote '"); - Assert.assertEquals("'test \\' single \\' quote \\''", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test line \n break \n char"); - Assert.assertEquals("'test line \\n break \\n char'", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test carriage \r return \r char"); - Assert.assertEquals("'test carriage \\r return \\r char'", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test \\ backslash \\ char"); - Assert.assertEquals("'test \\\\ backslash \\\\ char'", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test single ' quote ' char"); - Assert.assertEquals("'test single \\' quote \\' char'", param.getSqlSubstitutionValue()); - - param = new PreparedStatementRequest.StringParameter("test double \" quote \" char"); - Assert.assertEquals("'test double \\\" quote \\\" char'", param.getSqlSubstitutionValue()); - } - - @Test(expected = IllegalStateException.class) - public void testSubstitute_parameterNumberNotMatch() { - String sqlTemplate = "select * from table_name where param1 = ? and param2 = ?"; - List params = new ArrayList<>(); - params.add(new PreparedStatementRequest.StringParameter("value")); - - PreparedStatementRequest psr = new PreparedStatementRequest(sqlTemplate, new JSONObject(), params); - } - - @Test - public void testSubstitute_nullSql() { - PreparedStatementRequest psr = new PreparedStatementRequest(null, new JSONObject(), null); - - Assert.assertNull(psr.getSql()); - } + @Test + public void testSubstitute() { + String sqlTemplate = + "select * from table_name where number_param > ? and string_param = 'Amazon.com' and" + + " test_str = '''test escape? \\'' and state in (?,?) and null_param = ? and" + + " double_param = ? and question_mark = '?'"; + List params = new ArrayList<>(); + params.add(new PreparedStatementRequest.PreparedStatementParameter(10)); + params.add(new PreparedStatementRequest.StringParameter("WA")); + params.add(new PreparedStatementRequest.StringParameter("")); + params.add(new PreparedStatementRequest.NullParameter()); + params.add(new PreparedStatementRequest.PreparedStatementParameter(2.0)); + PreparedStatementRequest psr = + new PreparedStatementRequest(sqlTemplate, new JSONObject(), params); + String generatedSql = psr.getSql(); + + String expectedSql = + "select * from table_name where number_param > 10 and string_param = 'Amazon.com' " + + "and test_str = '''test escape? \\'' and state in ('WA','') and null_param = null " + + "and double_param = 2.0 and question_mark = '?'"; + Assert.assertEquals(expectedSql, generatedSql); + } + + @Test + public void testStringParameter() { + PreparedStatementRequest.StringParameter param; + param = new PreparedStatementRequest.StringParameter("test string"); + Assert.assertEquals("'test string'", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test ' single ' quote '"); + Assert.assertEquals("'test \\' single \\' quote \\''", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test line \n break \n char"); + Assert.assertEquals("'test line \\n break \\n char'", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test carriage \r return \r char"); + Assert.assertEquals("'test carriage \\r return \\r char'", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test \\ backslash \\ char"); + Assert.assertEquals("'test \\\\ backslash \\\\ char'", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test single ' quote ' char"); + Assert.assertEquals("'test single \\' quote \\' char'", param.getSqlSubstitutionValue()); + + param = new PreparedStatementRequest.StringParameter("test double \" quote \" char"); + Assert.assertEquals("'test double \\\" quote \\\" char'", param.getSqlSubstitutionValue()); + } + + @Test(expected = IllegalStateException.class) + public void testSubstitute_parameterNumberNotMatch() { + String sqlTemplate = "select * from table_name where param1 = ? and param2 = ?"; + List params = new ArrayList<>(); + params.add(new PreparedStatementRequest.StringParameter("value")); + + PreparedStatementRequest psr = + new PreparedStatementRequest(sqlTemplate, new JSONObject(), params); + } + + @Test + public void testSubstitute_nullSql() { + PreparedStatementRequest psr = new PreparedStatementRequest(null, new JSONObject(), null); + + Assert.assertNull(psr.getSql()); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/QueryFunctionsTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/QueryFunctionsTest.java index 0ebf89e296..b5a82f6737 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/QueryFunctionsTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/QueryFunctionsTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static org.hamcrest.MatcherAssert.assertThat; @@ -37,272 +36,178 @@ public class QueryFunctionsTest { - private static final String SELECT_ALL = "SELECT *"; - private static final String FROM_ACCOUNTS = "FROM " + TestsConstants.TEST_INDEX_ACCOUNT + "/account"; - private static final String FROM_NESTED = "FROM " + TestsConstants.TEST_INDEX_NESTED_TYPE + "/nestedType"; - private static final String FROM_PHRASE = "FROM " + TestsConstants.TEST_INDEX_PHRASE + "/phrase"; - - @Test - public void query() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE QUERY('CA')" - ), - contains( - queryStringQuery("CA") - ) - ); - } - - @Test - public void matchQueryRegularField() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE MATCH_QUERY(firstname, 'Ayers')" - ), - contains( - matchQuery("firstname", "Ayers") - ) - ); - } - - @Test - public void matchQueryNestedField() { - assertThat( - query( - FROM_NESTED, - "WHERE MATCH_QUERY(NESTED(comment.data), 'aa')" - ), - contains( - nestedQuery("comment", matchQuery("comment.data", "aa"), ScoreMode.None) - ) - ); - } - - @Test - public void scoreQuery() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE SCORE(MATCH_QUERY(firstname, 'Ayers'), 10)" - ), - contains( - constantScoreQuery( - matchQuery("firstname", "Ayers") - ).boost(10) - ) - ); - } - - @Test - public void scoreQueryWithNestedField() { - assertThat( - query( - FROM_NESTED, - "WHERE SCORE(MATCH_QUERY(NESTED(comment.data), 'ab'), 10)" - ), - contains( - constantScoreQuery( - nestedQuery("comment", matchQuery("comment.data", "ab"), ScoreMode.None) - ).boost(10) - ) - ); - } - - @Test - public void wildcardQueryRegularField() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE WILDCARD_QUERY(city.keyword, 'B*')" - ), - contains( - wildcardQuery("city.keyword", "B*") - ) - ); - } - - @Test - public void wildcardQueryNestedField() { - assertThat( - query( - FROM_NESTED, - "WHERE WILDCARD_QUERY(nested(comment.data), 'a*')" - ), - contains( - nestedQuery("comment", wildcardQuery("comment.data", "a*"), ScoreMode.None) - ) - ); - } - - @Test - public void matchPhraseQueryDefault() { - assertThat( - query( - FROM_PHRASE, - "WHERE MATCH_PHRASE(phrase, 'brown fox')" - ), - contains( - matchPhraseQuery("phrase", "brown fox") - ) - ); - } - - @Test - public void matchPhraseQueryWithSlop() { - assertThat( - query( - FROM_PHRASE, - "WHERE MATCH_PHRASE(phrase, 'brown fox', slop=2)" - ), - contains( - matchPhraseQuery("phrase", "brown fox").slop(2) - ) - ); - } - - @Test - public void multiMatchQuerySingleField() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE MULTI_MATCH(query='Ayers', fields='firstname')" - ), - contains( - multiMatchQuery("Ayers").field("firstname") - ) - ); - } - - @Test - public void multiMatchQueryWildcardField() { - assertThat( - query( - FROM_ACCOUNTS, - "WHERE MULTI_MATCH(query='Ay', fields='*name', type='phrase_prefix')" - ), - contains( - multiMatchQuery("Ay"). - field("*name"). - type(MultiMatchQueryBuilder.Type.PHRASE_PREFIX) - ) - ); - } - - @Test - public void numberLiteralInSelectField() { - String query = "SELECT 2 AS number FROM bank WHERE age > 20"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "def assign" - ) - ); - } - - @Test - public void ifFunctionWithConditionStatement() { - String query = "SELECT IF(age > 35, 'elastic', 'search') AS Ages FROM accounts"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "boolean cond = doc['age'].value > 35;" - ) - ); - } - - @Test - public void ifFunctionWithEquationConditionStatement() { - String query = "SELECT IF(age = 35, 'elastic', 'search') AS Ages FROM accounts"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "boolean cond = doc['age'].value == 35;" - ) - ); - } - - @Test - public void ifFunctionWithConstantConditionStatement() { - String query = "SELECT IF(1 = 2, 'elastic', 'search') FROM accounts"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "boolean cond = 1 == 2;" - ) - ); - } - - @Test - public void ifNull() { - String query = "SELECT IFNULL(lastname, 'Unknown') FROM accounts"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "doc['lastname'].size()==0" - ) - ); - } - - @Test - public void isNullWithMathExpr() { - String query = "SELECT ISNULL(1+1) FROM accounts"; - ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); - assertTrue( - CheckScriptContents.scriptContainsString( - scriptField, - "catch(ArithmeticException e)" - ) - ); - - } - - @Test(expected = SQLFeatureNotSupportedException.class) - public void emptyQueryShouldThrowSQLFeatureNotSupportedException() - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - OpenSearchActionFactory.create(Mockito.mock(Client.class), ""); - } - - @Test(expected = SQLFeatureNotSupportedException.class) - public void emptyNewLineQueryShouldThrowSQLFeatureNotSupportedException() - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - OpenSearchActionFactory.create(Mockito.mock(Client.class), "\n"); - } - - @Test(expected = SQLFeatureNotSupportedException.class) - public void emptyNewLineQueryShouldThrowSQLFeatureNotSupportedException2() - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - OpenSearchActionFactory.create(Mockito.mock(Client.class), "\r\n"); - } - - @Test(expected = SQLFeatureNotSupportedException.class) - public void queryWithoutSpaceShouldSQLFeatureNotSupportedException() - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - OpenSearchActionFactory.create(Mockito.mock(Client.class), "SELE"); - } - - @Test(expected = SQLFeatureNotSupportedException.class) - public void spacesOnlyQueryShouldThrowSQLFeatureNotSupportedException() - throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { - OpenSearchActionFactory.create(Mockito.mock(Client.class), " "); - } - - private String query(String from, String... statements) { - return explain(SELECT_ALL + " " + from + " " + String.join(" ", statements)); - } - - private String query(String sql) { - return explain(sql); - } - - private Matcher contains(AbstractQueryBuilder queryBuilder) { - return containsString(Strings.toString(XContentType.JSON, queryBuilder, false, false)); - } + private static final String SELECT_ALL = "SELECT *"; + private static final String FROM_ACCOUNTS = + "FROM " + TestsConstants.TEST_INDEX_ACCOUNT + "/account"; + private static final String FROM_NESTED = + "FROM " + TestsConstants.TEST_INDEX_NESTED_TYPE + "/nestedType"; + private static final String FROM_PHRASE = "FROM " + TestsConstants.TEST_INDEX_PHRASE + "/phrase"; + + @Test + public void query() { + assertThat(query(FROM_ACCOUNTS, "WHERE QUERY('CA')"), contains(queryStringQuery("CA"))); + } + + @Test + public void matchQueryRegularField() { + assertThat( + query(FROM_ACCOUNTS, "WHERE MATCH_QUERY(firstname, 'Ayers')"), + contains(matchQuery("firstname", "Ayers"))); + } + + @Test + public void matchQueryNestedField() { + assertThat( + query(FROM_NESTED, "WHERE MATCH_QUERY(NESTED(comment.data), 'aa')"), + contains(nestedQuery("comment", matchQuery("comment.data", "aa"), ScoreMode.None))); + } + + @Test + public void scoreQuery() { + assertThat( + query(FROM_ACCOUNTS, "WHERE SCORE(MATCH_QUERY(firstname, 'Ayers'), 10)"), + contains(constantScoreQuery(matchQuery("firstname", "Ayers")).boost(10))); + } + + @Test + public void scoreQueryWithNestedField() { + assertThat( + query(FROM_NESTED, "WHERE SCORE(MATCH_QUERY(NESTED(comment.data), 'ab'), 10)"), + contains( + constantScoreQuery( + nestedQuery("comment", matchQuery("comment.data", "ab"), ScoreMode.None)) + .boost(10))); + } + + @Test + public void wildcardQueryRegularField() { + assertThat( + query(FROM_ACCOUNTS, "WHERE WILDCARD_QUERY(city.keyword, 'B*')"), + contains(wildcardQuery("city.keyword", "B*"))); + } + + @Test + public void wildcardQueryNestedField() { + assertThat( + query(FROM_NESTED, "WHERE WILDCARD_QUERY(nested(comment.data), 'a*')"), + contains(nestedQuery("comment", wildcardQuery("comment.data", "a*"), ScoreMode.None))); + } + + @Test + public void matchPhraseQueryDefault() { + assertThat( + query(FROM_PHRASE, "WHERE MATCH_PHRASE(phrase, 'brown fox')"), + contains(matchPhraseQuery("phrase", "brown fox"))); + } + + @Test + public void matchPhraseQueryWithSlop() { + assertThat( + query(FROM_PHRASE, "WHERE MATCH_PHRASE(phrase, 'brown fox', slop=2)"), + contains(matchPhraseQuery("phrase", "brown fox").slop(2))); + } + + @Test + public void multiMatchQuerySingleField() { + assertThat( + query(FROM_ACCOUNTS, "WHERE MULTI_MATCH(query='Ayers', fields='firstname')"), + contains(multiMatchQuery("Ayers").field("firstname"))); + } + + @Test + public void multiMatchQueryWildcardField() { + assertThat( + query(FROM_ACCOUNTS, "WHERE MULTI_MATCH(query='Ay', fields='*name', type='phrase_prefix')"), + contains( + multiMatchQuery("Ay").field("*name").type(MultiMatchQueryBuilder.Type.PHRASE_PREFIX))); + } + + @Test + public void numberLiteralInSelectField() { + String query = "SELECT 2 AS number FROM bank WHERE age > 20"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue(CheckScriptContents.scriptContainsString(scriptField, "def assign")); + } + + @Test + public void ifFunctionWithConditionStatement() { + String query = "SELECT IF(age > 35, 'elastic', 'search') AS Ages FROM accounts"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue( + CheckScriptContents.scriptContainsString( + scriptField, "boolean cond = doc['age'].value > 35;")); + } + + @Test + public void ifFunctionWithEquationConditionStatement() { + String query = "SELECT IF(age = 35, 'elastic', 'search') AS Ages FROM accounts"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue( + CheckScriptContents.scriptContainsString( + scriptField, "boolean cond = doc['age'].value == 35;")); + } + + @Test + public void ifFunctionWithConstantConditionStatement() { + String query = "SELECT IF(1 = 2, 'elastic', 'search') FROM accounts"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue(CheckScriptContents.scriptContainsString(scriptField, "boolean cond = 1 == 2;")); + } + + @Test + public void ifNull() { + String query = "SELECT IFNULL(lastname, 'Unknown') FROM accounts"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue(CheckScriptContents.scriptContainsString(scriptField, "doc['lastname'].size()==0")); + } + + @Test + public void isNullWithMathExpr() { + String query = "SELECT ISNULL(1+1) FROM accounts"; + ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query); + assertTrue( + CheckScriptContents.scriptContainsString(scriptField, "catch(ArithmeticException e)")); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void emptyQueryShouldThrowSQLFeatureNotSupportedException() + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + OpenSearchActionFactory.create(Mockito.mock(Client.class), ""); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void emptyNewLineQueryShouldThrowSQLFeatureNotSupportedException() + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + OpenSearchActionFactory.create(Mockito.mock(Client.class), "\n"); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void emptyNewLineQueryShouldThrowSQLFeatureNotSupportedException2() + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + OpenSearchActionFactory.create(Mockito.mock(Client.class), "\r\n"); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void queryWithoutSpaceShouldSQLFeatureNotSupportedException() + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + OpenSearchActionFactory.create(Mockito.mock(Client.class), "SELE"); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void spacesOnlyQueryShouldThrowSQLFeatureNotSupportedException() + throws SQLFeatureNotSupportedException, SqlParseException, SQLFeatureDisabledException { + OpenSearchActionFactory.create(Mockito.mock(Client.class), " "); + } + + private String query(String from, String... statements) { + return explain(SELECT_ALL + " " + from + " " + String.join(" ", statements)); + } + + private String query(String sql) { + return explain(sql); + } + + private Matcher contains(AbstractQueryBuilder queryBuilder) { + return containsString(Strings.toString(XContentType.JSON, queryBuilder, false, false)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/RefExpressionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/RefExpressionTest.java index f8607ca889..faefa6d2c1 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/RefExpressionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/RefExpressionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.expression.core; import static org.hamcrest.MatcherAssert.assertThat; @@ -25,36 +24,40 @@ import org.junit.Test; public class RefExpressionTest extends ExpressionTest { - @Test - public void refIntegerValueShouldPass() { - assertEquals(Integer.valueOf(1), getIntegerValue(ref("intValue").valueOf(bindingTuple()))); - } - - @Test - public void refDoubleValueShouldPass() { - assertEquals(Double.valueOf(2d), getDoubleValue(ref("doubleValue").valueOf(bindingTuple()))); - } - - @Test - public void refStringValueShouldPass() { - assertEquals("string", getStringValue(ref("stringValue").valueOf(bindingTuple()))); - } - - @Test - public void refBooleanValueShouldPass() { - assertEquals(true, getBooleanValue(ref("booleanValue").valueOf(bindingTuple()))); - } - - @Test - public void refTupleValueShouldPass() { - assertThat(getTupleValue(ref("tupleValue").valueOf(bindingTuple())), - allOf(hasEntry("intValue", integerValue(1)), hasEntry("doubleValue", doubleValue(2d)), - hasEntry("stringValue", stringValue("string")))); - } - - @Test - public void refCollectValueShouldPass() { - assertThat(getCollectionValue(ref("collectValue").valueOf(bindingTuple())), - contains(integerValue(1), integerValue(2), integerValue(3))); - } + @Test + public void refIntegerValueShouldPass() { + assertEquals(Integer.valueOf(1), getIntegerValue(ref("intValue").valueOf(bindingTuple()))); + } + + @Test + public void refDoubleValueShouldPass() { + assertEquals(Double.valueOf(2d), getDoubleValue(ref("doubleValue").valueOf(bindingTuple()))); + } + + @Test + public void refStringValueShouldPass() { + assertEquals("string", getStringValue(ref("stringValue").valueOf(bindingTuple()))); + } + + @Test + public void refBooleanValueShouldPass() { + assertEquals(true, getBooleanValue(ref("booleanValue").valueOf(bindingTuple()))); + } + + @Test + public void refTupleValueShouldPass() { + assertThat( + getTupleValue(ref("tupleValue").valueOf(bindingTuple())), + allOf( + hasEntry("intValue", integerValue(1)), + hasEntry("doubleValue", doubleValue(2d)), + hasEntry("stringValue", stringValue("string")))); + } + + @Test + public void refCollectValueShouldPass() { + assertThat( + getCollectionValue(ref("collectValue").valueOf(bindingTuple())), + contains(integerValue(1), integerValue(2), integerValue(3))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/NumericMetricTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/NumericMetricTest.java index f2c2c25fab..d76241056f 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/NumericMetricTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/NumericMetricTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.metrics; import static org.hamcrest.MatcherAssert.assertThat; @@ -15,22 +14,21 @@ public class NumericMetricTest { - @Test - public void increment() { - NumericMetric metric = new NumericMetric("test", new BasicCounter()); - for (int i=0; i<5; ++i) { - metric.increment(); - } - - assertThat(metric.getValue(), equalTo(5L)); + @Test + public void increment() { + NumericMetric metric = new NumericMetric("test", new BasicCounter()); + for (int i = 0; i < 5; ++i) { + metric.increment(); } - @Test - public void add() { - NumericMetric metric = new NumericMetric("test", new BasicCounter()); - metric.increment(5); + assertThat(metric.getValue(), equalTo(5L)); + } - assertThat(metric.getValue(), equalTo(5L)); - } + @Test + public void add() { + NumericMetric metric = new NumericMetric("test", new BasicCounter()); + metric.increment(5); + assertThat(metric.getValue(), equalTo(5L)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/RollingCounterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/RollingCounterTest.java index a1651aad6b..0ad333a6e2 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/RollingCounterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/RollingCounterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.metrics; import static org.hamcrest.MatcherAssert.assertThat; @@ -20,61 +19,58 @@ @RunWith(MockitoJUnitRunner.class) public class RollingCounterTest { - @Mock - Clock clock; + @Mock Clock clock; - @Test - public void increment() { - RollingCounter counter = new RollingCounter(3, 1, clock); - for (int i=0; i<5; ++i) { - counter.increment(); - } + @Test + public void increment() { + RollingCounter counter = new RollingCounter(3, 1, clock); + for (int i = 0; i < 5; ++i) { + counter.increment(); + } - assertThat(counter.getValue(), equalTo(0L)); + assertThat(counter.getValue(), equalTo(0L)); - when(clock.millis()).thenReturn(1000L); // 1 second passed - assertThat(counter.getValue(), equalTo(5L)); + when(clock.millis()).thenReturn(1000L); // 1 second passed + assertThat(counter.getValue(), equalTo(5L)); - counter.increment(); - counter.increment(); + counter.increment(); + counter.increment(); - when(clock.millis()).thenReturn(2000L); // 1 second passed - assertThat(counter.getValue(), lessThanOrEqualTo(3L)); + when(clock.millis()).thenReturn(2000L); // 1 second passed + assertThat(counter.getValue(), lessThanOrEqualTo(3L)); - when(clock.millis()).thenReturn(3000L); // 1 second passed - assertThat(counter.getValue(), equalTo(0L)); + when(clock.millis()).thenReturn(3000L); // 1 second passed + assertThat(counter.getValue(), equalTo(0L)); + } - } + @Test + public void add() { + RollingCounter counter = new RollingCounter(3, 1, clock); - @Test - public void add() { - RollingCounter counter = new RollingCounter(3, 1, clock); + counter.add(6); + assertThat(counter.getValue(), equalTo(0L)); - counter.add(6); - assertThat(counter.getValue(), equalTo(0L)); + when(clock.millis()).thenReturn(1000L); // 1 second passed + assertThat(counter.getValue(), equalTo(6L)); - when(clock.millis()).thenReturn(1000L); // 1 second passed - assertThat(counter.getValue(), equalTo(6L)); + counter.add(4); + when(clock.millis()).thenReturn(2000L); // 1 second passed + assertThat(counter.getValue(), equalTo(4L)); - counter.add(4); - when(clock.millis()).thenReturn(2000L); // 1 second passed - assertThat(counter.getValue(), equalTo(4L)); + when(clock.millis()).thenReturn(3000L); // 1 second passed + assertThat(counter.getValue(), equalTo(0L)); + } - when(clock.millis()).thenReturn(3000L); // 1 second passed - assertThat(counter.getValue(), equalTo(0L)); - } + @Test + public void trim() { + RollingCounter counter = new RollingCounter(2, 1, clock); - @Test - public void trim() { - RollingCounter counter = new RollingCounter(2, 1, clock); - - for (int i=1; i<6; ++i) { - counter.increment(); - assertThat(counter.size(), equalTo(i)); - when(clock.millis()).thenReturn(i * 1000L); // i seconds passed - } - counter.increment(); - assertThat(counter.size(), lessThanOrEqualTo(3)); + for (int i = 1; i < 6; ++i) { + counter.increment(); + assertThat(counter.size(), equalTo(i)); + when(clock.millis()).thenReturn(i * 1000L); // i seconds passed } - + counter.increment(); + assertThat(counter.size(), lessThanOrEqualTo(3)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/OpenSearchActionFactoryTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/OpenSearchActionFactoryTest.java index 0b7c7f6740..3443c2decd 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/OpenSearchActionFactoryTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/OpenSearchActionFactoryTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.junit.Assert.assertFalse; @@ -15,60 +14,51 @@ import org.opensearch.sql.legacy.util.SqlParserUtils; public class OpenSearchActionFactoryTest { - @Test - public void josnOutputRequestShouldNotMigrateToQueryPlan() { - String sql = "SELECT age, MAX(balance) " + - "FROM account " + - "GROUP BY age"; + @Test + public void josnOutputRequestShouldNotMigrateToQueryPlan() { + String sql = "SELECT age, MAX(balance) FROM account GROUP BY age"; - assertFalse( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JSON)); - } + assertFalse( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JSON)); + } - @Test - public void nestQueryShouldNotMigrateToQueryPlan() { - String sql = "SELECT age, nested(balance) " + - "FROM account " + - "GROUP BY age"; + @Test + public void nestQueryShouldNotMigrateToQueryPlan() { + String sql = "SELECT age, nested(balance) FROM account GROUP BY age"; - assertFalse( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); - } + assertFalse( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); + } - @Test - public void nonAggregationQueryShouldNotMigrateToQueryPlan() { - String sql = "SELECT age " + - "FROM account "; + @Test + public void nonAggregationQueryShouldNotMigrateToQueryPlan() { + String sql = "SELECT age FROM account "; - assertFalse( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); - } + assertFalse( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); + } - @Test - public void aggregationQueryWithoutGroupByShouldMigrateToQueryPlan() { - String sql = "SELECT age, COUNT(balance) " + - "FROM account "; + @Test + public void aggregationQueryWithoutGroupByShouldMigrateToQueryPlan() { + String sql = "SELECT age, COUNT(balance) FROM account "; - assertTrue( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); - } + assertTrue( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); + } - @Test - public void aggregationQueryWithExpressionByShouldMigrateToQueryPlan() { - String sql = "SELECT age, MAX(balance) - MIN(balance) " + - "FROM account "; + @Test + public void aggregationQueryWithExpressionByShouldMigrateToQueryPlan() { + String sql = "SELECT age, MAX(balance) - MIN(balance) FROM account "; - assertTrue( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); - } + assertTrue( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); + } - @Test - public void queryOnlyHasGroupByShouldMigrateToQueryPlan() { - String sql = "SELECT CAST(age AS DOUBLE) as alias " + - "FROM account " + - "GROUP BY alias"; + @Test + public void queryOnlyHasGroupByShouldMigrateToQueryPlan() { + String sql = "SELECT CAST(age AS DOUBLE) as alias FROM account GROUP BY alias"; - assertTrue( - OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); - } + assertTrue( + OpenSearchActionFactory.shouldMigrateToQueryPlan(SqlParserUtils.parse(sql), Format.JDBC)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerBatchTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerBatchTest.java index 545710e343..0c77550a2f 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerBatchTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerBatchTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.hamcrest.MatcherAssert.assertThat; @@ -24,221 +23,153 @@ import org.opensearch.search.SearchHits; /** - * Batch prefetch testing. Test against different combination of algorithm block size and scroll page size. + * Batch prefetch testing. Test against different combination of algorithm block size and scroll + * page size. */ @SuppressWarnings("unchecked") @RunWith(Parameterized.class) public class QueryPlannerBatchTest extends QueryPlannerTest { - private static final String TEST_SQL1 = - "SELECT " + - " /*! JOIN_CIRCUIT_BREAK_LIMIT(100) */ " + - " /*! JOIN_ALGORITHM_BLOCK_SIZE(%d) */ " + - " /*! JOIN_SCROLL_PAGE_SIZE(%d) */ " + - " e.lastname AS name, d.id AS id, d.name AS dep "; - - private static final String TEST_SQL2_JOIN1 = - "FROM department d " + - " %s employee e "; - - private static final String TEST_SQL2_JOIN2 = - "FROM employee e " + - " %s department d "; - - private static final String TEST_SQL3 = - "ON d.id = e.departmentId " + - " WHERE e.age <= 50"; - - private SearchHit[] employees = { - employee(1, "People 1", "A"), - employee(2, "People 2", "A"), - employee(3, "People 3", "A"), - employee(4, "People 4", "B"), - employee(5, "People 5", "B"), - employee(6, "People 6", "C"), - employee(7, "People 7", "D"), - employee(8, "People 8", "D"), - employee(9, "People 9", "E"), - employee(10, "People 10", "F") - }; - - private SearchHit[] departments = { - department(1, "A", "AWS"), - department(2, "C", "Capital One"), - department(3, "D", "Dell"), - department(4, "F", "Facebook"), - department(5, "G", "Google"), - department(6, "M", "Microsoft"), - department(7, "U", "Uber"), - }; - - private Matcher[] matched = { - hit( - kv("name", "People 1"), - kv("id", "A"), - kv("dep", "AWS") - ), - hit( - kv("name", "People 2"), - kv("id", "A"), - kv("dep", "AWS") - ), - hit( - kv("name", "People 3"), - kv("id", "A"), - kv("dep", "AWS") - ), - hit( - kv("name", "People 6"), - kv("id", "C"), - kv("dep", "Capital One") - ), - hit( - kv("name", "People 7"), - kv("id", "D"), - kv("dep", "Dell") - ), - hit( - kv("name", "People 8"), - kv("id", "D"), - kv("dep", "Dell") - ), - hit( - kv("name", "People 10"), - kv("id", "F"), - kv("dep", "Facebook") - ) - }; - - private Matcher[] mismatched1 = { - hit( - kv("name", null), - kv("id", "G"), - kv("dep", "Google") - ), - hit( - kv("name", null), - kv("id", "M"), - kv("dep", "Microsoft") - ), - hit( - kv("name", null), - kv("id", "U"), - kv("dep", "Uber") - ) - }; - - private Matcher[] mismatched2 = { - hit( - kv("name", "People 4"), - kv("id", null), - kv("dep", null) - ), - hit( - kv("name", "People 5"), - kv("id", null), - kv("dep", null) - ), - hit( - kv("name", "People 9"), - kv("id", null), - kv("dep", null) - ) - }; - - private Matcher expectedInnerJoinResult = hits(matched); - - /** Department left join Employee */ - private Matcher expectedLeftOuterJoinResult1 = hits(concat(matched, mismatched1)); - - /** Employee left join Department */ - private Matcher expectedLeftOuterJoinResult2 = hits(concat(matched, mismatched2)); - - /** Parameterized test cases */ - private final int blockSize; - private final int pageSize; - - public QueryPlannerBatchTest(int blockSize, int pageSize) { - this.blockSize = blockSize; - this.pageSize = pageSize; - } - - @Parameters - public static Collection data() { - List params = new ArrayList<>(); - for (int blockSize = 1; blockSize <= 11; blockSize++) { - for (int pageSize = 1; pageSize <= 11; pageSize++) { - params.add(new Object[]{ blockSize, pageSize }); - } - } - return params; - } - - @Test - public void departmentInnerJoinEmployee() { - assertThat( - query( - String.format( - TEST_SQL1 + TEST_SQL2_JOIN1 + TEST_SQL3, - blockSize, pageSize, "INNER JOIN"), - departments(pageSize, departments), - employees(pageSize, employees) - ), - expectedInnerJoinResult - ); - } - - @Test - public void employeeInnerJoinDepartment() { - assertThat( - query( - String.format( - TEST_SQL1 + TEST_SQL2_JOIN2 + TEST_SQL3, - blockSize, pageSize, "INNER JOIN"), - employees(pageSize, employees), - departments(pageSize, departments) - ), - expectedInnerJoinResult - ); - } - - @Test - public void departmentLeftJoinEmployee() { - assertThat( - query( - String.format( - TEST_SQL1 + TEST_SQL2_JOIN1 + TEST_SQL3, - blockSize, pageSize, "LEFT JOIN"), - departments(pageSize, departments), - employees(pageSize, employees) - ), - expectedLeftOuterJoinResult1 - ); - } - - @Test - public void employeeLeftJoinDepartment() { - assertThat( - query( - String.format( - TEST_SQL1 + TEST_SQL2_JOIN2 + TEST_SQL3, - blockSize, pageSize, "LEFT JOIN"), - employees(pageSize, employees), - departments(pageSize, departments) - ), - expectedLeftOuterJoinResult2 - ); - } - - private static Matcher[] concat(Matcher[] one, Matcher[] other) { - return concat(one, other, Matcher.class); - } - - /** Copy from OpenSearch ArrayUtils */ - private static T[] concat(T[] one, T[] other, Class clazz) { - T[] target = (T[]) Array.newInstance(clazz, one.length + other.length); - System.arraycopy(one, 0, target, 0, one.length); - System.arraycopy(other, 0, target, one.length, other.length); - return target; + private static final String TEST_SQL1 = + "SELECT " + + " /*! JOIN_CIRCUIT_BREAK_LIMIT(100) */ " + + " /*! JOIN_ALGORITHM_BLOCK_SIZE(%d) */ " + + " /*! JOIN_SCROLL_PAGE_SIZE(%d) */ " + + " e.lastname AS name, d.id AS id, d.name AS dep "; + + private static final String TEST_SQL2_JOIN1 = "FROM department d " + " %s employee e "; + + private static final String TEST_SQL2_JOIN2 = "FROM employee e " + " %s department d "; + + private static final String TEST_SQL3 = "ON d.id = e.departmentId " + " WHERE e.age <= 50"; + + private SearchHit[] employees = { + employee(1, "People 1", "A"), + employee(2, "People 2", "A"), + employee(3, "People 3", "A"), + employee(4, "People 4", "B"), + employee(5, "People 5", "B"), + employee(6, "People 6", "C"), + employee(7, "People 7", "D"), + employee(8, "People 8", "D"), + employee(9, "People 9", "E"), + employee(10, "People 10", "F") + }; + + private SearchHit[] departments = { + department(1, "A", "AWS"), + department(2, "C", "Capital One"), + department(3, "D", "Dell"), + department(4, "F", "Facebook"), + department(5, "G", "Google"), + department(6, "M", "Microsoft"), + department(7, "U", "Uber"), + }; + + private Matcher[] matched = { + hit(kv("name", "People 1"), kv("id", "A"), kv("dep", "AWS")), + hit(kv("name", "People 2"), kv("id", "A"), kv("dep", "AWS")), + hit(kv("name", "People 3"), kv("id", "A"), kv("dep", "AWS")), + hit(kv("name", "People 6"), kv("id", "C"), kv("dep", "Capital One")), + hit(kv("name", "People 7"), kv("id", "D"), kv("dep", "Dell")), + hit(kv("name", "People 8"), kv("id", "D"), kv("dep", "Dell")), + hit(kv("name", "People 10"), kv("id", "F"), kv("dep", "Facebook")) + }; + + private Matcher[] mismatched1 = { + hit(kv("name", null), kv("id", "G"), kv("dep", "Google")), + hit(kv("name", null), kv("id", "M"), kv("dep", "Microsoft")), + hit(kv("name", null), kv("id", "U"), kv("dep", "Uber")) + }; + + private Matcher[] mismatched2 = { + hit(kv("name", "People 4"), kv("id", null), kv("dep", null)), + hit(kv("name", "People 5"), kv("id", null), kv("dep", null)), + hit(kv("name", "People 9"), kv("id", null), kv("dep", null)) + }; + + private Matcher expectedInnerJoinResult = hits(matched); + + /** Department left join Employee */ + private Matcher expectedLeftOuterJoinResult1 = hits(concat(matched, mismatched1)); + + /** Employee left join Department */ + private Matcher expectedLeftOuterJoinResult2 = hits(concat(matched, mismatched2)); + + /** Parameterized test cases */ + private final int blockSize; + + private final int pageSize; + + public QueryPlannerBatchTest(int blockSize, int pageSize) { + this.blockSize = blockSize; + this.pageSize = pageSize; + } + + @Parameters + public static Collection data() { + List params = new ArrayList<>(); + for (int blockSize = 1; blockSize <= 11; blockSize++) { + for (int pageSize = 1; pageSize <= 11; pageSize++) { + params.add(new Object[] {blockSize, pageSize}); + } } + return params; + } + + @Test + public void departmentInnerJoinEmployee() { + assertThat( + query( + String.format( + TEST_SQL1 + TEST_SQL2_JOIN1 + TEST_SQL3, blockSize, pageSize, "INNER JOIN"), + departments(pageSize, departments), + employees(pageSize, employees)), + expectedInnerJoinResult); + } + + @Test + public void employeeInnerJoinDepartment() { + assertThat( + query( + String.format( + TEST_SQL1 + TEST_SQL2_JOIN2 + TEST_SQL3, blockSize, pageSize, "INNER JOIN"), + employees(pageSize, employees), + departments(pageSize, departments)), + expectedInnerJoinResult); + } + + @Test + public void departmentLeftJoinEmployee() { + assertThat( + query( + String.format( + TEST_SQL1 + TEST_SQL2_JOIN1 + TEST_SQL3, blockSize, pageSize, "LEFT JOIN"), + departments(pageSize, departments), + employees(pageSize, employees)), + expectedLeftOuterJoinResult1); + } + + @Test + public void employeeLeftJoinDepartment() { + assertThat( + query( + String.format( + TEST_SQL1 + TEST_SQL2_JOIN2 + TEST_SQL3, blockSize, pageSize, "LEFT JOIN"), + employees(pageSize, employees), + departments(pageSize, departments)), + expectedLeftOuterJoinResult2); + } + + private static Matcher[] concat(Matcher[] one, Matcher[] other) { + return concat(one, other, Matcher.class); + } + + /** Copy from OpenSearch ArrayUtils */ + private static T[] concat(T[] one, T[] other, Class clazz) { + T[] target = (T[]) Array.newInstance(clazz, one.length + other.length); + System.arraycopy(one, 0, target, 0, one.length); + System.arraycopy(other, 0, target, one.length, other.length); + return target; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerConfigTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerConfigTest.java index 07a84683ce..81d6d718b9 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerConfigTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerConfigTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.hamcrest.MatcherAssert.assertThat; @@ -23,291 +22,252 @@ import org.opensearch.sql.legacy.query.planner.HashJoinQueryPlanRequestBuilder; import org.opensearch.sql.legacy.query.planner.core.Config; -/** - * Hint & Configuring Ability Test Cases - */ +/** Hint & Configuring Ability Test Cases */ public class QueryPlannerConfigTest extends QueryPlannerTest { - private static final Matcher DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER = totalAndTableLimit(200, 0, 0); - - @Test - public void algorithmBlockSizeHint() { - assertThat( - parseHint("! JOIN_ALGORITHM_BLOCK_SIZE(100000)"), - hint( - hintType(HintType.JOIN_ALGORITHM_BLOCK_SIZE), - hintValues(100000) - ) - ); - } - - @Test - public void algorithmUseLegacy() { - assertThat( - parseHint("! JOIN_ALGORITHM_USE_LEGACY"), - hint( - hintType(HintType.JOIN_ALGORITHM_USE_LEGACY), - hintValues() - ) - ); - } - - @Test - public void algorithmBlockSizeHintWithSpaces() { - assertThat( - parseHint("! JOIN_ALGORITHM_BLOCK_SIZE ( 200000 ) "), - hint( - hintType(HintType.JOIN_ALGORITHM_BLOCK_SIZE), - hintValues(200000) - ) - ); - } - - @Test - public void scrollPageSizeHint() { - assertThat( - parseHint("! JOIN_SCROLL_PAGE_SIZE(1000) "), - hint( - hintType(HintType.JOIN_SCROLL_PAGE_SIZE), - hintValues(1000) - ) - ); - } - - @Test - public void scrollPageSizeHintWithTwoSizes() { - assertThat( - parseHint("! JOIN_SCROLL_PAGE_SIZE(1000, 2000) "), - hint( - hintType(HintType.JOIN_SCROLL_PAGE_SIZE), - hintValues(1000, 2000) - ) - ); - } - - @Test - public void circuitBreakLimitHint() { - assertThat( - parseHint("! JOIN_CIRCUIT_BREAK_LIMIT(80)"), - hint( - hintType(HintType.JOIN_CIRCUIT_BREAK_LIMIT), - hintValues(80) - ) - ); - } - - @Test - public void backOffRetryIntervalsHint() { - assertThat( - parseHint("! JOIN_BACK_OFF_RETRY_INTERVALS(1, 5)"), - hint( - hintType(HintType.JOIN_BACK_OFF_RETRY_INTERVALS), - hintValues(1, 5) - ) - ); - } - - @Test - public void timeOutHint() { - assertThat( - parseHint("! JOIN_TIME_OUT(120)"), - hint( - hintType(HintType.JOIN_TIME_OUT), - hintValues(120) - ) - ); - } - - @Test - public void blockSizeConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_ALGORITHM_BLOCK_SIZE(200000) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), - config( - blockSize(200000), - scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), - circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - @Test - public void scrollPageSizeConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_SCROLL_PAGE_SIZE(50, 20) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), - config( - blockSize(Config.DEFAULT_BLOCK_SIZE), - scrollPageSize(50, 20), - circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - @Test - public void circuitBreakLimitConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_CIRCUIT_BREAK_LIMIT(60) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), - config( - blockSize(Config.DEFAULT_BLOCK_SIZE), - scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), - circuitBreakLimit(60), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - @Test - public void backOffRetryIntervalsConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_BACK_OFF_RETRY_INTERVALS(1, 3, 5, 10) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), - config( - blockSize(Config.DEFAULT_BLOCK_SIZE), - scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), - circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), - backOffRetryIntervals(new double[]{1, 3, 5, 10}), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - @Test - public void totalAndTableLimitConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_TABLES_LIMIT(10, 20) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId LIMIT 50"), - config( - blockSize(Config.DEFAULT_BLOCK_SIZE), - scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), - circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - totalAndTableLimit(50, 10, 20), - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - @Test - public void timeOutConfig() { - assertThat(queryPlannerConfig( - "SELECT /*! JOIN_TIME_OUT(120) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId"), - config( - blockSize(Config.DEFAULT_BLOCK_SIZE), - scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), - circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(120) - ) - ); - } - - @Test - public void multipleConfigCombined() { - assertThat(queryPlannerConfig( - "SELECT " + - " /*! JOIN_ALGORITHM_BLOCK_SIZE(100) */ " + - " /*! JOIN_SCROLL_PAGE_SIZE(50, 20) */ " + - " /*! JOIN_CIRCUIT_BREAK_LIMIT(10) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), - config( - blockSize(100), - scrollPageSize(50, 20), - circuitBreakLimit(10), - backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), - DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, - timeOut(Config.DEFAULT_TIME_OUT) - ) - ); - } - - private Hint parseHint(String hintStr) { - try { - return HintFactory.getHintFromString(hintStr); - } - catch (SqlParseException e) { - throw new IllegalArgumentException(e); - } - } - - private Config queryPlannerConfig(String sql) { - HashJoinQueryPlanRequestBuilder request = ((HashJoinQueryPlanRequestBuilder) createRequestBuilder(sql)); - request.plan(); - return request.getConfig(); + private static final Matcher DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER = + totalAndTableLimit(200, 0, 0); + + @Test + public void algorithmBlockSizeHint() { + assertThat( + parseHint("! JOIN_ALGORITHM_BLOCK_SIZE(100000)"), + hint(hintType(HintType.JOIN_ALGORITHM_BLOCK_SIZE), hintValues(100000))); + } + + @Test + public void algorithmUseLegacy() { + assertThat( + parseHint("! JOIN_ALGORITHM_USE_LEGACY"), + hint(hintType(HintType.JOIN_ALGORITHM_USE_LEGACY), hintValues())); + } + + @Test + public void algorithmBlockSizeHintWithSpaces() { + assertThat( + parseHint("! JOIN_ALGORITHM_BLOCK_SIZE ( 200000 ) "), + hint(hintType(HintType.JOIN_ALGORITHM_BLOCK_SIZE), hintValues(200000))); + } + + @Test + public void scrollPageSizeHint() { + assertThat( + parseHint("! JOIN_SCROLL_PAGE_SIZE(1000) "), + hint(hintType(HintType.JOIN_SCROLL_PAGE_SIZE), hintValues(1000))); + } + + @Test + public void scrollPageSizeHintWithTwoSizes() { + assertThat( + parseHint("! JOIN_SCROLL_PAGE_SIZE(1000, 2000) "), + hint(hintType(HintType.JOIN_SCROLL_PAGE_SIZE), hintValues(1000, 2000))); + } + + @Test + public void circuitBreakLimitHint() { + assertThat( + parseHint("! JOIN_CIRCUIT_BREAK_LIMIT(80)"), + hint(hintType(HintType.JOIN_CIRCUIT_BREAK_LIMIT), hintValues(80))); + } + + @Test + public void backOffRetryIntervalsHint() { + assertThat( + parseHint("! JOIN_BACK_OFF_RETRY_INTERVALS(1, 5)"), + hint(hintType(HintType.JOIN_BACK_OFF_RETRY_INTERVALS), hintValues(1, 5))); + } + + @Test + public void timeOutHint() { + assertThat( + parseHint("! JOIN_TIME_OUT(120)"), hint(hintType(HintType.JOIN_TIME_OUT), hintValues(120))); + } + + @Test + public void blockSizeConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_ALGORITHM_BLOCK_SIZE(200000) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), + config( + blockSize(200000), + scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), + circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(Config.DEFAULT_TIME_OUT))); + } + + @Test + public void scrollPageSizeConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_SCROLL_PAGE_SIZE(50, 20) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), + config( + blockSize(Config.DEFAULT_BLOCK_SIZE), + scrollPageSize(50, 20), + circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(Config.DEFAULT_TIME_OUT))); + } + + @Test + public void circuitBreakLimitConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_CIRCUIT_BREAK_LIMIT(60) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), + config( + blockSize(Config.DEFAULT_BLOCK_SIZE), + scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), + circuitBreakLimit(60), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(Config.DEFAULT_TIME_OUT))); + } + + @Test + public void backOffRetryIntervalsConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_BACK_OFF_RETRY_INTERVALS(1, 3, 5, 10) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), + config( + blockSize(Config.DEFAULT_BLOCK_SIZE), + scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), + circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), + backOffRetryIntervals(new double[] {1, 3, 5, 10}), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(Config.DEFAULT_TIME_OUT))); + } + + @Test + public void totalAndTableLimitConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_TABLES_LIMIT(10, 20) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId LIMIT 50"), + config( + blockSize(Config.DEFAULT_BLOCK_SIZE), + scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), + circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + totalAndTableLimit(50, 10, 20), + timeOut(Config.DEFAULT_TIME_OUT))); + } + + @Test + public void timeOutConfig() { + assertThat( + queryPlannerConfig( + "SELECT /*! JOIN_TIME_OUT(120) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId"), + config( + blockSize(Config.DEFAULT_BLOCK_SIZE), + scrollPageSize(Config.DEFAULT_SCROLL_PAGE_SIZE, Config.DEFAULT_SCROLL_PAGE_SIZE), + circuitBreakLimit(Config.DEFAULT_CIRCUIT_BREAK_LIMIT), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(120))); + } + + @Test + public void multipleConfigCombined() { + assertThat( + queryPlannerConfig( + "SELECT " + + " /*! JOIN_ALGORITHM_BLOCK_SIZE(100) */ " + + " /*! JOIN_SCROLL_PAGE_SIZE(50, 20) */ " + + " /*! JOIN_CIRCUIT_BREAK_LIMIT(10) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId "), + config( + blockSize(100), + scrollPageSize(50, 20), + circuitBreakLimit(10), + backOffRetryIntervals(Config.DEFAULT_BACK_OFF_RETRY_INTERVALS), + DEFAULT_TOTAL_AND_TABLE_LIMIT_MATCHER, + timeOut(Config.DEFAULT_TIME_OUT))); + } + + private Hint parseHint(String hintStr) { + try { + return HintFactory.getHintFromString(hintStr); + } catch (SqlParseException e) { + throw new IllegalArgumentException(e); } - - private Matcher hint(Matcher typeMatcher, Matcher valuesMatcher) { - return both( - featureValueOf("HintType", typeMatcher, Hint::getType) - ).and( - featureValueOf("HintValue", valuesMatcher, Hint::getParams) - ); - } - - private Matcher hintType(HintType type) { - return is(type); - } - - private Matcher hintValues(Object... values) { - if (values.length == 0) { - return emptyArray(); - } - return arrayContaining(values); - } - - private Matcher config(Matcher blockSizeMatcher, - Matcher scrollPageSizeMatcher, - Matcher circuitBreakLimitMatcher, - Matcher backOffRetryIntervalsMatcher, - Matcher totalAndTableLimitMatcher, - Matcher timeOutMatcher) { - return allOf( - featureValueOf("Block size", blockSizeMatcher, (cfg -> cfg.blockSize().size())), - featureValueOf("Scroll page size", scrollPageSizeMatcher, Config::scrollPageSize), - featureValueOf("Circuit break limit", circuitBreakLimitMatcher, Config::circuitBreakLimit), - featureValueOf("Back off retry intervals", backOffRetryIntervalsMatcher, Config::backOffRetryIntervals), - featureValueOf("Total and table limit", totalAndTableLimitMatcher, - (cfg -> new Integer[]{cfg.totalLimit(), cfg.tableLimit1(), cfg.tableLimit2()})), - featureValueOf("Time out", timeOutMatcher, Config::timeout) - ); + } + + private Config queryPlannerConfig(String sql) { + HashJoinQueryPlanRequestBuilder request = + ((HashJoinQueryPlanRequestBuilder) createRequestBuilder(sql)); + request.plan(); + return request.getConfig(); + } + + private Matcher hint(Matcher typeMatcher, Matcher valuesMatcher) { + return both(featureValueOf("HintType", typeMatcher, Hint::getType)) + .and(featureValueOf("HintValue", valuesMatcher, Hint::getParams)); + } + + private Matcher hintType(HintType type) { + return is(type); + } + + private Matcher hintValues(Object... values) { + if (values.length == 0) { + return emptyArray(); } - - private Matcher blockSize(int size) { - return is(size); - } - - @SuppressWarnings("unchecked") - private Matcher scrollPageSize(int size1, int size2) { - return arrayContaining(is(size1), is(size2)); - } - - private Matcher circuitBreakLimit(int limit) { - return is(limit); - } - - private Matcher backOffRetryIntervals(double[] intervals) { - return is(intervals); - } - - @SuppressWarnings("unchecked") - private static Matcher totalAndTableLimit(int totalLimit, int tableLimit1, int tableLimit2) { - return arrayContaining(is(totalLimit), is(tableLimit1), is(tableLimit2)); - } - - private static Matcher timeOut(int timeout) { - return is(timeout); - } - + return arrayContaining(values); + } + + private Matcher config( + Matcher blockSizeMatcher, + Matcher scrollPageSizeMatcher, + Matcher circuitBreakLimitMatcher, + Matcher backOffRetryIntervalsMatcher, + Matcher totalAndTableLimitMatcher, + Matcher timeOutMatcher) { + return allOf( + featureValueOf("Block size", blockSizeMatcher, (cfg -> cfg.blockSize().size())), + featureValueOf("Scroll page size", scrollPageSizeMatcher, Config::scrollPageSize), + featureValueOf("Circuit break limit", circuitBreakLimitMatcher, Config::circuitBreakLimit), + featureValueOf( + "Back off retry intervals", + backOffRetryIntervalsMatcher, + Config::backOffRetryIntervals), + featureValueOf( + "Total and table limit", + totalAndTableLimitMatcher, + (cfg -> new Integer[] {cfg.totalLimit(), cfg.tableLimit1(), cfg.tableLimit2()})), + featureValueOf("Time out", timeOutMatcher, Config::timeout)); + } + + private Matcher blockSize(int size) { + return is(size); + } + + @SuppressWarnings("unchecked") + private Matcher scrollPageSize(int size1, int size2) { + return arrayContaining(is(size1), is(size2)); + } + + private Matcher circuitBreakLimit(int limit) { + return is(limit); + } + + private Matcher backOffRetryIntervals(double[] intervals) { + return is(intervals); + } + + @SuppressWarnings("unchecked") + private static Matcher totalAndTableLimit( + int totalLimit, int tableLimit1, int tableLimit2) { + return arrayContaining(is(totalLimit), is(tableLimit1), is(tableLimit2)); + } + + private static Matcher timeOut(int timeout) { + return is(timeout); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExecuteTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExecuteTest.java index 55ea8c390b..dc8e094e2d 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExecuteTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExecuteTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.opensearch.sql.legacy.util.MatcherUtils.hit; @@ -14,767 +13,420 @@ import org.opensearch.search.SearchHit; import org.opensearch.sql.legacy.util.MatcherUtils; -/** - * Query planner execution unit test - */ +/** Query planner execution unit test */ public class QueryPlannerExecuteTest extends QueryPlannerTest { - @Test - public void simpleJoin() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - @Test - public void simpleJoinWithSelectAll() { - MatcherAssert.assertThat( - query( - "SELECT * FROM employee e " + - " JOIN department d ON d.id = e.departmentId ", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1"), - MatcherUtils.kv("e.lastname", "Alice"), - MatcherUtils.kv("e.departmentId", "1") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1"), - MatcherUtils.kv("e.lastname", "Hank"), - MatcherUtils.kv("e.departmentId", "1") - ) - ) - ); - } - - @Test - public void simpleLeftJoinWithSelectAllFromOneTable() { - MatcherAssert.assertThat( - query( - "SELECT e.lastname, d.* FROM employee e " + - " LEFT JOIN department d ON d.id = e.departmentId ", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1"), - employee(3, "Allen", "3") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("e.lastname", "Alice"), - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1") - ), - hit( - MatcherUtils.kv("e.lastname", "Hank"), - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1") - ), - hit( - MatcherUtils.kv("e.lastname", "Allen") - /* - * Not easy to figure out all column names for d.* without reading metadata - * or look into other rows from d. But in the extreme case, d could be empty table - * which requires metadata read anyway. - */ - //kv("d.name", null), - //kv("d.id", null) - ) - ) - ); - } - - @Test - public void simpleJoinWithSelectAllFromBothTables() { - MatcherAssert.assertThat( - query( - "SELECT e.*, d.* FROM employee e " + - " JOIN department d ON d.id = e.departmentId ", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1"), - MatcherUtils.kv("e.lastname", "Alice"), - MatcherUtils.kv("e.departmentId", "1") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("d.id", "1"), - MatcherUtils.kv("e.lastname", "Hank"), - MatcherUtils.kv("e.departmentId", "1") - ) - ) - ); - } - - @Test - public void simpleJoinWithoutMatch() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "2"), - employee(2, "Hank", "3") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits() - ); - } - - @Test - public void simpleJoinWithSomeMatches() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "2"), - employee(2, "Hank", "3") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Retail"), - MatcherUtils.kv("e.lastname", "Alice") - ) - ) - ); - } - - @Test - public void simpleJoinWithAllMatches() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1"), - employee(3, "Mike", "2") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Hank") - ), - hit( - MatcherUtils.kv("d.name", "Retail"), - MatcherUtils.kv("e.lastname", "Mike") - ) - ) - ); - } - - @Test - public void simpleJoinWithNull() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", null), - employee(3, "Mike", "2") - ), - departments( - department(1, "1", "AWS"), - department(2, null, "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ) - ) - ); - } - - @Test - public void simpleJoinWithColumnNameConflict() { - // Add a same column 'name' as in department on purpose - SearchHit alice = employee(1, "Alice", "1"); - alice.getSourceAsMap().put("name", "Alice Alice"); - SearchHit hank = employee(2, "Hank", "2"); - hank.getSourceAsMap().put("name", "Hank Hank"); - - MatcherAssert.assertThat( - query( - "SELECT d.name, e.name FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - alice, hank - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.name", "Alice Alice") - ) - ) - ); - } - - @Test - public void simpleJoinWithAliasInSelect() { - MatcherAssert.assertThat( - query( - "SELECT d.name AS dname, e.lastname AS ename FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "2"), - employee(2, "Hank", "3") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("dname", "Retail"), - MatcherUtils.kv("ename", "Alice") - ) - ) - ); - } - - @Test - public void simpleLeftJoinWithoutMatchInLeft() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " LEFT JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "2"), - employee(2, "Hank", "3") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", null), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", null), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - @Test - public void simpleLeftJoinWithSomeMismatchesInLeft() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " LEFT JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "2") - ), - departments( - department(1, "1", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", null), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - @Test - public void simpleLeftJoinWithSomeMismatchesInRight() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e " + - " LEFT JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - @Test - public void simpleQueryWithTotalLimit() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e JOIN department d ON d.id = e.departmentId LIMIT 1", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "2") - ), - departments( - department(1, "1", "AWS"), - department(1, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ) - ) - ); - } - - @Test - public void simpleQueryWithTableLimit() { - MatcherAssert.assertThat( - query( - "SELECT /*! JOIN_TABLES_LIMIT(1, 5) */ d.name, e.lastname FROM employee e JOIN department d ON d.id = e.departmentId", - employees( - employee(1, "Alice", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "AWS"), - department(1, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Alice") - ) - ) - ); - } - - @Test - public void simpleQueryWithOrderBy() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname FROM employee e JOIN department d ON d.id = e.departmentId ORDER BY e.lastname", - employees( - employee(1, "Hank", "1"), - employee(2, "Alice", "2"), - employee(3, "Allen", "1"), - employee(4, "Ellis", "2"), - employee(5, "Frank", "2") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - MatcherUtils.hitsInOrder( - hit( - MatcherUtils.kv("d.name", "Retail"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Allen") - ), - hit( - MatcherUtils.kv("d.name", "Retail"), - MatcherUtils.kv("e.lastname", "Ellis") - ), - hit( - MatcherUtils.kv("d.name", "Retail"), - MatcherUtils.kv("e.lastname", "Frank") - ), - hit( - MatcherUtils.kv("d.name", "AWS"), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - /** Doesn't support muliple columns from both tables (order is missing) */ - @Test - public void simpleQueryWithLeftJoinAndOrderByMultipleColumnsFromOneTableInDesc() { - MatcherAssert.assertThat( - query( - "SELECT d.id AS id, e.lastname AS lastname FROM employee e " + - " LEFT JOIN department d ON d.id = e.departmentId " + - " ORDER BY e.departmentId, e.lastname DESC", - employees( - employee(1, "Hank", "1"), - employee(2, "Alice", "2"), - employee(3, "Allen", "1"), - employee(4, "Ellis", "2"), - employee(5, "Gary", "3"), - employee(5, "Frank", "3") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - MatcherUtils.hitsInOrder( - hit( - MatcherUtils.kv("id", null), - MatcherUtils.kv("lastname", "Gary") - ), - hit( - MatcherUtils.kv("id", null), - MatcherUtils.kv("lastname", "Frank") - ), - hit( - MatcherUtils.kv("id", "2"), - MatcherUtils.kv("lastname", "Ellis") - ), - hit( - MatcherUtils.kv("id", "2"), - MatcherUtils.kv("lastname", "Alice") - ), - hit( - MatcherUtils.kv("id", "1"), - MatcherUtils.kv("lastname", "Hank") - ), - hit( - MatcherUtils.kv("id", "1"), - MatcherUtils.kv("lastname", "Allen") - ) - ) - ); - } - - @Test - public void simpleCrossJoin() { - MatcherAssert.assertThat( - query( - "SELECT d.name AS dname, e.lastname AS ename FROM employee e JOIN department d", - employees( - employee(1, "Alice", "2"), - employee(2, "Hank", "3") - ), - departments( - department(1, "1", "AWS"), - department(2, "2", "Retail") - ) - ), - hits( - hit( - MatcherUtils.kv("dname", "AWS"), - MatcherUtils.kv("ename", "Alice") - ), - hit( - MatcherUtils.kv("dname", "AWS"), - MatcherUtils.kv("ename", "Hank") - ), - hit( - MatcherUtils.kv("dname", "Retail"), - MatcherUtils.kv("ename", "Alice") - ), - hit( - MatcherUtils.kv("dname", "Retail"), - MatcherUtils.kv("ename", "Hank") - ) - ) - ); - } - - @Test - public void simpleQueryWithTermsFilterOptimization() { - MatcherAssert.assertThat( - query( - "SELECT /*! HASH_WITH_TERMS_FILTER*/ " + // Be careful that no space between ...FILTER and */ - " e.lastname, d.id FROM employee e " + - " JOIN department d ON d.id = e.departmentId AND d.name = e.lastname", - employees( - employee(1, "Johnson", "1"), - employee(2, "Allen", "4"), - employee(3, "Ellis", "2"), - employee(4, "Dell", "1"), - employee(5, "Dell", "4") - ), - departments( - department(1, "1", "Johnson"), - department(1, "4", "Dell") - ) - ), - hits( - hit( - MatcherUtils.kv("e.lastname", "Johnson"), - MatcherUtils.kv("d.id", "1") - ), - hit( - MatcherUtils.kv("e.lastname", "Dell"), - MatcherUtils.kv("d.id", "4") - ) - ) - ); - } - - @Test - public void complexJoinWithMultipleConditions() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname, d.id " + - " FROM employee e " + - " JOIN department d " + - " ON d.id = e.departmentId AND d.name = e.lastname" + - " WHERE d.region = 'US' AND e.age > 30", - employees( - employee(1, "Dell", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "Dell") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Dell"), - MatcherUtils.kv("d.id", "1") - ) - ) - ); - } - - @Test - public void complexJoinWithOrConditions() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname " + - " FROM employee e " + - " JOIN department d " + - " ON d.id = e.departmentId OR d.name = e.lastname", - employees( - employee(1, "Alice", "1"), - employee(2, "Dell", "2"), - employee(3, "Hank", "3") - ), - departments( - department(1, "1", "Dell"), - department(2, "4", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Dell") - ) - ) - ); - } - - @Test - public void complexJoinWithOrConditionsDuplicate() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.departmentId " + - " FROM employee e " + - " JOIN department d " + - " ON d.id = e.departmentId OR d.name = e.lastname", - employees( - employee(1, "Dell", "1") // Match both condition but should only show once in result - ), - departments( - department(1, "1", "Dell"), - department(2, "4", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.departmentId", "1") - ) - ) - ); - } - - @Test - public void complexJoinWithOrConditionsAndTermsFilterOptimization() { - MatcherAssert.assertThat( - query( - "SELECT /*! HASH_WITH_TERMS_FILTER*/ " + - " d.name, e.lastname " + - " FROM employee e " + - " JOIN department d " + - " ON d.id = e.departmentId OR d.name = e.lastname", - employees( - employee(1, "Alice", "1"), - employee(2, "Dell", "2"), - employee(3, "Hank", "3") - ), - departments( - department(1, "1", "Dell"), - department(2, "4", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Dell") - ) - ) - ); - } - - @Test - public void complexLeftJoinWithOrConditions() { - MatcherAssert.assertThat( - query( - "SELECT d.name, e.lastname " + - " FROM employee e " + - " LEFT JOIN department d " + - " ON d.id = e.departmentId OR d.name = e.lastname", - employees( - employee(1, "Alice", "1"), - employee(2, "Dell", "2"), - employee(3, "Hank", "3") - ), - departments( - department(1, "1", "Dell"), - department(2, "4", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Alice") - ), - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Dell") - ), - hit( - MatcherUtils.kv("d.name", null), - MatcherUtils.kv("e.lastname", "Hank") - ) - ) - ); - } - - @Test - public void complexJoinWithTableLimitHint() { - MatcherAssert.assertThat( - query( - "SELECT " + - " /*! JOIN_TABLES_LIMIT(2, 1)*/" + - " d.name, e.lastname " + - " FROM employee e " + - " JOIN department d " + - " ON d.id = e.departmentId", - employees( - employee(1, "Alice", "1"), // Only this and the second row will be pulled out - employee(2, "Dell", "4"), - employee(3, "Hank", "1") - ), - departments( - department(1, "1", "Dell"), // Only this row will be pulled out - department(2, "4", "AWS") - ) - ), - hits( - hit( - MatcherUtils.kv("d.name", "Dell"), - MatcherUtils.kv("e.lastname", "Alice") - ) - ) - ); - } - + @Test + public void simpleJoin() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Hank")))); + } + + @Test + public void simpleJoinWithSelectAll() { + MatcherAssert.assertThat( + query( + "SELECT * FROM employee e " + " JOIN department d ON d.id = e.departmentId ", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "AWS"))), + hits( + hit( + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1"), + MatcherUtils.kv("e.lastname", "Alice"), + MatcherUtils.kv("e.departmentId", "1")), + hit( + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1"), + MatcherUtils.kv("e.lastname", "Hank"), + MatcherUtils.kv("e.departmentId", "1")))); + } + + @Test + public void simpleLeftJoinWithSelectAllFromOneTable() { + MatcherAssert.assertThat( + query( + "SELECT e.lastname, d.* FROM employee e " + + " LEFT JOIN department d ON d.id = e.departmentId ", + employees( + employee(1, "Alice", "1"), employee(2, "Hank", "1"), employee(3, "Allen", "3")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits( + hit( + MatcherUtils.kv("e.lastname", "Alice"), + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1")), + hit( + MatcherUtils.kv("e.lastname", "Hank"), + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1")), + hit( + MatcherUtils.kv("e.lastname", "Allen") + /* + * Not easy to figure out all column names for d.* without reading metadata + * or look into other rows from d. But in the extreme case, d could be empty table + * which requires metadata read anyway. + */ + // kv("d.name", null), + // kv("d.id", null) + ))); + } + + @Test + public void simpleJoinWithSelectAllFromBothTables() { + MatcherAssert.assertThat( + query( + "SELECT e.*, d.* FROM employee e " + " JOIN department d ON d.id = e.departmentId ", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "AWS"))), + hits( + hit( + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1"), + MatcherUtils.kv("e.lastname", "Alice"), + MatcherUtils.kv("e.departmentId", "1")), + hit( + MatcherUtils.kv("d.name", "AWS"), + MatcherUtils.kv("d.id", "1"), + MatcherUtils.kv("e.lastname", "Hank"), + MatcherUtils.kv("e.departmentId", "1")))); + } + + @Test + public void simpleJoinWithoutMatch() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "2"), employee(2, "Hank", "3")), + departments(department(1, "1", "AWS"))), + hits()); + } + + @Test + public void simpleJoinWithSomeMatches() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "2"), employee(2, "Hank", "3")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits(hit(MatcherUtils.kv("d.name", "Retail"), MatcherUtils.kv("e.lastname", "Alice")))); + } + + @Test + public void simpleJoinWithAllMatches() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees( + employee(1, "Alice", "1"), employee(2, "Hank", "1"), employee(3, "Mike", "2")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits( + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Hank")), + hit(MatcherUtils.kv("d.name", "Retail"), MatcherUtils.kv("e.lastname", "Mike")))); + } + + @Test + public void simpleJoinWithNull() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees( + employee(1, "Alice", "1"), employee(2, "Hank", null), employee(3, "Mike", "2")), + departments(department(1, "1", "AWS"), department(2, null, "Retail"))), + hits(hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")))); + } + + @Test + public void simpleJoinWithColumnNameConflict() { + // Add a same column 'name' as in department on purpose + SearchHit alice = employee(1, "Alice", "1"); + alice.getSourceAsMap().put("name", "Alice Alice"); + SearchHit hank = employee(2, "Hank", "2"); + hank.getSourceAsMap().put("name", "Hank Hank"); + + MatcherAssert.assertThat( + query( + "SELECT d.name, e.name FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(alice, hank), + departments(department(1, "1", "AWS"))), + hits(hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.name", "Alice Alice")))); + } + + @Test + public void simpleJoinWithAliasInSelect() { + MatcherAssert.assertThat( + query( + "SELECT d.name AS dname, e.lastname AS ename FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "2"), employee(2, "Hank", "3")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits(hit(MatcherUtils.kv("dname", "Retail"), MatcherUtils.kv("ename", "Alice")))); + } + + @Test + public void simpleLeftJoinWithoutMatchInLeft() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " LEFT JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "2"), employee(2, "Hank", "3")), + departments(department(1, "1", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", null), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", null), MatcherUtils.kv("e.lastname", "Hank")))); + } + + @Test + public void simpleLeftJoinWithSomeMismatchesInLeft() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " LEFT JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "2")), + departments(department(1, "1", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", null), MatcherUtils.kv("e.lastname", "Hank")))); + } + + @Test + public void simpleLeftJoinWithSomeMismatchesInRight() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e " + + " LEFT JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits( + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Hank")))); + } + + @Test + public void simpleQueryWithTotalLimit() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e JOIN department d ON d.id = e.departmentId" + + " LIMIT 1", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "2")), + departments(department(1, "1", "AWS"), department(1, "2", "Retail"))), + hits(hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")))); + } + + @Test + public void simpleQueryWithTableLimit() { + MatcherAssert.assertThat( + query( + "SELECT /*! JOIN_TABLES_LIMIT(1, 5) */ d.name, e.lastname FROM employee e JOIN" + + " department d ON d.id = e.departmentId", + employees(employee(1, "Alice", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "AWS"), department(1, "2", "Retail"))), + hits(hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Alice")))); + } + + @Test + public void simpleQueryWithOrderBy() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname FROM employee e JOIN department d ON d.id = e.departmentId" + + " ORDER BY e.lastname", + employees( + employee(1, "Hank", "1"), + employee(2, "Alice", "2"), + employee(3, "Allen", "1"), + employee(4, "Ellis", "2"), + employee(5, "Frank", "2")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + MatcherUtils.hitsInOrder( + hit(MatcherUtils.kv("d.name", "Retail"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Allen")), + hit(MatcherUtils.kv("d.name", "Retail"), MatcherUtils.kv("e.lastname", "Ellis")), + hit(MatcherUtils.kv("d.name", "Retail"), MatcherUtils.kv("e.lastname", "Frank")), + hit(MatcherUtils.kv("d.name", "AWS"), MatcherUtils.kv("e.lastname", "Hank")))); + } + + /** Doesn't support muliple columns from both tables (order is missing) */ + @Test + public void simpleQueryWithLeftJoinAndOrderByMultipleColumnsFromOneTableInDesc() { + MatcherAssert.assertThat( + query( + "SELECT d.id AS id, e.lastname AS lastname FROM employee e " + + " LEFT JOIN department d ON d.id = e.departmentId " + + " ORDER BY e.departmentId, e.lastname DESC", + employees( + employee(1, "Hank", "1"), + employee(2, "Alice", "2"), + employee(3, "Allen", "1"), + employee(4, "Ellis", "2"), + employee(5, "Gary", "3"), + employee(5, "Frank", "3")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + MatcherUtils.hitsInOrder( + hit(MatcherUtils.kv("id", null), MatcherUtils.kv("lastname", "Gary")), + hit(MatcherUtils.kv("id", null), MatcherUtils.kv("lastname", "Frank")), + hit(MatcherUtils.kv("id", "2"), MatcherUtils.kv("lastname", "Ellis")), + hit(MatcherUtils.kv("id", "2"), MatcherUtils.kv("lastname", "Alice")), + hit(MatcherUtils.kv("id", "1"), MatcherUtils.kv("lastname", "Hank")), + hit(MatcherUtils.kv("id", "1"), MatcherUtils.kv("lastname", "Allen")))); + } + + @Test + public void simpleCrossJoin() { + MatcherAssert.assertThat( + query( + "SELECT d.name AS dname, e.lastname AS ename FROM employee e JOIN department d", + employees(employee(1, "Alice", "2"), employee(2, "Hank", "3")), + departments(department(1, "1", "AWS"), department(2, "2", "Retail"))), + hits( + hit(MatcherUtils.kv("dname", "AWS"), MatcherUtils.kv("ename", "Alice")), + hit(MatcherUtils.kv("dname", "AWS"), MatcherUtils.kv("ename", "Hank")), + hit(MatcherUtils.kv("dname", "Retail"), MatcherUtils.kv("ename", "Alice")), + hit(MatcherUtils.kv("dname", "Retail"), MatcherUtils.kv("ename", "Hank")))); + } + + @Test + public void simpleQueryWithTermsFilterOptimization() { + MatcherAssert.assertThat( + query( + "SELECT /*! HASH_WITH_TERMS_FILTER*/ " + + // Be careful that no space between ...FILTER and */ + " e.lastname, d.id FROM employee e " + + " JOIN department d ON d.id = e.departmentId AND d.name = e.lastname", + employees( + employee(1, "Johnson", "1"), + employee(2, "Allen", "4"), + employee(3, "Ellis", "2"), + employee(4, "Dell", "1"), + employee(5, "Dell", "4")), + departments(department(1, "1", "Johnson"), department(1, "4", "Dell"))), + hits( + hit(MatcherUtils.kv("e.lastname", "Johnson"), MatcherUtils.kv("d.id", "1")), + hit(MatcherUtils.kv("e.lastname", "Dell"), MatcherUtils.kv("d.id", "4")))); + } + + @Test + public void complexJoinWithMultipleConditions() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname, d.id " + + " FROM employee e " + + " JOIN department d " + + " ON d.id = e.departmentId AND d.name = e.lastname" + + " WHERE d.region = 'US' AND e.age > 30", + employees(employee(1, "Dell", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "Dell"))), + hits( + hit( + MatcherUtils.kv("d.name", "Dell"), + MatcherUtils.kv("e.lastname", "Dell"), + MatcherUtils.kv("d.id", "1")))); + } + + @Test + public void complexJoinWithOrConditions() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname " + + " FROM employee e " + + " JOIN department d " + + " ON d.id = e.departmentId OR d.name = e.lastname", + employees( + employee(1, "Alice", "1"), employee(2, "Dell", "2"), employee(3, "Hank", "3")), + departments(department(1, "1", "Dell"), department(2, "4", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Dell")))); + } + + @Test + public void complexJoinWithOrConditionsDuplicate() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.departmentId " + + " FROM employee e " + + " JOIN department d " + + " ON d.id = e.departmentId OR d.name = e.lastname", + employees( + employee(1, "Dell", "1") // Match both condition but should only show once in result + ), + departments(department(1, "1", "Dell"), department(2, "4", "AWS"))), + hits(hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.departmentId", "1")))); + } + + @Test + public void complexJoinWithOrConditionsAndTermsFilterOptimization() { + MatcherAssert.assertThat( + query( + "SELECT /*! HASH_WITH_TERMS_FILTER*/ " + + " d.name, e.lastname " + + " FROM employee e " + + " JOIN department d " + + " ON d.id = e.departmentId OR d.name = e.lastname", + employees( + employee(1, "Alice", "1"), employee(2, "Dell", "2"), employee(3, "Hank", "3")), + departments(department(1, "1", "Dell"), department(2, "4", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Dell")))); + } + + @Test + public void complexLeftJoinWithOrConditions() { + MatcherAssert.assertThat( + query( + "SELECT d.name, e.lastname " + + " FROM employee e " + + " LEFT JOIN department d " + + " ON d.id = e.departmentId OR d.name = e.lastname", + employees( + employee(1, "Alice", "1"), employee(2, "Dell", "2"), employee(3, "Hank", "3")), + departments(department(1, "1", "Dell"), department(2, "4", "AWS"))), + hits( + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Alice")), + hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Dell")), + hit(MatcherUtils.kv("d.name", null), MatcherUtils.kv("e.lastname", "Hank")))); + } + + @Test + public void complexJoinWithTableLimitHint() { + MatcherAssert.assertThat( + query( + "SELECT " + + " /*! JOIN_TABLES_LIMIT(2, 1)*/" + + " d.name, e.lastname " + + " FROM employee e " + + " JOIN department d " + + " ON d.id = e.departmentId", + employees( + employee(1, "Alice", "1"), // Only this and the second row will be pulled out + employee(2, "Dell", "4"), + employee(3, "Hank", "1")), + departments( + department(1, "1", "Dell"), // Only this row will be pulled out + department(2, "4", "AWS"))), + hits(hit(MatcherUtils.kv("d.name", "Dell"), MatcherUtils.kv("e.lastname", "Alice")))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java index 2c92c91666..7f495935ca 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java @@ -3,45 +3,41 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import org.junit.Test; import org.opensearch.sql.legacy.query.planner.core.QueryPlanner; -/** - * Query planner explanation unit test - */ +/** Query planner explanation unit test */ public class QueryPlannerExplainTest extends QueryPlannerTest { - @Test - public void explainInJson() { - QueryPlanner planner = plan( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30" - ); - planner.explain(); - } - - @Test - public void explainInJsonWithComplicatedOn() { - QueryPlanner planner = plan( - "SELECT d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId AND d.location = e.region " + - " WHERE d.region = 'US' AND e.age > 30" - ); - planner.explain(); - } - - @Test - public void explainInJsonWithDuplicateColumnsPushedDown() { - QueryPlanner planner = plan( - "SELECT d.id, e.departmentId FROM employee e " + - " JOIN department d ON d.id = e.departmentId AND d.location = e.region " + - " WHERE d.region = 'US' AND e.age > 30" - ); - planner.explain(); - } - + @Test + public void explainInJson() { + QueryPlanner planner = + plan( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30"); + planner.explain(); + } + + @Test + public void explainInJsonWithComplicatedOn() { + QueryPlanner planner = + plan( + "SELECT d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId AND d.location = e.region " + + " WHERE d.region = 'US' AND e.age > 30"); + planner.explain(); + } + + @Test + public void explainInJsonWithDuplicateColumnsPushedDown() { + QueryPlanner planner = + plan( + "SELECT d.id, e.departmentId FROM employee e " + + " JOIN department d ON d.id = e.departmentId AND d.location = e.region " + + " WHERE d.region = 'US' AND e.age > 30"); + planner.explain(); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerMonitorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerMonitorTest.java index 66ce2411f4..9b1d307ebc 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerMonitorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerMonitorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.mockito.Mockito.doAnswer; @@ -18,109 +17,95 @@ import org.opensearch.sql.legacy.query.planner.resource.Stats; import org.opensearch.sql.legacy.query.planner.resource.Stats.MemStats; -/** - * Circuit breaker component test - */ +/** Circuit breaker component test */ @Ignore public class QueryPlannerMonitorTest extends QueryPlannerTest { - /** Configure back off strategy 1s, 1s and 1s - retry 4 times at most */ - private static final String TEST_SQL1 = - "SELECT /*! JOIN_BACK_OFF_RETRY_INTERVALS(1, 1, 1) */ " + - " /*! JOIN_CIRCUIT_BREAK_LIMIT("; - - private static final String TEST_SQL2 = - ") */ d.name, e.lastname FROM employee e " + - " JOIN department d ON d.id = e.departmentId " + - " WHERE d.region = 'US' AND e.age > 30"; - - private static final long[] PERCENT_USAGE_15 = freeAndTotalMem(85, 100); - private static final long[] PERCENT_USAGE_24 = freeAndTotalMem(76, 100); - private static final long[] PERCENT_USAGE_50 = freeAndTotalMem(50, 100); - - @Spy - private Stats stats = new Stats(client); - - @Test - public void reachedLimitAndRecoverAt1stAttempt() { - mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_24); - queryWithLimit(25); // TODO: assert if final result set is correct after recovery - } - - @Test - public void reachedLimitAndRecoverAt2ndAttempt() { - mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_15); - queryWithLimit(25); - } - - @Test - public void reachedLimitAndRecoverAt3rdAttempt() { - mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_15); - queryWithLimit(25); - } - - @Test(expected = IllegalStateException.class) - public void reachedLimitAndFailFinally() { - mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50); - queryWithLimit(25); - } - - @Test(expected = IllegalStateException.class) - public void reachedLimitAndRejectNewRequest() { - mockMemUsage(PERCENT_USAGE_50); - queryWithLimit(25); - } - - @Test(expected = IllegalStateException.class) - public void timeOut() { - query( - "SELECT /*! JOIN_TIME_OUT(0) */ " + - " d.name FROM employee e JOIN department d ON d.id = e.departmentId", - employees( - employee(1, "Dell", "1") - ), - departments( - department(1, "1", "Dell") - ) - ); - } - - private void mockMemUsage(long[]... memUsages) { - doAnswer(new Answer() { - private int callCnt = -1; - - @Override - public MemStats answer(InvocationOnMock invocation) { + /** Configure back off strategy 1s, 1s and 1s - retry 4 times at most */ + private static final String TEST_SQL1 = + "SELECT /*! JOIN_BACK_OFF_RETRY_INTERVALS(1, 1, 1) */ " + " /*! JOIN_CIRCUIT_BREAK_LIMIT("; + + private static final String TEST_SQL2 = + ") */ d.name, e.lastname FROM employee e " + + " JOIN department d ON d.id = e.departmentId " + + " WHERE d.region = 'US' AND e.age > 30"; + + private static final long[] PERCENT_USAGE_15 = freeAndTotalMem(85, 100); + private static final long[] PERCENT_USAGE_24 = freeAndTotalMem(76, 100); + private static final long[] PERCENT_USAGE_50 = freeAndTotalMem(50, 100); + + @Spy private Stats stats = new Stats(client); + + @Test + public void reachedLimitAndRecoverAt1stAttempt() { + mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_24); + queryWithLimit(25); // TODO: assert if final result set is correct after recovery + } + + @Test + public void reachedLimitAndRecoverAt2ndAttempt() { + mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_15); + queryWithLimit(25); + } + + @Test + public void reachedLimitAndRecoverAt3rdAttempt() { + mockMemUsage( + PERCENT_USAGE_15, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_50, PERCENT_USAGE_15); + queryWithLimit(25); + } + + @Test(expected = IllegalStateException.class) + public void reachedLimitAndFailFinally() { + mockMemUsage(PERCENT_USAGE_15, PERCENT_USAGE_50); + queryWithLimit(25); + } + + @Test(expected = IllegalStateException.class) + public void reachedLimitAndRejectNewRequest() { + mockMemUsage(PERCENT_USAGE_50); + queryWithLimit(25); + } + + @Test(expected = IllegalStateException.class) + public void timeOut() { + query( + "SELECT /*! JOIN_TIME_OUT(0) */ " + + " d.name FROM employee e JOIN department d ON d.id = e.departmentId", + employees(employee(1, "Dell", "1")), + departments(department(1, "1", "Dell"))); + } + + private void mockMemUsage(long[]... memUsages) { + doAnswer( + new Answer() { + private int callCnt = -1; + + @Override + public MemStats answer(InvocationOnMock invocation) { callCnt = Math.min(callCnt + 1, memUsages.length - 1); - return new MemStats( - memUsages[callCnt][0], memUsages[callCnt][1] - ); - } - }).when(stats).collectMemStats(); - } - - private static long[] freeAndTotalMem(long free, long total) { - return new long[]{ free, total }; - } - - private SearchHits queryWithLimit(int limit) { - return query( - TEST_SQL1 + limit + TEST_SQL2, - employees( - employee(1, "Dell", "1"), - employee(2, "Hank", "1") - ), - departments( - department(1, "1", "Dell") - ) - ); - } - - @Override - protected QueryPlanner plan(String sql) { - QueryPlanner planner = super.plan(sql); - planner.setStats(stats); - return planner; - } - + return new MemStats(memUsages[callCnt][0], memUsages[callCnt][1]); + } + }) + .when(stats) + .collectMemStats(); + } + + private static long[] freeAndTotalMem(long free, long total) { + return new long[] {free, total}; + } + + private SearchHits queryWithLimit(int limit) { + return query( + TEST_SQL1 + limit + TEST_SQL2, + employees(employee(1, "Dell", "1"), employee(2, "Hank", "1")), + departments(department(1, "1", "Dell"))); + } + + @Override + protected QueryPlanner plan(String sql) { + QueryPlanner planner = super.plan(sql); + planner.setStats(stats); + return planner; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java index 66380c108d..4cda101ae4 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static java.util.Collections.emptyList; @@ -58,246 +57,240 @@ import org.opensearch.sql.legacy.request.SqlRequest; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -/** - * Test base class for all query planner tests. - */ +/** Test base class for all query planner tests. */ @Ignore public abstract class QueryPlannerTest { - @Mock - protected Client client; + @Mock protected Client client; - @Mock - private SearchResponse response1; - private static final String SCROLL_ID1 = "1"; + @Mock private SearchResponse response1; + private static final String SCROLL_ID1 = "1"; - @Mock - private SearchResponse response2; - private static final String SCROLL_ID2 = "2"; + @Mock private SearchResponse response2; + private static final String SCROLL_ID2 = "2"; - @Mock - private ClusterSettings clusterSettings; + @Mock private ClusterSettings clusterSettings; - /* - @BeforeClass - public static void initLogger() { - ConfigurationBuilder builder = newConfigurationBuilder(); - AppenderComponentBuilder appender = builder.newAppender("stdout", "Console"); + /* + @BeforeClass + public static void initLogger() { + ConfigurationBuilder builder = newConfigurationBuilder(); + AppenderComponentBuilder appender = builder.newAppender("stdout", "Console"); - LayoutComponentBuilder standard = builder.newLayout("PatternLayout"); - standard.addAttribute("pattern", "%d [%t] %-5level: %msg%n%throwable"); - appender.add(standard); + LayoutComponentBuilder standard = builder.newLayout("PatternLayout"); + standard.addAttribute("pattern", "%d [%t] %-5level: %msg%n%throwable"); + appender.add(standard); - RootLoggerComponentBuilder rootLogger = builder.newRootLogger(Level.ERROR); - rootLogger.add(builder.newAppenderRef("stdout")); + RootLoggerComponentBuilder rootLogger = builder.newRootLogger(Level.ERROR); + rootLogger.add(builder.newAppenderRef("stdout")); - LoggerComponentBuilder logger = builder.newLogger("org.nlpcn.es4sql.query.planner", Level.TRACE); - logger.add(builder.newAppenderRef("stdout")); - //logger.addAttribute("additivity", false); + LoggerComponentBuilder logger = builder.newLogger("org.nlpcn.es4sql.query.planner", Level.TRACE); + logger.add(builder.newAppenderRef("stdout")); + //logger.addAttribute("additivity", false); - builder.add(logger); + builder.add(logger); - Configurator.initialize(builder.build()); - } - */ + Configurator.initialize(builder.build()); + } + */ - @Before - public void init() { - MockitoAnnotations.initMocks(this); - when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT); - OpenSearchSettings settings = spy(new OpenSearchSettings(clusterSettings)); + @Before + public void init() { + MockitoAnnotations.initMocks(this); + when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT); + OpenSearchSettings settings = spy(new OpenSearchSettings(clusterSettings)); - // Force return empty list to avoid ClusterSettings be invoked which is a final class and hard to mock. - // In this case, default value in Setting will be returned all the time. - doReturn(emptyList()).when(settings).getSettings(); - LocalClusterState.state().setPluginSettings(settings); + // Force return empty list to avoid ClusterSettings be invoked which is a final class and hard + // to mock. + // In this case, default value in Setting will be returned all the time. + doReturn(emptyList()).when(settings).getSettings(); + LocalClusterState.state().setPluginSettings(settings); - ActionFuture mockFuture = mock(ActionFuture.class); - when(client.execute(any(), any())).thenReturn(mockFuture); + ActionFuture mockFuture = mock(ActionFuture.class); + when(client.execute(any(), any())).thenReturn(mockFuture); - // Differentiate response for Scroll-1/2 by call count and scroll ID. - when(mockFuture.actionGet()).thenAnswer(new Answer() { - private int callCnt; + // Differentiate response for Scroll-1/2 by call count and scroll ID. + when(mockFuture.actionGet()) + .thenAnswer( + new Answer() { + private int callCnt; - @Override - public SearchResponse answer(InvocationOnMock invocation) { + @Override + public SearchResponse answer(InvocationOnMock invocation) { /* * This works based on assumption that first call comes from Scroll-1, all the following calls come from Scroll-2. * Because Scroll-1 only open scroll once and must be ahead of Scroll-2 which opens multiple times later. */ return callCnt++ == 0 ? response1 : response2; - } - }); - - doReturn(SCROLL_ID1).when(response1).getScrollId(); - doReturn(SCROLL_ID2).when(response2).getScrollId(); - - // Avoid NPE in empty SearchResponse - doReturn(0).when(response1).getFailedShards(); - doReturn(0).when(response2).getFailedShards(); - doReturn(false).when(response1).isTimedOut(); - doReturn(false).when(response2).isTimedOut(); - - returnMockResponse(SCROLL_ID1, response1); - returnMockResponse(SCROLL_ID2, response2); - - Metrics.getInstance().registerDefaultMetrics(); - } - - private void returnMockResponse(String scrollId, SearchResponse response) { - SearchScrollRequestBuilder mockReqBuilder = mock(SearchScrollRequestBuilder.class); - when(client.prepareSearchScroll(scrollId)).thenReturn(mockReqBuilder); - when(mockReqBuilder.setScroll(any(TimeValue.class))).thenReturn(mockReqBuilder); - when(mockReqBuilder.get()).thenReturn(response); - } - - protected SearchHits query(String sql, MockSearchHits mockHits1, MockSearchHits mockHits2) { - doAnswer(mockHits1).when(response1).getHits(); - doAnswer(mockHits2).when(response2).getHits(); - - try (MockedStatic backOffRetryStrategyMocked = - Mockito.mockStatic(BackOffRetryStrategy.class)) { - backOffRetryStrategyMocked.when(BackOffRetryStrategy::isHealthy).thenReturn(true); + } + }); - ClearScrollRequestBuilder mockReqBuilder = mock(ClearScrollRequestBuilder.class); - when(client.prepareClearScroll()).thenReturn(mockReqBuilder); - when(mockReqBuilder.addScrollId(any())).thenReturn(mockReqBuilder); - when(mockReqBuilder.get()).thenAnswer(new Answer() { + doReturn(SCROLL_ID1).when(response1).getScrollId(); + doReturn(SCROLL_ID2).when(response2).getScrollId(); + + // Avoid NPE in empty SearchResponse + doReturn(0).when(response1).getFailedShards(); + doReturn(0).when(response2).getFailedShards(); + doReturn(false).when(response1).isTimedOut(); + doReturn(false).when(response2).isTimedOut(); + + returnMockResponse(SCROLL_ID1, response1); + returnMockResponse(SCROLL_ID2, response2); + + Metrics.getInstance().registerDefaultMetrics(); + } + + private void returnMockResponse(String scrollId, SearchResponse response) { + SearchScrollRequestBuilder mockReqBuilder = mock(SearchScrollRequestBuilder.class); + when(client.prepareSearchScroll(scrollId)).thenReturn(mockReqBuilder); + when(mockReqBuilder.setScroll(any(TimeValue.class))).thenReturn(mockReqBuilder); + when(mockReqBuilder.get()).thenReturn(response); + } + + protected SearchHits query(String sql, MockSearchHits mockHits1, MockSearchHits mockHits2) { + doAnswer(mockHits1).when(response1).getHits(); + doAnswer(mockHits2).when(response2).getHits(); + + try (MockedStatic backOffRetryStrategyMocked = + Mockito.mockStatic(BackOffRetryStrategy.class)) { + backOffRetryStrategyMocked.when(BackOffRetryStrategy::isHealthy).thenReturn(true); + + ClearScrollRequestBuilder mockReqBuilder = mock(ClearScrollRequestBuilder.class); + when(client.prepareClearScroll()).thenReturn(mockReqBuilder); + when(mockReqBuilder.addScrollId(any())).thenReturn(mockReqBuilder); + when(mockReqBuilder.get()) + .thenAnswer( + new Answer() { @Override public ClearScrollResponse answer(InvocationOnMock invocation) throws Throwable { - mockHits2.reset(); - return new ClearScrollResponse(true, 0); + mockHits2.reset(); + return new ClearScrollResponse(true, 0); } - }); + }); - List hits = plan(sql).execute(); - return new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), Relation.EQUAL_TO), 0); - } + List hits = plan(sql).execute(); + return new SearchHits( + hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), Relation.EQUAL_TO), 0); } + } - protected QueryPlanner plan(String sql) { - SqlElasticRequestBuilder request = createRequestBuilder(sql); - if (request instanceof HashJoinQueryPlanRequestBuilder) { - return ((HashJoinQueryPlanRequestBuilder) request).plan(); - } - throw new IllegalStateException("Not a JOIN query: " + sql); + protected QueryPlanner plan(String sql) { + SqlElasticRequestBuilder request = createRequestBuilder(sql); + if (request instanceof HashJoinQueryPlanRequestBuilder) { + return ((HashJoinQueryPlanRequestBuilder) request).plan(); } - - protected SqlElasticRequestBuilder createRequestBuilder(String sql) { - try { - SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(sql); - JoinSelect joinSelect = new SqlParser().parseJoinSelect(sqlExpr); // Ignore handleSubquery() - QueryAction queryAction = OpenSearchJoinQueryActionFactory - .createJoinAction(client, joinSelect); - queryAction.setSqlRequest(new SqlRequest(sql, null)); - return queryAction.explain(); - } - catch (SqlParseException e) { - throw new IllegalStateException("Invalid query: " + sql, e); - } + throw new IllegalStateException("Not a JOIN query: " + sql); + } + + protected SqlElasticRequestBuilder createRequestBuilder(String sql) { + try { + SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(sql); + JoinSelect joinSelect = new SqlParser().parseJoinSelect(sqlExpr); // Ignore handleSubquery() + QueryAction queryAction = + OpenSearchJoinQueryActionFactory.createJoinAction(client, joinSelect); + queryAction.setSqlRequest(new SqlRequest(sql, null)); + return queryAction.explain(); + } catch (SqlParseException e) { + throw new IllegalStateException("Invalid query: " + sql, e); } + } - private SQLExpr toSqlExpr(String sql) { - SQLExprParser parser = new ElasticSqlExprParser(sql); - SQLExpr expr = parser.expr(); + private SQLExpr toSqlExpr(String sql) { + SQLExprParser parser = new ElasticSqlExprParser(sql); + SQLExpr expr = parser.expr(); - if (parser.getLexer().token() != Token.EOF) { - throw new ParserException("illegal sql expr : " + sql); - } - return expr; + if (parser.getLexer().token() != Token.EOF) { + throw new ParserException("illegal sql expr : " + sql); } + return expr; + } - /** - * Mock SearchHits and slice and return in batch. - */ - protected static class MockSearchHits implements Answer { - - private final SearchHit[] allHits; - - private final int batchSize; //TODO: should be inferred from mock object dynamically - - private int callCnt; - - MockSearchHits(SearchHit[] allHits, int batchSize) { - this.allHits = allHits; - this.batchSize = batchSize; - } - - @Override - public SearchHits answer(InvocationOnMock invocation) { - SearchHit[] curBatch; - if (isNoMoreBatch()) { - curBatch = new SearchHit[0]; - } else { - curBatch = currentBatch(); - callCnt++; - } - return new SearchHits(curBatch, new TotalHits(allHits.length, Relation.EQUAL_TO), 0); - } - - private boolean isNoMoreBatch() { - return callCnt > allHits.length / batchSize; - } - - private SearchHit[] currentBatch() { - return Arrays.copyOfRange(allHits, startIndex(), endIndex()); - } - - private int startIndex() { - return callCnt * batchSize; - } - - private int endIndex() { - return Math.min(startIndex() + batchSize, allHits.length); - } - - private void reset() { - callCnt = 0; - } - } + /** Mock SearchHits and slice and return in batch. */ + protected static class MockSearchHits implements Answer { + + private final SearchHit[] allHits; - protected MockSearchHits employees(SearchHit... mockHits) { - return employees(5, mockHits); + private final int batchSize; // TODO: should be inferred from mock object dynamically + + private int callCnt; + + MockSearchHits(SearchHit[] allHits, int batchSize) { + this.allHits = allHits; + this.batchSize = batchSize; } - protected MockSearchHits employees(int pageSize, SearchHit... mockHits) { - return new MockSearchHits(mockHits, pageSize); + @Override + public SearchHits answer(InvocationOnMock invocation) { + SearchHit[] curBatch; + if (isNoMoreBatch()) { + curBatch = new SearchHit[0]; + } else { + curBatch = currentBatch(); + callCnt++; + } + return new SearchHits(curBatch, new TotalHits(allHits.length, Relation.EQUAL_TO), 0); } - protected MockSearchHits departments(SearchHit... mockHits) { - return departments(5, mockHits); + private boolean isNoMoreBatch() { + return callCnt > allHits.length / batchSize; } - protected MockSearchHits departments(int pageSize, SearchHit... mockHits) { - return new MockSearchHits(mockHits, pageSize); + private SearchHit[] currentBatch() { + return Arrays.copyOfRange(allHits, startIndex(), endIndex()); } - protected SearchHit employee(int docId, String lastname, String departmentId) { - SearchHit hit = new SearchHit(docId); - if (lastname == null) { - hit.sourceRef(new BytesArray("{\"departmentId\":\"" + departmentId + "\"}")); - } - else if (departmentId == null) { - hit.sourceRef(new BytesArray("{\"lastname\":\"" + lastname + "\"}")); - } - else { - hit.sourceRef(new BytesArray("{\"lastname\":\"" + lastname + "\",\"departmentId\":\"" + departmentId + "\"}")); - } - return hit; + private int startIndex() { + return callCnt * batchSize; } - protected SearchHit department(int docId, String id, String name) { - SearchHit hit = new SearchHit(docId); - if (id == null) { - hit.sourceRef(new BytesArray("{\"name\":\"" + name + "\"}")); - } - else if (name == null) { - hit.sourceRef(new BytesArray("{\"id\":\"" + id + "\"}")); - } - else { - hit.sourceRef(new BytesArray("{\"id\":\"" + id + "\",\"name\":\"" + name + "\"}")); - } - return hit; + private int endIndex() { + return Math.min(startIndex() + batchSize, allHits.length); } + private void reset() { + callCnt = 0; + } + } + + protected MockSearchHits employees(SearchHit... mockHits) { + return employees(5, mockHits); + } + + protected MockSearchHits employees(int pageSize, SearchHit... mockHits) { + return new MockSearchHits(mockHits, pageSize); + } + + protected MockSearchHits departments(SearchHit... mockHits) { + return departments(5, mockHits); + } + + protected MockSearchHits departments(int pageSize, SearchHit... mockHits) { + return new MockSearchHits(mockHits, pageSize); + } + + protected SearchHit employee(int docId, String lastname, String departmentId) { + SearchHit hit = new SearchHit(docId); + if (lastname == null) { + hit.sourceRef(new BytesArray("{\"departmentId\":\"" + departmentId + "\"}")); + } else if (departmentId == null) { + hit.sourceRef(new BytesArray("{\"lastname\":\"" + lastname + "\"}")); + } else { + hit.sourceRef( + new BytesArray( + "{\"lastname\":\"" + lastname + "\",\"departmentId\":\"" + departmentId + "\"}")); + } + return hit; + } + + protected SearchHit department(int docId, String id, String name) { + SearchHit hit = new SearchHit(docId); + if (id == null) { + hit.sourceRef(new BytesArray("{\"name\":\"" + name + "\"}")); + } else if (name == null) { + hit.sourceRef(new BytesArray("{\"id\":\"" + id + "\"}")); + } else { + hit.sourceRef(new BytesArray("{\"id\":\"" + id + "\",\"name\":\"" + name + "\"}")); + } + return hit; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/physical/SearchAggregationResponseHelperTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/physical/SearchAggregationResponseHelperTest.java index 589dab8905..cca5f745ee 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/physical/SearchAggregationResponseHelperTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/physical/SearchAggregationResponseHelperTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner.physical; import static org.hamcrest.MatcherAssert.assertThat; @@ -29,305 +28,318 @@ @RunWith(MockitoJUnitRunner.class) public class SearchAggregationResponseHelperTest { - /** - * SELECT MAX(age) as max - * FROM accounts - */ - @Test - public void noBucketOneMetricShouldPass() { - String json = "{\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, contains(allOf(hasEntry("max", 40d)))); - } + /** SELECT MAX(age) as max FROM accounts */ + @Test + public void noBucketOneMetricShouldPass() { + String json = "{\n \"max#max\": {\n \"value\": 40\n }\n}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat(result, contains(allOf(hasEntry("max", 40d)))); + } - /** - * SELECT MAX(age) as max, MIN(age) as min - * FROM accounts - */ - @Test - public void noBucketMultipleMetricShouldPass() { - String json = "{\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " },\n" - + " \"min#min\": {\n" - + " \"value\": 20\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, contains(allOf(hasEntry("max", 40d), hasEntry("min", 20d)))); - } + /** SELECT MAX(age) as max, MIN(age) as min FROM accounts */ + @Test + public void noBucketMultipleMetricShouldPass() { + String json = + "{\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " },\n" + + " \"min#min\": {\n" + + " \"value\": 20\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat(result, contains(allOf(hasEntry("max", 40d), hasEntry("min", 20d)))); + } - /** - * SELECT gender, MAX(age) as max, MIN(age) as min - * FROM accounts - * GROUP BY gender - */ - @Test - public void oneBucketMultipleMetricShouldPass() { - String json = "{\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 507,\n" - + " \"min#min\": {\n" - + " \"value\": 10\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 20\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 493,\n" - + " \"min#min\": {\n" - + " \"value\": 20\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, contains(allOf(hasEntry("gender", (Object) "m"), hasEntry("min", 10d), hasEntry("max", 20d)), - allOf(hasEntry("gender", (Object) "f"), hasEntry("min", 20d), - hasEntry("max", 40d)))); - } + /** SELECT gender, MAX(age) as max, MIN(age) as min FROM accounts GROUP BY gender */ + @Test + public void oneBucketMultipleMetricShouldPass() { + String json = + "{\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 507,\n" + + " \"min#min\": {\n" + + " \"value\": 10\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 20\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 493,\n" + + " \"min#min\": {\n" + + " \"value\": 20\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat( + result, + contains( + allOf(hasEntry("gender", (Object) "m"), hasEntry("min", 10d), hasEntry("max", 20d)), + allOf(hasEntry("gender", (Object) "f"), hasEntry("min", 20d), hasEntry("max", 40d)))); + } - /** - * SELECT gender, state, MAX(age) as max, MIN(age) as min - * FROM accounts - * GROUP BY gender, state - */ - @Test - public void multipleBucketMultipleMetricShouldPass() { - String json = "{\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"sterms#state\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"MD\",\n" - + " \"min#min\": {\n" - + " \"value\": 22\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 39\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": \"ID\",\n" - + " \"min#min\": {\n" - + " \"value\": 23\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"sterms#state\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"TX\",\n" - + " \"min#min\": {\n" - + " \"value\": 20\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 38\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": \"MI\",\n" - + " \"min#min\": {\n" - + " \"value\": 22\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, contains( - allOf(hasEntry("gender", (Object) "m"), hasEntry("state", (Object) "MD"), hasEntry("min", 22d), - hasEntry("max", 39d)), - allOf(hasEntry("gender", (Object) "m"), hasEntry("state", (Object) "ID"), hasEntry("min", 23d), - hasEntry("max", 40d)), - allOf(hasEntry("gender", (Object) "f"), hasEntry("state", (Object) "TX"), hasEntry("min", 20d), - hasEntry("max", 38d)), - allOf(hasEntry("gender", (Object) "f"), hasEntry("state", (Object) "MI"), hasEntry("min", 22d), - hasEntry("max", 40d)))); - } + /** SELECT gender, state, MAX(age) as max, MIN(age) as min FROM accounts GROUP BY gender, state */ + @Test + public void multipleBucketMultipleMetricShouldPass() { + String json = + "{\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"sterms#state\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"MD\",\n" + + " \"min#min\": {\n" + + " \"value\": 22\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 39\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": \"ID\",\n" + + " \"min#min\": {\n" + + " \"value\": 23\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"sterms#state\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"TX\",\n" + + " \"min#min\": {\n" + + " \"value\": 20\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 38\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": \"MI\",\n" + + " \"min#min\": {\n" + + " \"value\": 22\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat( + result, + contains( + allOf( + hasEntry("gender", (Object) "m"), + hasEntry("state", (Object) "MD"), + hasEntry("min", 22d), + hasEntry("max", 39d)), + allOf( + hasEntry("gender", (Object) "m"), + hasEntry("state", (Object) "ID"), + hasEntry("min", 23d), + hasEntry("max", 40d)), + allOf( + hasEntry("gender", (Object) "f"), + hasEntry("state", (Object) "TX"), + hasEntry("min", 20d), + hasEntry("max", 38d)), + allOf( + hasEntry("gender", (Object) "f"), + hasEntry("state", (Object) "MI"), + hasEntry("min", 22d), + hasEntry("max", 40d)))); + } - /** - * SELECT age, gender FROM accounts GROUP BY age, gender - */ - @Test - public void multipleBucketWithoutMetricShouldPass() { - String json = "{\n" - + " \"lterms#age\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": 31,\n" - + " \"doc_count\": 61,\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 35\n" - + " },\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 26\n" - + " }\n" - + " ]\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": 39,\n" - + " \"doc_count\": 60,\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 38\n" - + " },\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 22\n" - + " }\n" - + " ]\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, containsInAnyOrder( - allOf(hasEntry("age", (Object) 31L), hasEntry("gender","m")), - allOf(hasEntry("age", (Object) 31L), hasEntry("gender","f")), - allOf(hasEntry("age", (Object) 39L), hasEntry("gender","m")), - allOf(hasEntry("age", (Object) 39L), hasEntry("gender","f")))); - } + /** SELECT age, gender FROM accounts GROUP BY age, gender */ + @Test + public void multipleBucketWithoutMetricShouldPass() { + String json = + "{\n" + + " \"lterms#age\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": 31,\n" + + " \"doc_count\": 61,\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 35\n" + + " },\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 26\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": 39,\n" + + " \"doc_count\": 60,\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 38\n" + + " },\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 22\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat( + result, + containsInAnyOrder( + allOf(hasEntry("age", (Object) 31L), hasEntry("gender", "m")), + allOf(hasEntry("age", (Object) 31L), hasEntry("gender", "f")), + allOf(hasEntry("age", (Object) 39L), hasEntry("gender", "m")), + allOf(hasEntry("age", (Object) 39L), hasEntry("gender", "f")))); + } - /** - * SELECT PERCENTILES(age) FROM accounts - */ - @Test - public void noBucketPercentilesShouldPass() { - String json = "{\n" - + " \"percentiles_bucket#age\": {\n" - + " \"values\": {\n" - + " \"1.0\": 20,\n" - + " \"5.0\": 21,\n" - + " \"25.0\": 25,\n" - + " \"50.0\": 30.90909090909091,\n" - + " \"75.0\": 35,\n" - + " \"95.0\": 39,\n" - + " \"99.0\": 40\n" - + " }\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, contains(allOf(hasEntry("age_1.0", 20d)))); - } + /** SELECT PERCENTILES(age) FROM accounts */ + @Test + public void noBucketPercentilesShouldPass() { + String json = + "{\n" + + " \"percentiles_bucket#age\": {\n" + + " \"values\": {\n" + + " \"1.0\": 20,\n" + + " \"5.0\": 21,\n" + + " \"25.0\": 25,\n" + + " \"50.0\": 30.90909090909091,\n" + + " \"75.0\": 35,\n" + + " \"95.0\": 39,\n" + + " \"99.0\": 40\n" + + " }\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat(result, contains(allOf(hasEntry("age_1.0", 20d)))); + } - /** - * SELECT count(*) from online - * GROUP BY date_histogram('field'='insert_time','interval'='4d','alias'='days') - */ - @Test - public void populateShouldPass() { - String json = "{\n" - + " \"date_histogram#days\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key_as_string\": \"2014-08-14 00:00:00\",\n" - + " \"key\": 1407974400000,\n" - + " \"doc_count\": 477,\n" - + " \"value_count#COUNT_0\": {\n" - + " \"value\": 477\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - List> result = SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); - assertThat(result, containsInAnyOrder( - allOf(hasEntry("days", (Object) "2014-08-14 00:00:00"), hasEntry("COUNT_0",477d)))); - } + /** + * SELECT count(*) from online GROUP BY + * date_histogram('field'='insert_time','interval'='4d','alias'='days') + */ + @Test + public void populateShouldPass() { + String json = + "{\n" + + " \"date_histogram#days\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key_as_string\": \"2014-08-14 00:00:00\",\n" + + " \"key\": 1407974400000,\n" + + " \"doc_count\": 477,\n" + + " \"value_count#COUNT_0\": {\n" + + " \"value\": 477\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + List> result = + SearchAggregationResponseHelper.flatten(AggregationUtils.fromJson(json)); + assertThat( + result, + containsInAnyOrder( + allOf(hasEntry("days", (Object) "2014-08-14 00:00:00"), hasEntry("COUNT_0", 477d)))); + } - /** - * SELECT s - */ - @Test - public void populateSearchAggregationResponeShouldPass() { - String json = "{\n" - + " \"lterms#age\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": 31,\n" - + " \"doc_count\": 61,\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 35\n" - + " },\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 26\n" - + " }\n" - + " ]\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": 39,\n" - + " \"doc_count\": 60,\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 38\n" - + " },\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 22\n" - + " }\n" - + " ]\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - List bindingTupleRows = - SearchAggregationResponseHelper.populateSearchAggregationResponse(AggregationUtils.fromJson(json)); - assertEquals(4, bindingTupleRows.size()); - assertThat(bindingTupleRows, containsInAnyOrder( - bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 31L, "gender", "m"))), - bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 31L, "gender", "f"))), - bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 39L, "gender", "m"))), - bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 39L, "gender", "f"))))); - } + /** SELECT s */ + @Test + public void populateSearchAggregationResponeShouldPass() { + String json = + "{\n" + + " \"lterms#age\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": 31,\n" + + " \"doc_count\": 61,\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 35\n" + + " },\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 26\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": 39,\n" + + " \"doc_count\": 60,\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 38\n" + + " },\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 22\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + List bindingTupleRows = + SearchAggregationResponseHelper.populateSearchAggregationResponse( + AggregationUtils.fromJson(json)); + assertEquals(4, bindingTupleRows.size()); + assertThat( + bindingTupleRows, + containsInAnyOrder( + bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 31L, "gender", "m"))), + bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 31L, "gender", "f"))), + bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 39L, "gender", "m"))), + bindingTupleRow(BindingTuple.from(ImmutableMap.of("age", 39L, "gender", "f"))))); + } - private static Matcher bindingTupleRow(BindingTuple bindingTuple) { - return featureValueOf("BindingTuple", equalTo(bindingTuple), BindingTupleRow::data); - } + private static Matcher bindingTupleRow(BindingTuple bindingTuple) { + return featureValueOf("BindingTuple", equalTo(bindingTuple), BindingTupleRow::data); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/RewriteRuleExecutorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/RewriteRuleExecutorTest.java index 632cd2d7ea..9c13e1fc71 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/RewriteRuleExecutorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/RewriteRuleExecutorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.rewriter; import static org.mockito.Mockito.never; @@ -23,31 +22,29 @@ @RunWith(MockitoJUnitRunner.class) public class RewriteRuleExecutorTest { - @Mock - private RewriteRule rewriter; - @Mock - private SQLQueryExpr expr; + @Mock private RewriteRule rewriter; + @Mock private SQLQueryExpr expr; - private RewriteRuleExecutor ruleExecutor; + private RewriteRuleExecutor ruleExecutor; - @Before - public void setup() { - ruleExecutor = RewriteRuleExecutor.builder().withRule(rewriter).build(); - } + @Before + public void setup() { + ruleExecutor = RewriteRuleExecutor.builder().withRule(rewriter).build(); + } - @Test - public void optimize() throws SQLFeatureNotSupportedException { - when(rewriter.match(expr)).thenReturn(true); + @Test + public void optimize() throws SQLFeatureNotSupportedException { + when(rewriter.match(expr)).thenReturn(true); - ruleExecutor.executeOn(expr); - verify(rewriter, times(1)).rewrite(expr); - } + ruleExecutor.executeOn(expr); + verify(rewriter, times(1)).rewrite(expr); + } - @Test - public void noOptimize() throws SQLFeatureNotSupportedException { - when(rewriter.match(expr)).thenReturn(false); + @Test + public void noOptimize() throws SQLFeatureNotSupportedException { + when(rewriter.match(expr)).thenReturn(false); - ruleExecutor.executeOn(expr); - verify(rewriter, never()).rewrite(expr); - } + ruleExecutor.executeOn(expr); + verify(rewriter, never()).rewrite(expr); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/ordinal/OrdinalRewriterRuleTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/ordinal/OrdinalRewriterRuleTest.java index 3f4f799d66..d27967e361 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/ordinal/OrdinalRewriterRuleTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/ordinal/OrdinalRewriterRuleTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.rewriter.ordinal; import com.alibaba.druid.sql.SQLUtils; @@ -16,141 +15,139 @@ import org.opensearch.sql.legacy.rewriter.ordinal.OrdinalRewriterRule; import org.opensearch.sql.legacy.util.SqlParserUtils; -/** - * Test cases for ordinal aliases in GROUP BY and ORDER BY - */ - +/** Test cases for ordinal aliases in GROUP BY and ORDER BY */ public class OrdinalRewriterRuleTest { - @Rule - public ExpectedException exception = ExpectedException.none(); - - @Test - public void ordinalInGroupByShouldMatch() { - query("SELECT lastname FROM bank GROUP BY 1").shouldMatchRule(); - } - - @Test - public void ordinalInOrderByShouldMatch() { - query("SELECT lastname FROM bank ORDER BY 1").shouldMatchRule(); - } - - @Test - public void ordinalInGroupAndOrderByShouldMatch() { - query("SELECT lastname, age FROM bank GROUP BY 2, 1 ORDER BY 1").shouldMatchRule(); - } - - @Test - public void noOrdinalInGroupByShouldNotMatch() { - query("SELECT lastname FROM bank GROUP BY lastname").shouldNotMatchRule(); + @Rule public ExpectedException exception = ExpectedException.none(); + + @Test + public void ordinalInGroupByShouldMatch() { + query("SELECT lastname FROM bank GROUP BY 1").shouldMatchRule(); + } + + @Test + public void ordinalInOrderByShouldMatch() { + query("SELECT lastname FROM bank ORDER BY 1").shouldMatchRule(); + } + + @Test + public void ordinalInGroupAndOrderByShouldMatch() { + query("SELECT lastname, age FROM bank GROUP BY 2, 1 ORDER BY 1").shouldMatchRule(); + } + + @Test + public void noOrdinalInGroupByShouldNotMatch() { + query("SELECT lastname FROM bank GROUP BY lastname").shouldNotMatchRule(); + } + + @Test + public void noOrdinalInOrderByShouldNotMatch() { + query("SELECT lastname, age FROM bank ORDER BY age").shouldNotMatchRule(); + } + + @Test + public void noOrdinalInGroupAndOrderByShouldNotMatch() { + query("SELECT lastname, age FROM bank GROUP BY lastname, age ORDER BY age") + .shouldNotMatchRule(); + } + + @Test + public void simpleGroupByOrdinal() { + query("SELECT lastname FROM bank GROUP BY 1") + .shouldBeAfterRewrite("SELECT lastname FROM bank GROUP BY lastname"); + } + + @Test + public void multipleGroupByOrdinal() { + query("SELECT lastname, age FROM bank GROUP BY 1, 2 ") + .shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY lastname, age"); + + query("SELECT lastname, age FROM bank GROUP BY 2, 1") + .shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY age, lastname"); + + query("SELECT lastname, age, firstname FROM bank GROUP BY 2, firstname, 1") + .shouldBeAfterRewrite( + "SELECT lastname, age, firstname FROM bank GROUP BY age, firstname, lastname"); + + query("SELECT lastname, age, firstname FROM bank GROUP BY 2, something, 1") + .shouldBeAfterRewrite( + "SELECT lastname, age, firstname FROM bank GROUP BY age, something, lastname"); + } + + @Test + public void simpleOrderByOrdinal() { + query("SELECT lastname FROM bank ORDER BY 1") + .shouldBeAfterRewrite("SELECT lastname FROM bank ORDER BY lastname"); + } + + @Test + public void multipleOrderByOrdinal() { + query("SELECT lastname, age FROM bank ORDER BY 1, 2 ") + .shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY lastname, age"); + + query("SELECT lastname, age FROM bank ORDER BY 2, 1") + .shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY age, lastname"); + + query("SELECT lastname, age, firstname FROM bank ORDER BY 2, firstname, 1") + .shouldBeAfterRewrite( + "SELECT lastname, age, firstname FROM bank ORDER BY age, firstname, lastname"); + + query("SELECT lastname, age, firstname FROM bank ORDER BY 2, department, 1") + .shouldBeAfterRewrite( + "SELECT lastname, age, firstname FROM bank ORDER BY age, department, lastname"); + } + + // Tests invalid Ordinals, non-positive ordinal values are already validated by semantic analyzer + @Test + public void invalidGroupByOrdinalShouldThrowException() { + exception.expect(VerificationException.class); + exception.expectMessage("Invalid ordinal [3] specified in [GROUP BY 3]"); + query("SELECT lastname, MAX(lastname) FROM bank GROUP BY 3 ").rewrite(); + } + + @Test + public void invalidOrderByOrdinalShouldThrowException() { + exception.expect(VerificationException.class); + exception.expectMessage("Invalid ordinal [4] specified in [ORDER BY 4]"); + query("SELECT `lastname`, `age`, `firstname` FROM bank ORDER BY 4 IS NOT NULL").rewrite(); + } + + private QueryAssertion query(String sql) { + return new QueryAssertion(sql); + } + + private static class QueryAssertion { + + private OrdinalRewriterRule rule; + private SQLQueryExpr expr; + + QueryAssertion(String sql) { + this.expr = SqlParserUtils.parse(sql); + this.rule = new OrdinalRewriterRule(sql); } - @Test - public void noOrdinalInOrderByShouldNotMatch() { - query("SELECT lastname, age FROM bank ORDER BY age").shouldNotMatchRule(); + void shouldBeAfterRewrite(String expected) { + shouldMatchRule(); + rule.rewrite(expr); + Assert.assertEquals( + SQLUtils.toMySqlString(SqlParserUtils.parse(expected)), SQLUtils.toMySqlString(expr)); } - @Test - public void noOrdinalInGroupAndOrderByShouldNotMatch() { - query("SELECT lastname, age FROM bank GROUP BY lastname, age ORDER BY age").shouldNotMatchRule(); + void shouldMatchRule() { + Assert.assertTrue(match()); } - @Test - public void simpleGroupByOrdinal() { - query("SELECT lastname FROM bank GROUP BY 1" - ).shouldBeAfterRewrite("SELECT lastname FROM bank GROUP BY lastname"); + void shouldNotMatchRule() { + Assert.assertFalse(match()); } - @Test - public void multipleGroupByOrdinal() { - query("SELECT lastname, age FROM bank GROUP BY 1, 2 " - ).shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY lastname, age"); - - query("SELECT lastname, age FROM bank GROUP BY 2, 1" - ).shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY age, lastname"); - - query("SELECT lastname, age, firstname FROM bank GROUP BY 2, firstname, 1" - ).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank GROUP BY age, firstname, lastname"); - - query("SELECT lastname, age, firstname FROM bank GROUP BY 2, something, 1" - ).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank GROUP BY age, something, lastname"); - + void rewrite() { + shouldMatchRule(); + rule.rewrite(expr); } - @Test - public void simpleOrderByOrdinal() { - query("SELECT lastname FROM bank ORDER BY 1" - ).shouldBeAfterRewrite("SELECT lastname FROM bank ORDER BY lastname"); - } - - @Test - public void multipleOrderByOrdinal() { - query("SELECT lastname, age FROM bank ORDER BY 1, 2 " - ).shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY lastname, age"); - - query("SELECT lastname, age FROM bank ORDER BY 2, 1" - ).shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY age, lastname"); - - query("SELECT lastname, age, firstname FROM bank ORDER BY 2, firstname, 1" - ).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank ORDER BY age, firstname, lastname"); - - query("SELECT lastname, age, firstname FROM bank ORDER BY 2, department, 1" - ).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank ORDER BY age, department, lastname"); - } - - // Tests invalid Ordinals, non-positive ordinal values are already validated by semantic analyzer - @Test - public void invalidGroupByOrdinalShouldThrowException() { - exception.expect(VerificationException.class); - exception.expectMessage("Invalid ordinal [3] specified in [GROUP BY 3]"); - query("SELECT lastname, MAX(lastname) FROM bank GROUP BY 3 ").rewrite(); - } - - @Test - public void invalidOrderByOrdinalShouldThrowException() { - exception.expect(VerificationException.class); - exception.expectMessage("Invalid ordinal [4] specified in [ORDER BY 4]"); - query("SELECT `lastname`, `age`, `firstname` FROM bank ORDER BY 4 IS NOT NULL").rewrite(); - } - - - private QueryAssertion query(String sql) { - return new QueryAssertion(sql); - } - private static class QueryAssertion { - - private OrdinalRewriterRule rule; - private SQLQueryExpr expr; - - QueryAssertion(String sql) { - this.expr = SqlParserUtils.parse(sql); - this.rule = new OrdinalRewriterRule(sql); - } - - void shouldBeAfterRewrite(String expected) { - shouldMatchRule(); - rule.rewrite(expr); - Assert.assertEquals( - SQLUtils.toMySqlString(SqlParserUtils.parse(expected)), - SQLUtils.toMySqlString(expr) - ); - } - - void shouldMatchRule() { - Assert.assertTrue(match()); - } - - void shouldNotMatchRule() { - Assert.assertFalse(match()); - } - - void rewrite() { - shouldMatchRule(); - rule.rewrite(expr); - } - - private boolean match() { - return rule.match(expr); - } + private boolean match() { + return rule.match(expr); } + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/subquery/NestedQueryContextTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/subquery/NestedQueryContextTest.java index a94b3e6112..3e20e8edf6 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/subquery/NestedQueryContextTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/subquery/NestedQueryContextTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.rewriter.subquery; import static org.junit.Assert.assertFalse; @@ -16,42 +15,48 @@ import org.junit.Test; import org.opensearch.sql.legacy.rewriter.subquery.NestedQueryContext; - public class NestedQueryContextTest { - @Test - public void isNested() { - NestedQueryContext nestedQueryDetector = new NestedQueryContext(); - nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); - - assertFalse(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); - assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p"))); - - nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p")); - assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); - } - - @Test - public void isNestedJoin() { - NestedQueryContext nestedQueryDetector = new NestedQueryContext(); - SQLJoinTableSource joinTableSource = new SQLJoinTableSource(); - joinTableSource.setLeft(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); - joinTableSource.setRight(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p")); - joinTableSource.setJoinType(JoinType.COMMA); - nestedQueryDetector.add(joinTableSource); - - assertFalse(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); - assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p"))); - assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); - } - - @Test - public void notNested() { - NestedQueryContext nestedQueryDetector = new NestedQueryContext(); - nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); - nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("projects"), "p")); - - assertFalse(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); - assertFalse(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); - } + @Test + public void isNested() { + NestedQueryContext nestedQueryDetector = new NestedQueryContext(); + nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); + + assertFalse( + nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); + assertTrue( + nestedQueryDetector.isNested( + new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p"))); + + nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p")); + assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); + } + + @Test + public void isNestedJoin() { + NestedQueryContext nestedQueryDetector = new NestedQueryContext(); + SQLJoinTableSource joinTableSource = new SQLJoinTableSource(); + joinTableSource.setLeft(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); + joinTableSource.setRight(new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p")); + joinTableSource.setJoinType(JoinType.COMMA); + nestedQueryDetector.add(joinTableSource); + + assertFalse( + nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); + assertTrue( + nestedQueryDetector.isNested( + new SQLExprTableSource(new SQLIdentifierExpr("e.projects"), "p"))); + assertTrue(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); + } + + @Test + public void notNested() { + NestedQueryContext nestedQueryDetector = new NestedQueryContext(); + nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("employee"), "e")); + nestedQueryDetector.add(new SQLExprTableSource(new SQLIdentifierExpr("projects"), "p")); + + assertFalse( + nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("e"), "e1"))); + assertFalse(nestedQueryDetector.isNested(new SQLExprTableSource(new SQLIdentifierExpr("p")))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/PrettyFormatterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/PrettyFormatterTest.java index f876b14110..68ad891020 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/PrettyFormatterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/PrettyFormatterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.utils; import static org.hamcrest.MatcherAssert.assertThat; @@ -19,42 +18,45 @@ public class PrettyFormatterTest { - @Test - public void assertFormatterWithoutContentInside() throws IOException { - String noContentInput = "{ }"; - String expectedOutput = "{ }"; - String result = JsonPrettyFormatter.format(noContentInput); - assertThat(result, equalTo(expectedOutput)); - } - - @Test - public void assertFormatterOutputsPrettyJson() throws IOException { - String explainFormattedPrettyFilePath = TestUtils.getResourceFilePath( - "/src/test/resources/expectedOutput/explain_format_pretty.json"); - String explainFormattedPretty = Files.toString(new File(explainFormattedPrettyFilePath), StandardCharsets.UTF_8) - .replaceAll("\r", ""); - - String explainFormattedOnelineFilePath = TestUtils.getResourceFilePath( - "/src/test/resources/explain_format_oneline.json"); - String explainFormattedOneline = Files.toString(new File(explainFormattedOnelineFilePath), StandardCharsets.UTF_8) - .replaceAll("\r", ""); - String result = JsonPrettyFormatter.format(explainFormattedOneline); - - assertThat(result, equalTo(explainFormattedPretty)); - } - - @Test(expected = IOException.class) - public void illegalInputOfNull() throws IOException { - JsonPrettyFormatter.format(""); - } - - @Test(expected = IOException.class) - public void illegalInputOfUnpairedBrace() throws IOException { - JsonPrettyFormatter.format("{\"key\" : \"value\""); - } - - @Test(expected = IOException.class) - public void illegalInputOfWrongBraces() throws IOException { - JsonPrettyFormatter.format("<\"key\" : \"value\">"); - } + @Test + public void assertFormatterWithoutContentInside() throws IOException { + String noContentInput = "{ }"; + String expectedOutput = "{ }"; + String result = JsonPrettyFormatter.format(noContentInput); + assertThat(result, equalTo(expectedOutput)); + } + + @Test + public void assertFormatterOutputsPrettyJson() throws IOException { + String explainFormattedPrettyFilePath = + TestUtils.getResourceFilePath( + "/src/test/resources/expectedOutput/explain_format_pretty.json"); + String explainFormattedPretty = + Files.toString(new File(explainFormattedPrettyFilePath), StandardCharsets.UTF_8) + .replaceAll("\r", ""); + + String explainFormattedOnelineFilePath = + TestUtils.getResourceFilePath("/src/test/resources/explain_format_oneline.json"); + String explainFormattedOneline = + Files.toString(new File(explainFormattedOnelineFilePath), StandardCharsets.UTF_8) + .replaceAll("\r", ""); + String result = JsonPrettyFormatter.format(explainFormattedOneline); + + assertThat(result, equalTo(explainFormattedPretty)); + } + + @Test(expected = IOException.class) + public void illegalInputOfNull() throws IOException { + JsonPrettyFormatter.format(""); + } + + @Test(expected = IOException.class) + public void illegalInputOfUnpairedBrace() throws IOException { + JsonPrettyFormatter.format("{\"key\" : \"value\""); + } + + @Test(expected = IOException.class) + public void illegalInputOfWrongBraces() throws IOException { + JsonPrettyFormatter.format("<\"key\" : \"value\">"); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryContextTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryContextTest.java index 55b78af0d7..5dbda8cb92 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryContextTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryContextTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.utils; import static org.hamcrest.Matchers.equalTo; @@ -18,56 +17,57 @@ public class QueryContextTest { - private static final String REQUEST_ID_KEY = "request_id"; + private static final String REQUEST_ID_KEY = "request_id"; - @After - public void cleanUpContext() { + @After + public void cleanUpContext() { - ThreadContext.clearMap(); - } + ThreadContext.clearMap(); + } - @Test - public void addRequestId() { + @Test + public void addRequestId() { - Assert.assertNull(ThreadContext.get(REQUEST_ID_KEY)); - QueryContext.addRequestId(); - final String requestId = ThreadContext.get(REQUEST_ID_KEY); - Assert.assertNotNull(requestId); - } + Assert.assertNull(ThreadContext.get(REQUEST_ID_KEY)); + QueryContext.addRequestId(); + final String requestId = ThreadContext.get(REQUEST_ID_KEY); + Assert.assertNotNull(requestId); + } - @Test - public void addRequestId_alreadyExists() { + @Test + public void addRequestId_alreadyExists() { - QueryContext.addRequestId(); - final String requestId = ThreadContext.get(REQUEST_ID_KEY); - QueryContext.addRequestId(); - final String requestId2 = ThreadContext.get(REQUEST_ID_KEY); - Assert.assertThat(requestId2, not(equalTo(requestId))); - } + QueryContext.addRequestId(); + final String requestId = ThreadContext.get(REQUEST_ID_KEY); + QueryContext.addRequestId(); + final String requestId2 = ThreadContext.get(REQUEST_ID_KEY); + Assert.assertThat(requestId2, not(equalTo(requestId))); + } - @Test - public void getRequestId_doesNotExist() { - assertNotNull(QueryContext.getRequestId()); - } + @Test + public void getRequestId_doesNotExist() { + assertNotNull(QueryContext.getRequestId()); + } - @Test - public void getRequestId() { + @Test + public void getRequestId() { - final String test_request_id = "test_id_111"; - ThreadContext.put(REQUEST_ID_KEY, test_request_id); - final String requestId = QueryContext.getRequestId(); - Assert.assertThat(requestId, equalTo(test_request_id)); - } + final String test_request_id = "test_id_111"; + ThreadContext.put(REQUEST_ID_KEY, test_request_id); + final String requestId = QueryContext.getRequestId(); + Assert.assertThat(requestId, equalTo(test_request_id)); + } - @Test - public void withCurrentContext() throws InterruptedException { + @Test + public void withCurrentContext() throws InterruptedException { - Runnable task = () -> { - Assert.assertTrue(ThreadContext.containsKey("test11")); - Assert.assertTrue(ThreadContext.containsKey("test22")); + Runnable task = + () -> { + Assert.assertTrue(ThreadContext.containsKey("test11")); + Assert.assertTrue(ThreadContext.containsKey("test22")); }; - ThreadContext.put("test11", "value11"); - ThreadContext.put("test22", "value11"); - new Thread(QueryContext.withCurrentContext(task)).join(); - } + ThreadContext.put("test11", "value11"); + ThreadContext.put("test22", "value11"); + new Thread(QueryContext.withCurrentContext(task)).join(); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryDataAnonymizerTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryDataAnonymizerTest.java index ca95b547a9..073fec61e7 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryDataAnonymizerTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/QueryDataAnonymizerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.utils; import org.junit.Assert; @@ -12,78 +11,84 @@ public class QueryDataAnonymizerTest { - @Test - public void queriesShouldHaveAnonymousFieldAndIndex() { - String query = "SELECT ABS(balance) FROM accounts WHERE age > 30 GROUP BY ABS(balance)"; - String expectedQuery = "( SELECT ABS(identifier) FROM table WHERE identifier > number GROUP BY ABS(identifier) )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesShouldHaveAnonymousFieldAndIndex() { + String query = "SELECT ABS(balance) FROM accounts WHERE age > 30 GROUP BY ABS(balance)"; + String expectedQuery = + "( SELECT ABS(identifier) FROM table WHERE identifier > number GROUP BY ABS(identifier) )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesShouldAnonymousNumbers() { - String query = "SELECT ABS(20), LOG(20.20) FROM accounts"; - String expectedQuery = "( SELECT ABS(number), LOG(number) FROM table )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesShouldAnonymousNumbers() { + String query = "SELECT ABS(20), LOG(20.20) FROM accounts"; + String expectedQuery = "( SELECT ABS(number), LOG(number) FROM table )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesShouldHaveAnonymousBooleanLiterals() { - String query = "SELECT TRUE FROM accounts"; - String expectedQuery = "( SELECT boolean_literal FROM table )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesShouldHaveAnonymousBooleanLiterals() { + String query = "SELECT TRUE FROM accounts"; + String expectedQuery = "( SELECT boolean_literal FROM table )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesShouldHaveAnonymousInputStrings() { - String query = "SELECT * FROM accounts WHERE name = 'Oliver'"; - String expectedQuery = "( SELECT * FROM table WHERE identifier = 'string_literal' )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesShouldHaveAnonymousInputStrings() { + String query = "SELECT * FROM accounts WHERE name = 'Oliver'"; + String expectedQuery = "( SELECT * FROM table WHERE identifier = 'string_literal' )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesWithAliasesShouldAnonymizeSensitiveData() { - String query = "SELECT balance AS b FROM accounts AS a"; - String expectedQuery = "( SELECT identifier AS b FROM table a )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesWithAliasesShouldAnonymizeSensitiveData() { + String query = "SELECT balance AS b FROM accounts AS a"; + String expectedQuery = "( SELECT identifier AS b FROM table a )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesWithFunctionsShouldAnonymizeSensitiveData() { - String query = "SELECT LTRIM(firstname) FROM accounts"; - String expectedQuery = "( SELECT LTRIM(identifier) FROM table )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesWithFunctionsShouldAnonymizeSensitiveData() { + String query = "SELECT LTRIM(firstname) FROM accounts"; + String expectedQuery = "( SELECT LTRIM(identifier) FROM table )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesWithAggregatesShouldAnonymizeSensitiveData() { - String query = "SELECT MAX(price) - MIN(price) from tickets"; - String expectedQuery = "( SELECT MAX(identifier) - MIN(identifier) FROM table )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesWithAggregatesShouldAnonymizeSensitiveData() { + String query = "SELECT MAX(price) - MIN(price) from tickets"; + String expectedQuery = "( SELECT MAX(identifier) - MIN(identifier) FROM table )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void queriesWithSubqueriesShouldAnonymizeSensitiveData() { - String query = "SELECT a.f, a.l, a.a FROM " + - "(SELECT firstname AS f, lastname AS l, age AS a FROM accounts WHERE age > 30) a"; - String expectedQuery = "( SELECT identifier, identifier, identifier FROM (SELECT identifier AS f, " + - "identifier AS l, identifier AS a FROM table WHERE identifier > number ) a )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void queriesWithSubqueriesShouldAnonymizeSensitiveData() { + String query = + "SELECT a.f, a.l, a.a FROM " + + "(SELECT firstname AS f, lastname AS l, age AS a FROM accounts WHERE age > 30) a"; + String expectedQuery = + "( SELECT identifier, identifier, identifier FROM (SELECT identifier AS f, " + + "identifier AS l, identifier AS a FROM table WHERE identifier > number ) a )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void joinQueriesShouldAnonymizeSensitiveData() { - String query = "SELECT a.account_number, a.firstname, a.lastname, e.id, e.name " + - "FROM accounts a JOIN employees e"; - String expectedQuery = "( SELECT identifier, identifier, identifier, identifier, identifier " + - "FROM table a JOIN table e )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void joinQueriesShouldAnonymizeSensitiveData() { + String query = + "SELECT a.account_number, a.firstname, a.lastname, e.id, e.name " + + "FROM accounts a JOIN employees e"; + String expectedQuery = + "( SELECT identifier, identifier, identifier, identifier, identifier " + + "FROM table a JOIN table e )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } - @Test - public void unionQueriesShouldAnonymizeSensitiveData() { - String query = "SELECT name, age FROM accounts UNION SELECT name, age FROM employees"; - String expectedQuery = "( SELECT identifier, identifier FROM table " + - "UNION SELECT identifier, identifier FROM table )"; - Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); - } + @Test + public void unionQueriesShouldAnonymizeSensitiveData() { + String query = "SELECT name, age FROM accounts UNION SELECT name, age FROM employees"; + String expectedQuery = + "( SELECT identifier, identifier FROM table " + + "UNION SELECT identifier, identifier FROM table )"; + Assert.assertEquals(expectedQuery, QueryDataAnonymizer.anonymizeData(query)); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java index cd915cf5e5..1a15e57c55 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java @@ -36,7 +36,7 @@ class OpenSearchAggregationResponseParserTest { /** SELECT MAX(age) as max FROM accounts. */ @Test void no_bucket_one_metric_should_pass() { - String response = "{\n \"max#max\": {\n \"value\": 40\n }\n}"; + String response = "{\n" + " \"max#max\": {\n" + " \"value\": 40\n" + " }\n" + "}"; NoBucketAggregationParser parser = new NoBucketAggregationParser(new SingleValueParser("max")); assertThat(parse(parser, response), contains(entry("max", 40d))); } @@ -140,7 +140,8 @@ void two_bucket_one_metric_should_pass() { @Test void unsupported_aggregation_should_fail() { - String response = "{\n \"date_histogram#date_histogram\": {\n \"value\": 40\n }\n}"; + String response = + "{\n" + " \"date_histogram#date_histogram\": {\n" + " \"value\": 40\n" + " }\n" + "}"; NoBucketAggregationParser parser = new NoBucketAggregationParser(new SingleValueParser("max")); RuntimeException exception = assertThrows(RuntimeException.class, () -> parse(parser, response));