diff --git a/src/jmh/java/io/r2dbc/postgresql/PostgresqlSqlParserBenchmarks.java b/src/jmh/java/io/r2dbc/postgresql/PostgresqlSqlParserBenchmarks.java new file mode 100644 index 000000000..286ba905a --- /dev/null +++ b/src/jmh/java/io/r2dbc/postgresql/PostgresqlSqlParserBenchmarks.java @@ -0,0 +1,56 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import org.junit.platform.commons.annotation.Testable; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.concurrent.TimeUnit; + +/** + * Benchmarks for {@link PostgresqlSqlParser}. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Testable +public class PostgresqlSqlParserBenchmarks extends BenchmarkSettings { + + @Benchmark + public void simpleStatement(Blackhole blackhole) { + blackhole.consume(PostgresqlSqlParser.parse("SELECT * FROM FOO")); + } + + @Benchmark + public void parametrizedStatement(Blackhole blackhole) { + blackhole.consume(PostgresqlSqlParser.parse("SELECT * FROM FOO WHERE $2 = $1")); + } + + @Benchmark + public void createOrReplaceFunction(Blackhole blackhole) { + blackhole.consume(PostgresqlSqlParser.parse("CREATE OR REPLACE FUNCTION asterisks(n int)\n" + + " RETURNS SETOF text\n" + + " LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE\n" + + "BEGIN ATOMIC\n" + + "SELECT repeat('*', g) FROM generate_series (1, n) g; -- <-- Note this semicolon\n" + + "END;")); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/ParsedSql.java b/src/main/java/io/r2dbc/postgresql/ParsedSql.java index 3480d6b97..b7bec1603 100644 --- a/src/main/java/io/r2dbc/postgresql/ParsedSql.java +++ b/src/main/java/io/r2dbc/postgresql/ParsedSql.java @@ -50,12 +50,12 @@ public int getParameterCount() { } public String getSql() { - return sql; + return this.sql; } private static int getParameterCount(List statements) { int sum = 0; - for (Statement statement : statements){ + for (Statement statement : statements) { sum += statement.getParameterCount(); } return sum; diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java index b9884a370..ef487ac73 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java @@ -16,8 +16,11 @@ package io.r2dbc.postgresql; +import io.netty.util.collection.CharObjectHashMap; +import io.netty.util.collection.CharObjectMap; + import java.util.ArrayList; -import java.util.Arrays; +import java.util.LinkedList; import java.util.List; import static java.lang.Character.isWhitespace; @@ -29,13 +32,79 @@ */ class PostgresqlSqlParser { - private static final char[] SPECIAL_AND_OPERATOR_CHARS = { - '+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?', - '(', ')', '[', ']', ',', ';', ':', '*', '.', '\'', '"' - }; + private static final CharObjectMap SPECIAL_AND_OPERATOR_CHARS = new CharObjectHashMap<>(); static { - Arrays.sort(SPECIAL_AND_OPERATOR_CHARS); + char[] specialCharsAndOperators = {'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?', + '(', ')', '[', ']', ',', ';', ':', '*', '.', '\'', '"'}; + + for (char c : specialCharsAndOperators) { + SPECIAL_AND_OPERATOR_CHARS.put(c, new Object()); + } + } + + public static ParsedSql parse(String sql) { + List tokens = tokenize(sql); + List statements = new ArrayList<>(); + LinkedList functionBodyList = null; + + List currentStatementTokens = new ArrayList<>(tokens.size()); + + for (int i = 0; i < tokens.size(); i++) { + ParsedSql.Token current = tokens.get(i); + currentStatementTokens.add(current); + + if (current.getType() == ParsedSql.TokenType.DEFAULT) { + String currentValue = current.getValue(); + + if (currentValue.equalsIgnoreCase("BEGIN")) { + if (functionBodyList == null) { + functionBodyList = new LinkedList<>(); + } + if (hasNextToken(tokens, i) && peekNext(tokens, i).getValue().equalsIgnoreCase("ATOMIC")) { + functionBodyList.add(true); + } else { + functionBodyList.add(false); + } + } else if (currentValue.equalsIgnoreCase("END") && functionBodyList != null && !functionBodyList.isEmpty()) { + functionBodyList.removeLast(); + } + } else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) { + boolean inFunctionBody = false; + + if (functionBodyList != null) { + for (boolean b : functionBodyList) { + inFunctionBody |= b; + } + } + if (!inFunctionBody) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); + currentStatementTokens = new ArrayList<>(); + } + } + } + + if (!currentStatementTokens.isEmpty()) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); + } + + return new ParsedSql(sql, statements); + } + + private static ParsedSql.Token peekNext(List tokens, int index) { + return tokens.get(index + 1); + } + + private static boolean hasNextToken(List tokens, int index) { + return tokens.size() > index + 1; + } + + private static char peekNext(CharSequence sequence, int index) { + return sequence.charAt(index + 1); + } + + private static boolean hasNextToken(CharSequence sequence, int index) { + return sequence.length() > index + 1; } private static List tokenize(String sql) { @@ -57,12 +126,12 @@ private static List tokenize(String sql) { token = getQuotedIdentifierToken(sql, i); break; case '-': // Possible start of double-dash comment - if ((i + 1) < sql.length() && sql.charAt(i + 1) == '-') { + if (hasNextToken(sql, i) && peekNext(sql, i) == '-') { token = getCommentToLineEndToken(sql, i); } break; case '/': // Possible start of c-style comment - if ((i + 1) < sql.length() && sql.charAt(i + 1) == '*') { + if (hasNextToken(sql, i) && peekNext(sql, i) == '*') { token = getBlockCommentToken(sql, i); } break; @@ -89,48 +158,6 @@ private static List tokenize(String sql) { return tokens; } - public static ParsedSql parse(String sql) { - List tokens = tokenize(sql); - List statements = new ArrayList<>(); - List functionBodyList = new ArrayList<>(); - - List currentStatementTokens = new ArrayList<>(); - for (int i = 0; i < tokens.size(); i++) { - ParsedSql.Token current = tokens.get(i); - currentStatementTokens.add(current); - - if (current.getType() == ParsedSql.TokenType.DEFAULT) { - String currentValue = current.getValue(); - - if (currentValue.equalsIgnoreCase("BEGIN")) { - if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) { - functionBodyList.add(true); - } else { - functionBodyList.add(false); - } - } else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) { - functionBodyList.remove(functionBodyList.size() - 1); - } - } else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) { - boolean inFunctionBody = false; - - for (boolean b : functionBodyList) { - inFunctionBody |= b; - } - if (!inFunctionBody) { - statements.add(new ParsedSql.Statement(currentStatementTokens)); - currentStatementTokens = new ArrayList<>(); - } - } - } - - if (!currentStatementTokens.isEmpty()) { - statements.add(new ParsedSql.Statement(currentStatementTokens)); - } - - return new ParsedSql(sql, statements); - } - private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) { for (int i = beginIndex + 1; i < sql.length(); i++) { char c = sql.charAt(i); @@ -142,7 +169,7 @@ private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) { } private static boolean isSpecialOrOperatorChar(char c) { - return Arrays.binarySearch(SPECIAL_AND_OPERATOR_CHARS, c) >= 0; + return SPECIAL_AND_OPERATOR_CHARS.containsKey(c); } private static ParsedSql.Token getBlockCommentToken(String sql, int beginIndex) { diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTests.java similarity index 74% rename from src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java rename to src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTests.java index ad1008b98..d0e3376f1 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTests.java @@ -22,11 +22,13 @@ import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertIterableEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -class PostgresqlSqlParserTest { +/** + * Unit tests for {@link PostgresqlSqlParser}. + */ +class PostgresqlSqlParserTests { @Nested class SingleStatementTests { @@ -36,77 +38,77 @@ class SingleTokenTests { @Test void singleQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("'Test'", ParsedSql.TokenType.STRING_CONSTANT); + assertTokenEquals("'Test'", ParsedSql.TokenType.STRING_CONSTANT); } @Test void dollarQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$$test$$", ParsedSql.TokenType.STRING_CONSTANT); + assertTokenEquals("$$test$$", ParsedSql.TokenType.STRING_CONSTANT); } @Test void dollarQuotedTaggedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$a$test$a$", ParsedSql.TokenType.STRING_CONSTANT); + assertTokenEquals("$a$test$a$", ParsedSql.TokenType.STRING_CONSTANT); } @Test void quotedIdentifierIsTokenized() { - assertSingleStatementEqualsCompleteToken("\"test\"", ParsedSql.TokenType.QUOTED_IDENTIFIER); + assertTokenEquals("\"test\"", ParsedSql.TokenType.QUOTED_IDENTIFIER); } @Test void lineCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("--test", ParsedSql.TokenType.COMMENT); + assertTokenEquals("--test", ParsedSql.TokenType.COMMENT); } @Test void cStyleCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("/*Test*/", ParsedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/**/", ParsedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/*T*/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/*Test*/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/**/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/*T*/", ParsedSql.TokenType.COMMENT); } @Test void nestedCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*/*Test*/*/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/*/*Test*/*/", ParsedSql.TokenType.COMMENT); } @Test void windowsMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\r\n Test*/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/*Test\r\n Test*/", ParsedSql.TokenType.COMMENT); } @Test void unixMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\n Test*/", ParsedSql.TokenType.COMMENT); + assertTokenEquals("/*Test\n Test*/", ParsedSql.TokenType.COMMENT); } @Test void digitIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("1", ParsedSql.TokenType.DEFAULT); + assertTokenEquals("1", ParsedSql.TokenType.DEFAULT); } @Test void alphaIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("a", ParsedSql.TokenType.DEFAULT); + assertTokenEquals("a", ParsedSql.TokenType.DEFAULT); } @Test void multipleDefaultTokensAreTokenizedAsSingleDefaultToken() { - assertSingleStatementEqualsCompleteToken("atest123", ParsedSql.TokenType.DEFAULT); + assertTokenEquals("atest123", ParsedSql.TokenType.DEFAULT); } @Test void parameterIsTokenized() { - assertSingleStatementEqualsCompleteToken("$1", ParsedSql.TokenType.PARAMETER); + assertTokenEquals("$1", ParsedSql.TokenType.PARAMETER); } @Test void statementEndIsTokenized() { - assertSingleStatementEqualsCompleteToken(";", ParsedSql.TokenType.STATEMENT_END); + assertTokenEquals(";", ParsedSql.TokenType.STATEMENT_END); } - void assertSingleStatementEqualsCompleteToken(String sql, ParsedSql.TokenType token) { + void assertTokenEquals(String sql, ParsedSql.TokenType token) { assertSingleStatementEquals(sql, new ParsedSql.Token(token, sql)); } @@ -117,47 +119,47 @@ class SingleTokenExceptionTests { @Test void unclosedSingleQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("'test")); } @Test void unclosedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("$$test")); } @Test void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("$abc$test")); } @Test void unclosedQuotedIdentifierThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("\"test")); } @Test void unclosedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("/*test")); } @Test void unclosedNestedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("/*/*test*/")); } @Test void invalidParameterCharacterThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("$1test")); } @Test void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("$a b$test$a b$")); } @Test void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc")); + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlSqlParser.parse("$abc")); } } @@ -268,9 +270,9 @@ void simpleSelectStatementWithFunctionBodyIsTokenized() { void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) { ParsedSql parsedSql = PostgresqlSqlParser.parse(sql); - assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); + assertThat(parsedSql.getStatements()).hasSize(1); ParsedSql.Statement statement = parsedSql.getStatements().get(0); - assertIterableEquals(Arrays.asList(tokens), statement.getTokens()); + assertThat(statement.getTokens()).containsExactly(tokens); } } @@ -282,30 +284,27 @@ class MultipleStatementTests { void simpleMultipleStatementIsTokenized() { ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;"); List statements = parsedSql.getStatements(); - assertEquals(2, statements.size()); + assertThat(parsedSql.getStatements()).hasSize(2); ParsedSql.Statement statementA = statements.get(0); ParsedSql.Statement statementB = statements.get(1); - assertIterableEquals( + assertThat( Arrays.asList( new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - ), - statementA.getTokens() - ); + ) + ).containsExactlyElementsOf(statementA.getTokens()); - assertIterableEquals( + assertThat( Arrays.asList( new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - ), - statementB.getTokens() - ); - + ) + ).containsExactlyElementsOf(statementB.getTokens()); } }