From 8f27eca7ca34f172a95680d706c8c9427a5176b2 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 15 Jan 2022 11:19:52 +0100 Subject: [PATCH] Fix Statement leak in Trino JDBC DatabaseMetaData Before the change, Trino JDBC's `DatabaseMetaData` implementation (`TrinoDatabaseMetaData`) would create `Statement` objects that are never closed. Since `Connection` (`TrinoConnection`) tracks open statements to be able to close them upon `Connection.close()` (per JDBC requirements), this created a memory leak where `Statement` objects are leaked in `Connection.statements` collection. The commit fixing this, under the condition that `ResultSet` returned from `TrinoDatabaseMetaData` is correctly closed. --- .../java/io/trino/jdbc/TrinoConnection.java | 6 +++ .../io/trino/jdbc/TrinoDatabaseMetaData.java | 20 ++++++++- .../java/io/trino/jdbc/TrinoResultSet.java | 45 ++++++++++++++++--- .../trino/jdbc/TestTrinoDatabaseMetaData.java | 37 +++++++++++++++ 4 files changed, 100 insertions(+), 8 deletions(-) diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index fad337bacc26..b0e61088dcd4 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -782,6 +782,12 @@ private void unregisterStatement(TrinoStatement statement) checkState(statements.remove(statement), "Statement is not registered"); } + @VisibleForTesting + int activeStatements() + { + return statements.size(); + } + private void checkOpen() throws SQLException { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java index 48a69e1d6c2f..a3cbfb173628 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java @@ -28,6 +28,7 @@ import java.sql.RowIdLifetime; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; +import java.sql.Statement; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -1488,7 +1489,24 @@ private ResultSet selectEmpty(String sql) private ResultSet select(String sql) throws SQLException { - return getConnection().createStatement().executeQuery(sql); + Statement statement = getConnection().createStatement(); + TrinoResultSet resultSet; + try { + resultSet = (TrinoResultSet) statement.executeQuery(sql); + resultSet.setCloseStatementOnClose(); + } + catch (Throwable e) { + try { + statement.close(); + } + catch (Throwable closeException) { + if (closeException != e) { + e.addSuppressed(closeException); + } + } + throw e; + } + return resultSet; } private static void buildFilters(StringBuilder out, List filters) diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java index 06b8632cd45b..35fac6a3e48d 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java @@ -20,6 +20,8 @@ import io.trino.client.QueryStatusInfo; import io.trino.client.StatementClient; +import javax.annotation.concurrent.GuardedBy; + import java.sql.SQLException; import java.sql.Statement; import java.util.Iterator; @@ -31,7 +33,6 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.stream.Stream; @@ -44,10 +45,14 @@ public class TrinoResultSet extends AbstractTrinoResultSet { + private final Statement statement; private final StatementClient client; private final String queryId; - private final AtomicBoolean closed = new AtomicBoolean(); + @GuardedBy("this") + private boolean closed; + @GuardedBy("this") + private boolean closeStatementOnClose; static TrinoResultSet create(Statement statement, StatementClient client, long maxRows, Consumer progressCallback, WarningsManager warningsManager) throws SQLException @@ -65,6 +70,7 @@ private TrinoResultSet(Statement statement, StatementClient client, List columns, new AsyncIterator<>(flatten(new ResultsPageIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager), maxRows), client)); + this.statement = statement; this.client = requireNonNull(client, "client is null"); requireNonNull(progressCallback, "progressCallback is null"); @@ -81,21 +87,46 @@ public QueryStats getStats() return QueryStats.create(queryId, client.getStats()); } + void setCloseStatementOnClose() + throws SQLException + { + boolean alreadyClosed; + synchronized (this) { + alreadyClosed = closed; + if (!alreadyClosed) { + closeStatementOnClose = true; + } + } + if (alreadyClosed) { + statement.close(); + } + } + @Override public void close() throws SQLException { - if (closed.compareAndSet(false, true)) { - ((AsyncIterator) results).cancel(); - client.close(); + boolean closeStatement; + synchronized (this) { + if (closed) { + return; + } + closed = true; + closeStatement = closeStatementOnClose; + } + + ((AsyncIterator) results).cancel(); + client.close(); + if (closeStatement) { + statement.close(); } } @Override - public boolean isClosed() + public synchronized boolean isClosed() throws SQLException { - return closed.get(); + return closed; } void partialCancel() diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java index 76a4d4f70d72..a4a76a10e474 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java @@ -1496,6 +1496,33 @@ public void testEscapeIfNecessary() assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc\\_def"), "abc\\\\\\_def"); } + @Test + public void testStatementsDoNotLeak() + throws Exception + { + TrinoConnection connection = (TrinoConnection) this.connection; + DatabaseMetaData metaData = connection.getMetaData(); + + // consumed + try (ResultSet resultSet = metaData.getCatalogs()) { + assertThat(countRows(resultSet)).isEqualTo(5); + } + try (ResultSet resultSet = metaData.getSchemas(TEST_CATALOG, null)) { + assertThat(countRows(resultSet)).isEqualTo(10); + } + try (ResultSet resultSet = metaData.getTables(TEST_CATALOG, "sf%", null, null)) { + assertThat(countRows(resultSet)).isEqualTo(64); + } + + // not consumed + metaData.getCatalogs().close(); + metaData.getSchemas(TEST_CATALOG, null).close(); + metaData.getTables(TEST_CATALOG, "sf%", null, null).close(); + + assertThat(connection.activeStatements()).as("activeStatements") + .isEqualTo(0); + } + private static void assertColumnSpec(ResultSet rs, int dataType, Long precision, Long numPrecRadix, String typeName) throws SQLException { @@ -1585,6 +1612,16 @@ private MetaDataCallback>> readMetaData(MetaDataCallback