Skip to content

Commit

Permalink
for #1205, Get connection sync to prevent dead lock for sharding-jdbc
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Sep 19, 2018
1 parent 7c4eadc commit f38351a
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;

/**
* SQL execute prepare callback.
Expand All @@ -36,11 +37,11 @@ public interface SQLExecutePrepareCallback {
* Get connection.
*
* @param dataSourceName data source name
* @param index index of connection
* @param connectionSize connection size
* @return connection
* @throws SQLException SQL exception
*/
Connection getConnection(String dataSourceName, int index) throws SQLException;
List<Connection> getConnections(String dataSourceName, int connectionSize) throws SQLException;

/**
* Create SQL execute unit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ private List<ShardingExecuteGroup<StatementExecuteUnit>> getSQLExecuteGroups(
final String dataSourceName, final List<SQLUnit> sqlUnits, final SQLExecutePrepareCallback callback) throws SQLException {
List<ShardingExecuteGroup<StatementExecuteUnit>> result = new LinkedList<>();
int desiredPartitionSize = Math.max(sqlUnits.size() / maxConnectionsSizePerQuery, 1);
int index = 0;
for (List<SQLUnit> each : Lists.partition(sqlUnits, desiredPartitionSize)) {
// TODO get connection sync to prevent dead lock
result.add(getSQLExecuteGroup(callback.getConnection(dataSourceName, index++), dataSourceName, each, callback));
List<List<SQLUnit>> sqlUnitGroups = Lists.partition(sqlUnits, desiredPartitionSize);
List<Connection> connections = callback.getConnections(dataSourceName, sqlUnitGroups.size());
int count = 0;
for (List<SQLUnit> each : sqlUnitGroups) {
result.add(getSQLExecuteGroup(connections.get(count++), dataSourceName, each, callback));
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ public RouteUnit apply(final BatchRouteUnit input) {
}), new SQLExecutePrepareCallback() {

@Override
public Connection getConnection(final String dataSourceName, final int index) throws SQLException {
Connection conn = BatchPreparedStatementExecutor.super.getConnection().getConnection(dataSourceName, index);
getConnections().add(conn);
return conn;
public List<Connection> getConnections(final String dataSourceName, final int connectionSize) throws SQLException {
return BatchPreparedStatementExecutor.super.getConnection().getConnections(dataSourceName, connectionSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ private Collection<ShardingExecuteGroup<StatementExecuteUnit>> obtainExecuteGrou
return getSqlExecutePrepareTemplate().getExecuteUnitGroups(routeUnits, new SQLExecutePrepareCallback() {

@Override
public Connection getConnection(final String dataSourceName, final int index) throws SQLException {
Connection conn = PreparedStatementExecutor.super.getConnection().getConnection(dataSourceName, index);
getConnections().add(conn);
return conn;
public List<Connection> getConnections(final String dataSourceName, final int connectionSize) throws SQLException {
return PreparedStatementExecutor.super.getConnection().getConnections(dataSourceName, connectionSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ private Collection<ShardingExecuteGroup<StatementExecuteUnit>> obtainExecuteGrou
return getSqlExecutePrepareTemplate().getExecuteUnitGroups(routeUnits, new SQLExecutePrepareCallback() {

@Override
public Connection getConnection(final String dataSourceName, final int index) throws SQLException {
Connection conn = StatementExecutor.super.getConnection().getConnection(dataSourceName, index);
getConnections().add(conn);
return conn;
public List<Connection> getConnections(final String dataSourceName, final int connectionSize) throws SQLException {
return StatementExecutor.super.getConnection().getConnections(dataSourceName, connectionSize);
}

@SuppressWarnings("MagicConstant")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -82,36 +84,54 @@ public abstract class AbstractConnectionAdapter extends AbstractUnsupportedOpera
* @throws SQLException SQL exception
*/
public final Connection getConnection(final String dataSourceName) throws SQLException {
return getConnection(dataSourceName, 0);
return getConnections(dataSourceName, 1).get(0);
}

/**
* Get database connection.
* Get database connections.
*
* @param dataSourceName data source name
* @param index index of connection
* @return database connection
* @param connectionSize size of connection list to be get
* @return database connections
* @throws SQLException SQL exception
*/
public final Connection getConnection(final String dataSourceName, final int index) throws SQLException {
public final List<Connection> getConnections(final String dataSourceName, final int connectionSize) throws SQLException {
ShardingEventBusInstance.getInstance().post(new GetConnectionStartEvent(dataSourceName));
DataSource dataSource = getDataSourceMap().get(dataSourceName);
Preconditions.checkState(null != dataSource, "Missing the data source name: '%s'", dataSourceName);
Collection<Connection> connections = cachedConnections.get(dataSourceName);
if (cachedConnections.get(dataSourceName).size() > index) {
Connection result = cachedConnections.get(dataSourceName).toArray(new Connection[connections.size()])[index];
postGetConnectionEvent(result);
return result;
List<Connection> result;
if (connections.size() >= connectionSize) {
result = new ArrayList<>(cachedConnections.get(dataSourceName)).subList(0, connectionSize);
} else if (!connections.isEmpty()) {
result = new ArrayList<>(connectionSize);
result.addAll(connections);
List<Connection> newConnections = createConnections(dataSource, connectionSize - connections.size());
result.addAll(newConnections);
cachedConnections.putAll(dataSourceName, newConnections);
} else {
result = new ArrayList<>(createConnections(dataSource, connectionSize));
cachedConnections.putAll(dataSourceName, result);
}
Connection result = dataSource.getConnection();
cachedConnections.put(dataSourceName, result);
replayMethodsInvocation(result);
postGetConnectionEvent(result);
return result;
}

private void postGetConnectionEvent(final Connection connection) throws SQLException {
GetConnectionEvent finishEvent = new GetConnectionFinishEvent(DataSourceMetaDataFactory.newInstance(databaseType, connection.getMetaData().getURL()));
@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
private synchronized List<Connection> createConnections(final DataSource dataSource, final int connectionSize) throws SQLException {
List<Connection> result = new ArrayList<>(connectionSize);
synchronized (dataSource) {
for (int i = 0; i < connectionSize; i++) {
Connection connection = dataSource.getConnection();
replayMethodsInvocation(connection);
result.add(connection);
}
}
return result;
}

private void postGetConnectionEvent(final List<Connection> connections) throws SQLException {
GetConnectionEvent finishEvent = new GetConnectionFinishEvent(DataSourceMetaDataFactory.newInstance(databaseType, connections.get(0).getMetaData().getURL()));
finishEvent.setExecuteSuccess();
ShardingEventBusInstance.getInstance().post(finishEvent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -175,8 +176,8 @@ private final class ConnectionStrictlySQLExecutePrepareCallback implements SQLEx
private final boolean isReturnGeneratedKeys;

@Override
public Connection getConnection(final String dataSourceName, final int index) throws SQLException {
return getBackendConnection().getConnection(dataSourceName);
public List<Connection> getConnections(final String dataSourceName, final int connectionSize) throws SQLException {
return Collections.singletonList(getBackendConnection().getConnection(dataSourceName));
}

@Override
Expand Down

0 comments on commit f38351a

Please sign in to comment.