diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java index da3760efd6..71db63dbcc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -32,6 +32,7 @@ public enum FunctionType { CSV("CSV"), MISC("Misc"), GENERATOR("Generator"), + OTHER("Other"), UDF("User Defined Function"); private final String name; @@ -422,6 +423,51 @@ public enum FunctionType { "posexplode", "posexplode_outer", "stack")) + .put( + OTHER, + Set.of( + "aggregate", + "array_size", + "array_sort", + "cardinality", + "crc32", + "exists", + "filter", + "forall", + "hash", + "ilike", + "in", + "like", + "map_filter", + "map_zip_with", + "md5", + "mod", + "named_struct", + "parse_url", + "raise_error", + "reduce", + "reverse", + "sha", + "sha1", + "sha2", + "size", + "struct", + "transform", + "transform_keys", + "transform_values", + "url_decode", + "url_encode", + "xpath", + "xpath_boolean", + "xpath_double", + "xpath_float", + "xpath_int", + "xpath_long", + "xpath_number", + "xpath_short", + "xpath_string", + "xxhash64", + "zip_with")) .build(); private static final Map FUNCTION_NAME_TO_FUNCTION_TYPE_MAP = diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLGrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLGrammarElement.java index ef3e1f2c8c..9cabfd7d9e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLGrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLGrammarElement.java @@ -78,6 +78,7 @@ public enum SQLGrammarElement implements GrammarElement { CSV_FUNCTIONS("CSV functions"), GENERATOR_FUNCTIONS("Generator functions"), MISC_FUNCTIONS("Misc functions"), + OTHER_FUNCTIONS("Other functions"), // UDF UDF("User Defined functions"); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index 10fc48727a..b336ef4605 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -560,26 +560,30 @@ public Void visitFunctionName(FunctionNameContext ctx) { return super.visitFunctionName(ctx); } - private void validateFunctionAllowed(String function) { - FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); + private void validateFunctionAllowed(String functionName) { + String lowerCaseFunctionName = functionName.toLowerCase(); + FunctionType type = FunctionType.fromFunctionName(lowerCaseFunctionName); switch (type) { case MAP: - validateAllowed(SQLGrammarElement.MAP_FUNCTIONS); + validateAllowed(SQLGrammarElement.MAP_FUNCTIONS, lowerCaseFunctionName); break; case BITWISE: - validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS); + validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS, lowerCaseFunctionName); break; case CSV: - validateAllowed(SQLGrammarElement.CSV_FUNCTIONS); + validateAllowed(SQLGrammarElement.CSV_FUNCTIONS, lowerCaseFunctionName); break; case MISC: - validateAllowed(SQLGrammarElement.MISC_FUNCTIONS); + validateAllowed(SQLGrammarElement.MISC_FUNCTIONS, lowerCaseFunctionName); break; case GENERATOR: - validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS); + validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS, lowerCaseFunctionName); + break; + case OTHER: + validateAllowed(SQLGrammarElement.OTHER_FUNCTIONS, lowerCaseFunctionName); break; case UDF: - validateAllowed(SQLGrammarElement.UDF); + validateAllowed(SQLGrammarElement.UDF, lowerCaseFunctionName); break; } } @@ -590,6 +594,12 @@ private void validateAllowed(SQLGrammarElement element) { } } + private void validateAllowed(SQLGrammarElement element, String detail) { + if (!grammarElementValidator.isValid(element)) { + throw new IllegalArgumentException(String.format("%s (%s) is not allowed.", element, detail)); + } + } + @Override public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) { ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 881ad0e56a..276bce3b91 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -140,7 +140,7 @@ void testExtractionFromFlintSkippingIndexQueries() { + " WHERE elb_status_code = 500 " + " WITH (auto_refresh = true)", "DROP SKIPPING INDEX ON myS3.default.alb_logs", - "ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)", + "ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)" }; for (String query : createSkippingIndexQueries) { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java index a5f868421c..f661f5feac 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java @@ -42,6 +42,8 @@ public void test() { assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version")); assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode")); assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack")); + assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("aggregate")); + assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("forall")); assertEquals(FunctionType.UDF, FunctionType.fromFunctionName("unknown")); } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index ad73daa37f..3cfc33a5b1 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.validator; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -192,6 +193,10 @@ private enum TestElement { // Generator Functions GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"), + // Other functions + NAMED_STRUCT("SELECT named_struct('a', 1);"), + PARSE_URL("SELECT parse_url(url) FROM my_table;"), + // UDFs (User-Defined Functions) SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"), USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"), @@ -323,6 +328,10 @@ void testDenyAllValidator() { // Generator Functions v.ng(TestElement.GENERATOR_FUNCTIONS); + // Other Functions + v.ng(TestElement.NAMED_STRUCT); + v.ng(TestElement.PARSE_URL); + // UDFs v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); @@ -440,6 +449,10 @@ void testS3glueQueries() { // Generator Functions v.ok(TestElement.GENERATOR_FUNCTIONS); + // Other Functions + v.ok(TestElement.NAMED_STRUCT); + v.ok(TestElement.PARSE_URL); + // UDFs v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); @@ -621,6 +634,14 @@ void testUnsupportedHiveNativeCommand() { v.ng("DFS"); } + @Test + void testException() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); + + v.ng("SELECT named_struct('a', 1);", "Other functions (named_struct) is not allowed."); + } + @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; @@ -645,6 +666,15 @@ public void ng(String query) { "The query should throw: query=`" + query.toString() + "`"); } + public void ng(String query, String expectedMessage) { + Exception e = + assertThrows( + IllegalArgumentException.class, + () -> runValidate(query), + "The query should throw: query=`" + query.toString() + "`"); + assertEquals(expectedMessage, e.getMessage()); + } + void runValidate(String[] queries) { Arrays.stream(queries).forEach(query -> validator.validate(query, dataSourceType)); }