From 53b6d684cf88e3da9590ae5321e0a8a5d08bc463 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Tue, 23 Mar 2021 13:59:34 +0100 Subject: [PATCH] Properly convert primitive array arguments. Closes #945 Original pull request: #949. --- .../data/jdbc/core/convert/ArrayUtil.java | 93 +++++++++++++++++++ .../jdbc/core/convert/BasicJdbcConverter.java | 7 +- .../convert/BasicJdbcConverterUnitTests.java | 45 +++++++-- .../JdbcRepositoryIntegrationTests.java | 12 +++ .../jdbc/testing/TestDatabaseFeatures.java | 1 + 5 files changed, 148 insertions(+), 10 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ArrayUtil.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ArrayUtil.java index e84cbad2ee..c206d302d4 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ArrayUtil.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ArrayUtil.java @@ -42,4 +42,97 @@ static byte[] toPrimitiveByteArray(Byte[] byteArray) { } return bytes; } + + static Byte[] toObjectArray(byte[] primitiveArray) { + + Byte[] objects = new Byte[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Short[] toObjectArray(short[] primitiveArray) { + + Short[] objects = new Short[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Character[] toObjectArray(char[] primitiveArray) { + + Character[] objects = new Character[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Integer[] toObjectArray(int[] primitiveArray) { + + Integer[] objects = new Integer[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Long[] toObjectArray(long[] primitiveArray) { + + Long[] objects = new Long[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Float[] toObjectArray(float[] primitiveArray) { + + Float[] objects = new Float[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Double[] toObjectArray(double[] primitiveArray) { + + Double[] objects = new Double[primitiveArray.length]; + for (int i = 0; i < primitiveArray.length; i++) { + objects[i] = primitiveArray[i]; + } + return objects; + } + + static Object[] convertToObjectArray(Object unknownArray) { + + Class componentType = unknownArray.getClass().getComponentType(); + + if (componentType.isPrimitive()) { + if (componentType == byte.class) { + return toObjectArray((byte[]) unknownArray); + } + if (componentType == short.class) { + return toObjectArray((short[]) unknownArray); + } + if (componentType == char.class) { + return toObjectArray((char[]) unknownArray); + } + if (componentType == int.class) { + return toObjectArray((int[]) unknownArray); + } + if (componentType == long.class) { + return toObjectArray((long[]) unknownArray); + } + if (componentType == float.class) { + return toObjectArray((float[]) unknownArray); + } + if (componentType == double.class) { + return toObjectArray((double[]) unknownArray); + } + } + return (Object[]) unknownArray; + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverter.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverter.java index 0a8bce9dc7..d8a5582350 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverter.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverter.java @@ -24,7 +24,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.core.convert.ConverterNotFoundException; @@ -317,7 +316,9 @@ public JdbcValue writeJdbcValue(@Nullable Object value, Class columnType, int Class componentType = convertedValue.getClass().getComponentType(); if (componentType != byte.class && componentType != Byte.class) { - return JdbcValue.of(typeFactory.createArray((Object[]) convertedValue), JDBCType.ARRAY); + + Object[] objectArray = ArrayUtil.convertToObjectArray(convertedValue); + return JdbcValue.of(typeFactory.createArray(objectArray), JDBCType.ARRAY); } if (componentType == Byte.class) { @@ -333,7 +334,7 @@ private JdbcValue tryToConvertToJdbcValue(@Nullable Object value) { if (canWriteAsJdbcValue(value)) { Object converted = writeValue(value, ClassTypeInformation.from(JdbcValue.class)); - if(converted instanceof JdbcValue) { + if (converted instanceof JdbcValue) { return (JdbcValue) converted; } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverterUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverterUnitTests.java index 94f2f2a4a0..1eed43cfa8 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverterUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/BasicJdbcConverterUnitTests.java @@ -16,9 +16,11 @@ package org.springframework.data.jdbc.core.convert; import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; import lombok.Data; +import java.sql.Array; import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; @@ -35,8 +37,10 @@ import org.springframework.data.annotation.Id; import org.springframework.data.jdbc.core.mapping.AggregateReference; import org.springframework.data.jdbc.core.mapping.JdbcMappingContext; +import org.springframework.data.jdbc.support.JdbcUtil; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.sql.IdentifierProcessing; import org.springframework.data.util.ClassTypeInformation; /** @@ -47,9 +51,15 @@ public class BasicJdbcConverterUnitTests { JdbcMappingContext context = new JdbcMappingContext(); - BasicJdbcConverter converter = new BasicJdbcConverter(context, (identifier, path) -> { - throw new UnsupportedOperationException(); - }); + StubbedJdbcTypeFactory typeFactory = new StubbedJdbcTypeFactory(); + BasicJdbcConverter converter = new BasicJdbcConverter( // + context, // + (identifier, path) -> { + throw new UnsupportedOperationException(); + }, // + new JdbcCustomConversions(), // + typeFactory, IdentifierProcessing.ANSI // + ); @Test // DATAJDBC-104, DATAJDBC-1384 public void testTargetTypesForPropertyType() { @@ -110,14 +120,25 @@ void conversionOfDateLikeValueAndBackYieldsOriginalValue() { LocalDateTime testLocalDateTime = LocalDateTime.of(2001, 2, 3, 4, 5, 6, 123456789); checkConversionToTimestampAndBack(softly, persistentEntity, "localDateTime", testLocalDateTime); checkConversionToTimestampAndBack(softly, persistentEntity, "localDate", LocalDate.of(2001, 2, 3)); - checkConversionToTimestampAndBack(softly, persistentEntity, "localTime", LocalTime.of(1, 2, 3,123456789)); - checkConversionToTimestampAndBack(softly, persistentEntity, "instant", testLocalDateTime.toInstant(ZoneOffset.UTC)); + checkConversionToTimestampAndBack(softly, persistentEntity, "localTime", LocalTime.of(1, 2, 3, 123456789)); + checkConversionToTimestampAndBack(softly, persistentEntity, "instant", + testLocalDateTime.toInstant(ZoneOffset.UTC)); }); } - private void checkConversionToTimestampAndBack(SoftAssertions softly, RelationalPersistentEntity persistentEntity, String propertyName, - Object value) { + @Test // #945 + void conversionOfPrimitiveArrays() { + + int[] ints = { 1, 2, 3, 4, 5 }; + JdbcValue converted = converter.writeJdbcValue(ints, ints.getClass(), JdbcUtil.sqlTypeFor(ints.getClass())); + + assertThat(converted.getValue()).isInstanceOf(Array.class); + assertThat(typeFactory.arraySource).containsExactly(1, 2, 3, 4, 5); + } + + private void checkConversionToTimestampAndBack(SoftAssertions softly, RelationalPersistentEntity persistentEntity, + String propertyName, Object value) { RelationalPersistentProperty property = persistentEntity.getRequiredPersistentProperty(propertyName); @@ -165,4 +186,14 @@ private enum SomeEnum { @SuppressWarnings("unused") private static class OtherEntity {} + + private static class StubbedJdbcTypeFactory implements JdbcTypeFactory { + public Object[] arraySource; + + @Override + public Array createArray(Object[] value) { + arraySource = value; + return mock(Array.class); + } + } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java index 67d7a6b9a7..872e3e46fe 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.List; +import lombok.ToString; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -42,7 +43,9 @@ import org.springframework.data.jdbc.repository.query.Query; import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory; import org.springframework.data.jdbc.testing.AssumeFeatureTestExecutionListener; +import org.springframework.data.jdbc.testing.EnabledOnFeature; import org.springframework.data.jdbc.testing.TestConfiguration; +import org.springframework.data.jdbc.testing.TestDatabaseFeatures; import org.springframework.data.relational.core.mapping.event.AbstractRelationalEvent; import org.springframework.data.relational.core.mapping.event.AfterLoadEvent; import org.springframework.data.repository.CrudRepository; @@ -381,6 +384,12 @@ public void countByQueryDerivation() { assertThat(repository.countByName(one.getName())).isEqualTo(2); } + @Test // #945 + @EnabledOnFeature(TestDatabaseFeatures.Feature.IS_POSTGRES) + public void usePrimitiveArrayAsArgument() { + assertThat(repository.unnestPrimitive(new int[]{1, 2, 3})).containsExactly(1,2,3); + } + interface DummyEntityRepository extends CrudRepository { List findAllByNamedQuery(); @@ -406,6 +415,9 @@ interface DummyEntityRepository extends CrudRepository { boolean existsByName(String name); int countByName(String name); + + @Query("select unnest( :ids )") + List unnestPrimitive(@Param("ids") int[] ids); } @Configuration diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestDatabaseFeatures.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestDatabaseFeatures.java index 939abb991f..6e3e65a484 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestDatabaseFeatures.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestDatabaseFeatures.java @@ -115,6 +115,7 @@ public enum Feature { SUPPORTS_ARRAYS(TestDatabaseFeatures::supportsArrays), // SUPPORTS_GENERATED_IDS_IN_REFERENCED_ENTITIES(TestDatabaseFeatures::supportsGeneratedIdsInReferencedEntities), // SUPPORTS_NANOSECOND_PRECISION(TestDatabaseFeatures::supportsNanosecondPrecision), // + IS_POSTGRES(f -> f.databaseIs(Database.PostgreSql)), // IS_HSQL(f -> f.databaseIs(Database.Hsql)); private final Consumer featureMethod;