Skip to content

Commit

Permalink
Add other functions to SQL query validator (#3304)
Browse files Browse the repository at this point in the history
* Add uncategorized functions to SQL query validator

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Fix variable name

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Fix name from uncategorized to other

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

---------

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
(cherry picked from commit 4b60ab6)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Feb 7, 2025
1 parent d6e08cc commit 8491ac4
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public enum FunctionType {
CSV("CSV"),
MISC("Misc"),
GENERATOR("Generator"),
OTHER("Other"),
UDF("User Defined Function");

private final String name;
Expand Down Expand Up @@ -422,6 +423,51 @@ public enum FunctionType {
"posexplode",
"posexplode_outer",
"stack"))
.put(
OTHER,
Set.of(
"aggregate",
"array_size",
"array_sort",
"cardinality",
"crc32",
"exists",
"filter",
"forall",
"hash",
"ilike",
"in",
"like",
"map_filter",
"map_zip_with",
"md5",
"mod",
"named_struct",
"parse_url",
"raise_error",
"reduce",
"reverse",
"sha",
"sha1",
"sha2",
"size",
"struct",
"transform",
"transform_keys",
"transform_values",
"url_decode",
"url_encode",
"xpath",
"xpath_boolean",
"xpath_double",
"xpath_float",
"xpath_int",
"xpath_long",
"xpath_number",
"xpath_short",
"xpath_string",
"xxhash64",
"zip_with"))
.build();

private static final Map<String, FunctionType> FUNCTION_NAME_TO_FUNCTION_TYPE_MAP =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public enum SQLGrammarElement implements GrammarElement {
CSV_FUNCTIONS("CSV functions"),
GENERATOR_FUNCTIONS("Generator functions"),
MISC_FUNCTIONS("Misc functions"),
OTHER_FUNCTIONS("Other functions"),

// UDF
UDF("User Defined functions");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,26 +560,30 @@ public Void visitFunctionName(FunctionNameContext ctx) {
return super.visitFunctionName(ctx);
}

private void validateFunctionAllowed(String function) {
FunctionType type = FunctionType.fromFunctionName(function.toLowerCase());
private void validateFunctionAllowed(String functionName) {
String lowerCaseFunctionName = functionName.toLowerCase();
FunctionType type = FunctionType.fromFunctionName(lowerCaseFunctionName);
switch (type) {
case MAP:
validateAllowed(SQLGrammarElement.MAP_FUNCTIONS);
validateAllowed(SQLGrammarElement.MAP_FUNCTIONS, lowerCaseFunctionName);
break;
case BITWISE:
validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS);
validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS, lowerCaseFunctionName);
break;
case CSV:
validateAllowed(SQLGrammarElement.CSV_FUNCTIONS);
validateAllowed(SQLGrammarElement.CSV_FUNCTIONS, lowerCaseFunctionName);
break;
case MISC:
validateAllowed(SQLGrammarElement.MISC_FUNCTIONS);
validateAllowed(SQLGrammarElement.MISC_FUNCTIONS, lowerCaseFunctionName);
break;
case GENERATOR:
validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS);
validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS, lowerCaseFunctionName);
break;
case OTHER:
validateAllowed(SQLGrammarElement.OTHER_FUNCTIONS, lowerCaseFunctionName);
break;
case UDF:
validateAllowed(SQLGrammarElement.UDF);
validateAllowed(SQLGrammarElement.UDF, lowerCaseFunctionName);
break;
}
}
Expand All @@ -590,6 +594,12 @@ private void validateAllowed(SQLGrammarElement element) {
}
}

private void validateAllowed(SQLGrammarElement element, String detail) {
if (!grammarElementValidator.isValid(element)) {
throw new IllegalArgumentException(String.format("%s (%s) is not allowed.", element, detail));
}
}

@Override
public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) {
ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void testExtractionFromFlintSkippingIndexQueries() {
+ " WHERE elb_status_code = 500 "
+ " WITH (auto_refresh = true)",
"DROP SKIPPING INDEX ON myS3.default.alb_logs",
"ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)",
"ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)"
};

for (String query : createSkippingIndexQueries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public void test() {
assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version"));
assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode"));
assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack"));
assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("aggregate"));
assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("forall"));
assertEquals(FunctionType.UDF, FunctionType.fromFunctionName("unknown"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.validator;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -192,6 +193,10 @@ private enum TestElement {
// Generator Functions
GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"),

// Other functions
NAMED_STRUCT("SELECT named_struct('a', 1);"),
PARSE_URL("SELECT parse_url(url) FROM my_table;"),

// UDFs (User-Defined Functions)
SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"),
USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"),
Expand Down Expand Up @@ -323,6 +328,10 @@ void testDenyAllValidator() {
// Generator Functions
v.ng(TestElement.GENERATOR_FUNCTIONS);

// Other Functions
v.ng(TestElement.NAMED_STRUCT);
v.ng(TestElement.PARSE_URL);

// UDFs
v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS);
v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS);
Expand Down Expand Up @@ -440,6 +449,10 @@ void testS3glueQueries() {
// Generator Functions
v.ok(TestElement.GENERATOR_FUNCTIONS);

// Other Functions
v.ok(TestElement.NAMED_STRUCT);
v.ok(TestElement.PARSE_URL);

// UDFs
v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS);
v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS);
Expand Down Expand Up @@ -621,6 +634,14 @@ void testUnsupportedHiveNativeCommand() {
v.ng("DFS");
}

@Test
void testException() {
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false);
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE);

v.ng("SELECT named_struct('a', 1);", "Other functions (named_struct) is not allowed.");
}

@AllArgsConstructor
private static class VerifyValidator {
private final SQLQueryValidator validator;
Expand All @@ -645,6 +666,15 @@ public void ng(String query) {
"The query should throw: query=`" + query.toString() + "`");
}

public void ng(String query, String expectedMessage) {
Exception e =
assertThrows(
IllegalArgumentException.class,
() -> runValidate(query),
"The query should throw: query=`" + query.toString() + "`");
assertEquals(expectedMessage, e.getMessage());
}

void runValidate(String[] queries) {
Arrays.stream(queries).forEach(query -> validator.validate(query, dataSourceType));
}
Expand Down

0 comments on commit 8491ac4

Please sign in to comment.