diff --git a/.gitignore b/.gitignore index d163d47d..7183e3df 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ target .project .classpath .settings +.idea +*.iml \ No newline at end of file diff --git a/src/main/java/com/github/davidmoten/rx/jdbc/QuerySelectOnSubscribe.java b/src/main/java/com/github/davidmoten/rx/jdbc/QuerySelectOnSubscribe.java index 1b123cd9..f0fb1b60 100644 --- a/src/main/java/com/github/davidmoten/rx/jdbc/QuerySelectOnSubscribe.java +++ b/src/main/java/com/github/davidmoten/rx/jdbc/QuerySelectOnSubscribe.java @@ -10,6 +10,7 @@ import rx.Observable; import rx.Observable.OnSubscribe; import rx.Subscriber; +import rx.Subscription; import rx.functions.Action0; import rx.subscriptions.Subscriptions; @@ -107,7 +108,10 @@ private void connectAndPrepareStatement(Subscriber subscriber, State state.ps = state.con.prepareStatement(query.sql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); log.debug("setting parameters"); - Util.setParameters(state.ps, parameters, query.names()); + List subscriptions = Util.setParameters(state.ps, parameters, query.names()); + for (Subscription subscription : subscriptions) { + subscriber.add(subscription); + } } } diff --git a/src/main/java/com/github/davidmoten/rx/jdbc/QueryUpdateOnSubscribe.java b/src/main/java/com/github/davidmoten/rx/jdbc/QueryUpdateOnSubscribe.java index 0d6bd7d4..77d377d4 100644 --- a/src/main/java/com/github/davidmoten/rx/jdbc/QueryUpdateOnSubscribe.java +++ b/src/main/java/com/github/davidmoten/rx/jdbc/QueryUpdateOnSubscribe.java @@ -214,7 +214,10 @@ private void performUpdate(final Subscriber subscriber, State state) keysOption = Statement.NO_GENERATED_KEYS; } state.ps = state.con.prepareStatement(query.sql(), keysOption); - Util.setParameters(state.ps, parameters, query.names()); + List subscriptions = Util.setParameters(state.ps, parameters, query.names()); + for (Subscription subscription : subscriptions) { + subscriber.add(subscription); + } if (subscriber.isUnsubscribed()) return; diff --git a/src/main/java/com/github/davidmoten/rx/jdbc/Util.java b/src/main/java/com/github/davidmoten/rx/jdbc/Util.java index 533b9ddc..89aad334 100644 --- a/src/main/java/com/github/davidmoten/rx/jdbc/Util.java +++ b/src/main/java/com/github/davidmoten/rx/jdbc/Util.java @@ -12,6 +12,7 @@ import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; +import java.sql.Array; import java.sql.Blob; import java.sql.Clob; import java.sql.Connection; @@ -33,10 +34,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import rx.Subscription; +import rx.functions.Action0; import rx.functions.Func1; import com.github.davidmoten.rx.jdbc.QuerySelect.Builder; import com.github.davidmoten.rx.jdbc.exceptions.SQLRuntimeException; +import rx.subscriptions.Subscriptions; /** * Utility methods. @@ -751,8 +755,9 @@ public int read(char[] cbuf, int off, int len) throws IOException { * @param params * @throws SQLException */ - static void setParameters(PreparedStatement ps, List params, boolean namesAllowed) + static List setParameters(PreparedStatement ps, List params, boolean namesAllowed) throws SQLException { + final List subscriptions = new ArrayList<>(); for (int i = 1; i <= params.size(); i++) { if (params.get(i - 1).hasName() && !namesAllowed) throw new SQLException("named parameter found but sql does not contain names"); @@ -787,16 +792,49 @@ else if (o == Database.NULL_BLOB) Calendar cal = Calendar.getInstance(); java.util.Date date = (java.util.Date) o; ps.setTimestamp(i, new java.sql.Timestamp(date.getTime()), cal); + } else if (cls.isArray() && !cls.getComponentType().isPrimitive()) { + Subscription subscription = configureArray(ps, i, (Object[]) o); + if (subscription != null) { + subscriptions.add(subscription); + } } else ps.setObject(i, o); } } catch (SQLException e) { log.debug("{} when setting ps.setObject({},{})", e.getMessage(), i, o); + for (Subscription subscription : subscriptions) { + subscription.unsubscribe(); + } + throw e; + } + } + + return subscriptions; + } + + private static Subscription configureArray(PreparedStatement ps, int i, Object[] o) throws SQLException { + if (String[].class.isAssignableFrom(o.getClass()) && !getDatabaseProductName(ps).equals("H2")) { + final Array array = ps.getConnection().createArrayOf("varchar", o); + Subscription subscription = Subscriptions.create(new ArrayFreeAction(array)); + + try { + ps.setArray(i, array); + } catch (SQLException e) { + array.free(); throw e; } + + return subscription; + } else { + ps.setObject(i, o); + return null; } } + private static String getDatabaseProductName(PreparedStatement ps) throws SQLException { + return ps.getConnection().getMetaData().getDatabaseProductName(); + } + /** * Sets a blob parameter for the prepared statement. * @@ -905,8 +943,8 @@ static ResultSetMapper toOne() { return ResultSetMapperToOne.INSTANCE; } - public static void setNamedParameters(PreparedStatement ps, List parameters, - List names) throws SQLException { + public static List setNamedParameters(PreparedStatement ps, List parameters, + List names) throws SQLException { Map map = new HashMap(); for (Parameter p : parameters) { if (p.hasName()) { @@ -924,15 +962,32 @@ public static void setNamedParameters(PreparedStatement ps, List para Parameter p = map.get(name); list.add(p); } - Util.setParameters(ps, list, true); + return Util.setParameters(ps, list, true); } - static void setParameters(PreparedStatement ps, List parameters, List names) + static List setParameters(PreparedStatement ps, List parameters, List names) throws SQLException { if (names.isEmpty()) { - Util.setParameters(ps, parameters, false); + return Util.setParameters(ps, parameters, false); } else { - Util.setNamedParameters(ps, parameters, names); + return Util.setNamedParameters(ps, parameters, names); + } + } + + private static class ArrayFreeAction implements Action0 { + private final Array array; + + public ArrayFreeAction(Array array) { + this.array = array; + } + + @Override + public void call() { + try { + array.free(); + } catch (Exception e) { + // ignore + } } } } diff --git a/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseCreator.java b/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseCreator.java index 4aa739d7..ee8564cb 100644 --- a/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseCreator.java +++ b/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseCreator.java @@ -84,6 +84,10 @@ public static void createDatabase(Connection c) { c.prepareStatement( "create table note(id bigint auto_increment primary key, text varchar(255))") .execute(); + + c.prepareStatement( + "create table person_lines (name varchar(50) not null, lines array)") + .execute(); } catch (SQLException e) { throw new SQLRuntimeException(e); } diff --git a/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseTestBase.java b/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseTestBase.java index 90b8b9d7..4abe529c 100644 --- a/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseTestBase.java +++ b/src/test/java/com/github/davidmoten/rx/jdbc/DatabaseTestBase.java @@ -775,6 +775,16 @@ public void testCalendarParameter() throws SQLException { assertEquals(0, t.getTime()); } + @Test + public void testStringArray() throws Exception { + Database db = db(); + String[] lines = new String[] {"123 Main St.", "Nowhere, USA"}; + + int actual = db.update("INSERT INTO person_lines (name, lines) VALUES (?, ?)") + .parameters("fred", lines).count().first().toBlocking().single(); + assertEquals(1, actual); + } + @Test public void testDatabaseBuilder() { Database.builder().connectionProvider(connectionProvider())