diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index dbef6d1e1a..6d9fb662df 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -16,6 +16,7 @@ package org.springframework.data.jdbc.core; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -23,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -56,6 +58,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; /** * {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store. @@ -173,19 +176,8 @@ public T save(T instance) { @Override public List saveAll(Iterable instances) { - - Assert.notNull(instances, "Aggregate instances must not be null"); - - if (!instances.iterator().hasNext()) { - return Collections.emptyList(); - } - - List> entityAndChangeCreators = new ArrayList<>(); - for (T instance : instances) { - verifyIdProperty(instance); - entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance))); - } - return performSaveAll(entityAndChangeCreators); + return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty, + this::performSaveAll); } /** @@ -206,21 +198,7 @@ public T insert(T instance) { @Override public List insertAll(Iterable instances) { - - Assert.notNull(instances, "Aggregate instances must not be null"); - - if (!instances.iterator().hasNext()) { - return Collections.emptyList(); - } - - List> entityAndChangeCreators = new ArrayList<>(); - for (T instance : instances) { - - Function> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity)); - EntityAndChangeCreator entityChange = new EntityAndChangeCreator<>(instance, changeCreator); - entityAndChangeCreators.add(entityChange); - } - return performSaveAll(entityAndChangeCreators); + return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll); } /** @@ -241,21 +219,35 @@ public T update(T instance) { @Override public List updateAll(Iterable instances) { + return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll); + } + + private List doWithBatch(Iterable iterable, Function> changeCreator, + Function>, List> performFunction) { + return doWithBatch(iterable, changeCreator, entity -> {}, performFunction); + } - Assert.notNull(instances, "Aggregate instances must not be null"); + private List doWithBatch(Iterable iterable, Function> changeCreator, + Consumer beforeEntityChange, Function>, List> performFunction) { - if (!instances.iterator().hasNext()) { + Assert.notNull(iterable, "Aggregate instances must not be null"); + + if (ObjectUtils.isEmpty(iterable)) { return Collections.emptyList(); } - List> entityAndChangeCreators = new ArrayList<>(); - for (T instance : instances) { + List> entityAndChangeCreators = new ArrayList<>( + iterable instanceof Collection c ? c.size() : 16); + + for (T instance : iterable) { + + beforeEntityChange.accept(instance); - Function> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity)); EntityAndChangeCreator entityChange = new EntityAndChangeCreator<>(instance, changeCreator); entityAndChangeCreators.add(entityChange); } - return performSaveAll(entityAndChangeCreators); + + return performFunction.apply(entityAndChangeCreators); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java index d3004c61a0..4d210d516d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java @@ -272,12 +272,12 @@ public T findById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType) { + public List findAll(Class domainType) { return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType)); } @Override - public Iterable findAllById(Iterable ids, Class domainType) { + public List findAllById(Iterable ids, Class domainType) { if (!ids.iterator().hasNext()) { return Collections.emptyList(); @@ -290,7 +290,7 @@ public Iterable findAllById(Iterable ids, Class domainType) { @Override @SuppressWarnings("unchecked") - public Iterable findAllByPath(Identifier identifier, + public List findAllByPath(Identifier identifier, PersistentPropertyPath propertyPath) { Assert.notNull(identifier, "identifier must not be null"); @@ -338,12 +338,12 @@ public boolean existsById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType, Sort sort) { + public List findAll(Class domainType, Sort sort) { return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType)); } @Override - public Iterable findAll(Class domainType, Pageable pageable) { + public List findAll(Class domainType, Pageable pageable) { return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType)); } @@ -361,7 +361,7 @@ public Optional findOne(Query query, Class domainType) { } @Override - public Iterable findAll(Query query, Class domainType) { + public List findAll(Query query, Class domainType) { MapSqlParameterSource parameterSource = new MapSqlParameterSource(); String sqlQuery = sql(domainType).selectByQuery(query, parameterSource); @@ -370,7 +370,7 @@ public Iterable findAll(Query query, Class domainType) { } @Override - public Iterable findAll(Query query, Class domainType, Pageable pageable) { + public List findAll(Query query, Class domainType, Pageable pageable) { MapSqlParameterSource parameterSource = new MapSqlParameterSource(); String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java index 13dd732f42..3b8b8efd34 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java @@ -256,7 +256,7 @@ public T findById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType) { + public List findAll(Class domainType) { String statement = namespace(domainType) + ".findAll"; MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap()); @@ -264,13 +264,13 @@ public Iterable findAll(Class domainType) { } @Override - public Iterable findAllById(Iterable ids, Class domainType) { + public List findAllById(Iterable ids, Class domainType) { return sqlSession().selectList(namespace(domainType) + ".findAllById", new MyBatisContext(ids, null, domainType, Collections.emptyMap())); } @Override - public Iterable findAllByPath(Identifier identifier, + public List findAllByPath(Identifier identifier, PersistentPropertyPath path) { String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath(); @@ -288,7 +288,7 @@ public boolean existsById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType, Sort sort) { + public List findAll(Class domainType, Sort sort) { Map additionalContext = new HashMap<>(); additionalContext.put("sort", sort); @@ -297,7 +297,7 @@ public Iterable findAll(Class domainType, Sort sort) { } @Override - public Iterable findAll(Class domainType, Pageable pageable) { + public List findAll(Class domainType, Pageable pageable) { Map additionalContext = new HashMap<>(); additionalContext.put("pageable", pageable); @@ -311,12 +311,12 @@ public Optional findOne(Query query, Class probeType) { } @Override - public Iterable findAll(Query query, Class probeType) { + public List findAll(Query query, Class probeType) { throw new UnsupportedOperationException("Not implemented"); } @Override - public Iterable findAll(Query query, Class probeType, Pageable pageable) { + public List findAll(Query query, Class probeType, Pageable pageable) { throw new UnsupportedOperationException("Not implemented"); }