Skip to content

Commit

Permalink
Properly convert primitive array arguments.
Browse files Browse the repository at this point in the history
Closes #945
Original pull request: #949.
  • Loading branch information
schauder authored and mp911de committed Mar 29, 2021
1 parent 84ab063 commit 53b6d68
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,97 @@ static byte[] toPrimitiveByteArray(Byte[] byteArray) {
}
return bytes;
}

static Byte[] toObjectArray(byte[] primitiveArray) {

Byte[] objects = new Byte[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Short[] toObjectArray(short[] primitiveArray) {

Short[] objects = new Short[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Character[] toObjectArray(char[] primitiveArray) {

Character[] objects = new Character[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Integer[] toObjectArray(int[] primitiveArray) {

Integer[] objects = new Integer[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Long[] toObjectArray(long[] primitiveArray) {

Long[] objects = new Long[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Float[] toObjectArray(float[] primitiveArray) {

Float[] objects = new Float[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Double[] toObjectArray(double[] primitiveArray) {

Double[] objects = new Double[primitiveArray.length];
for (int i = 0; i < primitiveArray.length; i++) {
objects[i] = primitiveArray[i];
}
return objects;
}

static Object[] convertToObjectArray(Object unknownArray) {

Class<?> componentType = unknownArray.getClass().getComponentType();

if (componentType.isPrimitive()) {
if (componentType == byte.class) {
return toObjectArray((byte[]) unknownArray);
}
if (componentType == short.class) {
return toObjectArray((short[]) unknownArray);
}
if (componentType == char.class) {
return toObjectArray((char[]) unknownArray);
}
if (componentType == int.class) {
return toObjectArray((int[]) unknownArray);
}
if (componentType == long.class) {
return toObjectArray((long[]) unknownArray);
}
if (componentType == float.class) {
return toObjectArray((float[]) unknownArray);
}
if (componentType == double.class) {
return toObjectArray((double[]) unknownArray);
}
}
return (Object[]) unknownArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.convert.ConverterNotFoundException;
Expand Down Expand Up @@ -317,7 +316,9 @@ public JdbcValue writeJdbcValue(@Nullable Object value, Class<?> columnType, int

Class<?> componentType = convertedValue.getClass().getComponentType();
if (componentType != byte.class && componentType != Byte.class) {
return JdbcValue.of(typeFactory.createArray((Object[]) convertedValue), JDBCType.ARRAY);

Object[] objectArray = ArrayUtil.convertToObjectArray(convertedValue);
return JdbcValue.of(typeFactory.createArray(objectArray), JDBCType.ARRAY);
}

if (componentType == Byte.class) {
Expand All @@ -333,7 +334,7 @@ private JdbcValue tryToConvertToJdbcValue(@Nullable Object value) {
if (canWriteAsJdbcValue(value)) {

Object converted = writeValue(value, ClassTypeInformation.from(JdbcValue.class));
if(converted instanceof JdbcValue) {
if (converted instanceof JdbcValue) {
return (JdbcValue) converted;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package org.springframework.data.jdbc.core.convert;

import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;

import lombok.Data;

import java.sql.Array;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
Expand All @@ -35,8 +37,10 @@
import org.springframework.data.annotation.Id;
import org.springframework.data.jdbc.core.mapping.AggregateReference;
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
import org.springframework.data.jdbc.support.JdbcUtil;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.sql.IdentifierProcessing;
import org.springframework.data.util.ClassTypeInformation;

/**
Expand All @@ -47,9 +51,15 @@
public class BasicJdbcConverterUnitTests {

JdbcMappingContext context = new JdbcMappingContext();
BasicJdbcConverter converter = new BasicJdbcConverter(context, (identifier, path) -> {
throw new UnsupportedOperationException();
});
StubbedJdbcTypeFactory typeFactory = new StubbedJdbcTypeFactory();
BasicJdbcConverter converter = new BasicJdbcConverter( //
context, //
(identifier, path) -> {
throw new UnsupportedOperationException();
}, //
new JdbcCustomConversions(), //
typeFactory, IdentifierProcessing.ANSI //
);

@Test // DATAJDBC-104, DATAJDBC-1384
public void testTargetTypesForPropertyType() {
Expand Down Expand Up @@ -110,14 +120,25 @@ void conversionOfDateLikeValueAndBackYieldsOriginalValue() {
LocalDateTime testLocalDateTime = LocalDateTime.of(2001, 2, 3, 4, 5, 6, 123456789);
checkConversionToTimestampAndBack(softly, persistentEntity, "localDateTime", testLocalDateTime);
checkConversionToTimestampAndBack(softly, persistentEntity, "localDate", LocalDate.of(2001, 2, 3));
checkConversionToTimestampAndBack(softly, persistentEntity, "localTime", LocalTime.of(1, 2, 3,123456789));
checkConversionToTimestampAndBack(softly, persistentEntity, "instant", testLocalDateTime.toInstant(ZoneOffset.UTC));
checkConversionToTimestampAndBack(softly, persistentEntity, "localTime", LocalTime.of(1, 2, 3, 123456789));
checkConversionToTimestampAndBack(softly, persistentEntity, "instant",
testLocalDateTime.toInstant(ZoneOffset.UTC));
});

}

private void checkConversionToTimestampAndBack(SoftAssertions softly, RelationalPersistentEntity<?> persistentEntity, String propertyName,
Object value) {
@Test // #945
void conversionOfPrimitiveArrays() {

int[] ints = { 1, 2, 3, 4, 5 };
JdbcValue converted = converter.writeJdbcValue(ints, ints.getClass(), JdbcUtil.sqlTypeFor(ints.getClass()));

assertThat(converted.getValue()).isInstanceOf(Array.class);
assertThat(typeFactory.arraySource).containsExactly(1, 2, 3, 4, 5);
}

private void checkConversionToTimestampAndBack(SoftAssertions softly, RelationalPersistentEntity<?> persistentEntity,
String propertyName, Object value) {

RelationalPersistentProperty property = persistentEntity.getRequiredPersistentProperty(propertyName);

Expand Down Expand Up @@ -165,4 +186,14 @@ private enum SomeEnum {

@SuppressWarnings("unused")
private static class OtherEntity {}

private static class StubbedJdbcTypeFactory implements JdbcTypeFactory {
public Object[] arraySource;

@Override
public Array createArray(Object[] value) {
arraySource = value;
return mock(Array.class);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.ArrayList;
import java.util.List;

import lombok.ToString;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -42,7 +43,9 @@
import org.springframework.data.jdbc.repository.query.Query;
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory;
import org.springframework.data.jdbc.testing.AssumeFeatureTestExecutionListener;
import org.springframework.data.jdbc.testing.EnabledOnFeature;
import org.springframework.data.jdbc.testing.TestConfiguration;
import org.springframework.data.jdbc.testing.TestDatabaseFeatures;
import org.springframework.data.relational.core.mapping.event.AbstractRelationalEvent;
import org.springframework.data.relational.core.mapping.event.AfterLoadEvent;
import org.springframework.data.repository.CrudRepository;
Expand Down Expand Up @@ -381,6 +384,12 @@ public void countByQueryDerivation() {
assertThat(repository.countByName(one.getName())).isEqualTo(2);
}

@Test // #945
@EnabledOnFeature(TestDatabaseFeatures.Feature.IS_POSTGRES)
public void usePrimitiveArrayAsArgument() {
assertThat(repository.unnestPrimitive(new int[]{1, 2, 3})).containsExactly(1,2,3);
}

interface DummyEntityRepository extends CrudRepository<DummyEntity, Long> {

List<DummyEntity> findAllByNamedQuery();
Expand All @@ -406,6 +415,9 @@ interface DummyEntityRepository extends CrudRepository<DummyEntity, Long> {
boolean existsByName(String name);

int countByName(String name);

@Query("select unnest( :ids )")
List<Integer> unnestPrimitive(@Param("ids") int[] ids);
}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ public enum Feature {
SUPPORTS_ARRAYS(TestDatabaseFeatures::supportsArrays), //
SUPPORTS_GENERATED_IDS_IN_REFERENCED_ENTITIES(TestDatabaseFeatures::supportsGeneratedIdsInReferencedEntities), //
SUPPORTS_NANOSECOND_PRECISION(TestDatabaseFeatures::supportsNanosecondPrecision), //
IS_POSTGRES(f -> f.databaseIs(Database.PostgreSql)), //
IS_HSQL(f -> f.databaseIs(Database.Hsql));

private final Consumer<TestDatabaseFeatures> featureMethod;
Expand Down

0 comments on commit 53b6d68

Please sign in to comment.