From cb3bf80876cb01e0233aa238a158fda46296bd9e Mon Sep 17 00:00:00 2001 From: Ilya Kaznacheev Date: Mon, 30 Oct 2023 17:32:09 +0300 Subject: [PATCH] Only deep reset if default object is of the same type as current and is not null Change API to remove redundant defaults param Revert "Disable deep resetting for object fields #745" This reverts commit 96bcb58e9ef4cca56845430488da47dbe780dabf. --- .../wire/AbstractMarshallableCfg.java | 4 +- .../openhft/chronicle/wire/Marshallable.java | 2 +- .../chronicle/wire/WireMarshaller.java | 406 ++++++++++++------ .../WireMarshallerForUnexpectedFields.java | 14 +- .../net/openhft/chronicle/wire/Wires.java | 8 +- .../openhft/chronicle/wire/WireResetTest.java | 6 +- 6 files changed, 287 insertions(+), 153 deletions(-) diff --git a/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java b/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java index e524f73def..a4e637cb45 100644 --- a/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java +++ b/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java @@ -49,7 +49,7 @@ public void readMarshallable(@NotNull WireIn wire) throws IORuntimeException, In // Use the WireMarshaller to read the configuration // Field that are not present in the input are not touched. - wm.readMarshallable(this, wire, wm.defaultValue(), false); + wm.readMarshallable(this, wire, false); } /** @@ -69,7 +69,7 @@ public void writeMarshallable(@NotNull WireOut wire) throws InvalidMarshallableE // Use the WireMarshaller to write the configuration // Fields with a default value are not written - marshaller.writeMarshallable(this, wire, marshaller.defaultValue(), false); + marshaller.writeMarshallable(this, wire, false); } /** diff --git a/src/main/java/net/openhft/chronicle/wire/Marshallable.java b/src/main/java/net/openhft/chronicle/wire/Marshallable.java index c2c612aa63..6359f744d0 100644 --- a/src/main/java/net/openhft/chronicle/wire/Marshallable.java +++ b/src/main/java/net/openhft/chronicle/wire/Marshallable.java @@ -136,7 +136,7 @@ default void readMarshallable(@NotNull WireIn wire) throws IORuntimeException, I WireMarshaller wm = WIRE_MARSHALLER_CL.get(this.getClass()); // Delegate the reading process to the obtained WireMarshaller - wm.readMarshallable(this, wire, wm.defaultValue(), true); + wm.readMarshallable(this, wire, true); } /** diff --git a/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java b/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java index 4a08886463..360deea796 100644 --- a/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java +++ b/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java @@ -22,10 +22,10 @@ import net.openhft.chronicle.core.*; import net.openhft.chronicle.core.io.*; import net.openhft.chronicle.core.scoped.ScopedResource; -import net.openhft.chronicle.core.util.ClassLocal; import net.openhft.chronicle.core.util.ClassNotFoundRuntimeException; import net.openhft.chronicle.core.util.ObjectUtils; import net.openhft.chronicle.core.util.StringUtils; +import net.openhft.chronicle.core.util.ClassLocal; import net.openhft.chronicle.core.values.IntValue; import net.openhft.chronicle.core.values.LongValue; import org.jetbrains.annotations.NotNull; @@ -62,7 +62,7 @@ protected WireMarshaller(@NotNull Class tClass, @NotNull FieldAccess[] fields : WireMarshaller.of(tClass) ); - private WireMarshaller(@NotNull FieldAccess[] fields, boolean isLeaf, @Nullable T defaultValue) { + WireMarshaller(@NotNull FieldAccess[] fields, boolean isLeaf, @Nullable T defaultValue) { this.fields = fields; this.isLeaf = isLeaf; this.defaultValue = defaultValue; @@ -76,12 +76,14 @@ public static WireMarshaller of(@NotNull Class tClass) { if (tClass.isInterface() || (tClass.isEnum() && !DynamicEnum.class.isAssignableFrom(tClass))) return new WireMarshaller<>(tClass, NO_FIELDS, true); + T defaultObject = defaultValueForType(tClass); + @NotNull Map map = new LinkedHashMap<>(); getAllField(tClass, map); final FieldAccess[] fields = map.values().stream() // for Java 15+ strip "hidden" fields that can't be accessed in Java 15+ this way. .filter(field -> !(Jvm.isJava15Plus() && field.getName().matches("^.*\\$\\d+$"))) - .map(FieldAccess::create) + .map(field -> FieldAccess.create(field, defaultObject)) .toArray(FieldAccess[]::new); Map fieldCount = Stream.of(fields).collect(Collectors.groupingBy(f -> f.field.getName(), Collectors.counting())); fieldCount.forEach((n, c) -> { @@ -92,8 +94,8 @@ public static WireMarshaller of(@NotNull Class tClass) { .collect(Collectors.toList()); boolean isLeaf = collect.isEmpty(); return overridesUnexpectedFields(tClass) - ? new WireMarshallerForUnexpectedFields<>(tClass, fields, isLeaf) - : new WireMarshaller<>(tClass, fields, isLeaf); + ? new WireMarshallerForUnexpectedFields<>(fields, isLeaf, defaultObject) + : new WireMarshaller<>(fields, isLeaf, defaultObject); } protected static boolean leafable(FieldAccess c) { @@ -116,12 +118,14 @@ private static boolean overridesUnexpectedFields(Class tClass) { @NotNull private static WireMarshaller ofThrowable(@NotNull Class tClass) { + T defaultObject = defaultValueForType(tClass); + @NotNull Map map = new LinkedHashMap<>(); getAllField(tClass, map); final FieldAccess[] fields = map.values().stream() - .map(FieldAccess::create).toArray(FieldAccess[]::new); + .map(field -> FieldAccess.create(field, defaultObject)).toArray(FieldAccess[]::new); boolean isLeaf = false; - return new WireMarshaller<>(tClass, fields, isLeaf); + return new WireMarshaller<>(fields, isLeaf, defaultObject); } private static boolean isCollection(@NotNull Class c) { @@ -149,7 +153,7 @@ public static void getAllField(@NotNull Class clazz, @NotNull Map } } - private static T defaultValueForType(@NotNull Class tClass) { + static T defaultValueForType(@NotNull Class tClass) { // tClass = ObjectUtils.implementationToUse(tClass); if (ObjectUtils.isConcreteClass(tClass) && !tClass.getName().startsWith("java") @@ -279,65 +283,89 @@ public void writeMarshallable(T t, Bytes bytes) { } } + /** + * @see #writeMarshallable(Object, WireOut, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void writeMarshallable(T t, @NotNull WireOut out, T ignored, boolean copy) throws InvalidMarshallableException { + writeMarshallable(t, out, copy); + } + /** * Writes the values of the fields from the provided object (DTO) to the output. Before writing, * the object is validated. The method also supports optional copying of the values * from the source object to a previous instance. * - * @param t Object whose field values are to be written. - * @param out Output destination where the field values are written to. - * @param previous Previous object to compare for optional copying. - * @param copy Flag indicating whether to copy values from the source object to the previous object. + * @param t Object whose field values are to be written. + * @param out Output destination where the field values are written to. + * @param copy Flag indicating whether to copy values from the source object to the previous object. * @throws InvalidMarshallableException If there's an error during marshalling. */ - public void writeMarshallable(T t, @NotNull WireOut out, T previous, boolean copy) throws InvalidMarshallableException { + public void writeMarshallable(T t, @NotNull WireOut out, boolean copy) throws InvalidMarshallableException { // Validate the object before writing ValidatableUtil.validate(t); try { // Iterate through all fields and write their values to the output for (@NotNull FieldAccess field : fields) { - field.write(t, out, previous, copy); + field.write(t, out, defaultValue, copy); } } catch (IllegalAccessException e) { throw new AssertionError(e); } } + /** + * @see #readMarshallable(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallable(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallable(t, in, overwrite); + } + /** * Reads and populates the DTO based on the provided input. The input order can be hinted. * After reading, the object is validated. * * @param t Object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallable(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { // Choose the reading method based on the hint if (in.hintReadInputOrder()) - readMarshallableInputOrder(t, in, defaults, overwrite); + readMarshallableInputOrder(t, in, overwrite); else - readMarshallableDTOOrder(t, in, defaults, overwrite); + readMarshallableDTOOrder(t, in, overwrite); // Validate the object after reading ValidatableUtil.validate(t); } + /** + * @see #readMarshallableDTOOrder(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallableDTOOrder(t, in, overwrite); + } + /** * Reads and populates the DTO based on the provided order. * * @param t Target object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallableDTOOrder(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try { for (@NotNull FieldAccess field : fields) { ValueIn vin = in.read(field.key); - field.readValue(t, defaults, vin, overwrite); + field.readValue(t, defaultValue, vin, overwrite); } ValidatableUtil.validate(t); } catch (IllegalAccessException e) { @@ -345,16 +373,24 @@ public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T defaults, boolea } } + /** + * @see #readMarshallableInputOrder(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallableInputOrder(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallableDTOOrder(t, in, overwrite); + } + /** * Reads and populates the DTO based on the input's order. * * @param t Target object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallableInputOrder(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try (ScopedResource stlSb = Wires.acquireStringBuilderScoped()) { StringBuilder sb = stlSb.get(); @@ -367,13 +403,13 @@ public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, bool // Check if fields are present and in order if (more && matchesFieldName(sb, field)) { - field.readValue(t, defaults, in.getValueIn(), overwrite); + field.readValue(t, defaultValue, in.getValueIn(), overwrite); } else { // If not, copy default values for (; i < fields.length; i++) { FieldAccess field2 = fields[i]; - field2.copy(defaults, t); + field2.setDefaultValue(defaultValue, t); } if (vin == null || sb.length() <= 0) @@ -385,7 +421,7 @@ public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, bool if (fieldAccess == null) vin.skipValue(); else - fieldAccess.readValue(t, defaults, vin, overwrite); + fieldAccess.readValue(t, defaultValue, vin, overwrite); vin = in.read(sb); } while (in.hasMore()); @@ -485,15 +521,8 @@ public T defaultValue() { public void reset(T o) { try { - for (FieldAccess field : fields) { - if (field.isResettable()) { - Object value = field.field.get(o); - if (value != null) - Wires.reset(value); - } else { - field.copy(defaultValue, o); - } - } + for (FieldAccess field : fields) + field.setDefaultValue(defaultValue, o); } catch (IllegalAccessException e) { // should never happen as the types should match. throw new AssertionError(e); @@ -712,87 +741,91 @@ abstract static class FieldAccess { * @return FieldAccess object specific to the field type. */ @Nullable - public static Object create(@NotNull Field field) { + public static Object create(@NotNull Field field, @Nullable Object defaultObject) { Class type = field.getType(); - if (type.isArray()) { - if (type.getComponentType() == byte.class) - return new ByteArrayFieldAccess(field); - return new ArrayFieldAccess(field); - } - if (EnumSet.class.isAssignableFrom(type)) { - final Class componentType = extractClass(computeActualTypeArguments(EnumSet.class, field)[0]); - if (componentType == Object.class || Modifier.isAbstract(componentType.getModifiers())) - throw new RuntimeException("Could not get enum constant directory"); - - boolean isLeaf = !Throwable.class.isAssignableFrom(componentType) - && WIRE_MARSHALLER_CL.get(componentType).isLeaf; - try { - Object[] values = (Object[]) Jvm.getMethod(componentType, "values").invoke(componentType, null); - return new EnumSetFieldAccess(field, isLeaf, values, componentType); - } catch (IllegalAccessException | InvocationTargetException e) { - throw Jvm.rethrow(e); - } - } - if (Collection.class.isAssignableFrom(type)) - return CollectionFieldAccess.of(field); - if (Map.class.isAssignableFrom(type)) - return new MapFieldAccess(field); - - switch (type.getName()) { - case "boolean": - return new BooleanFieldAccess(field); - case "byte": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new ByteLongConverterFieldAccess(field, longConverter); - return new ByteFieldAccess(field); - } - case "char": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new CharLongConverterFieldAccess(field, longConverter); - return new CharFieldAccess(field); - } - case "short": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new ShortLongConverterFieldAccess(field, longConverter); - return new ShortFieldAccess(field); + try { + if (type.isArray()) { + if (type.getComponentType() == byte.class) + return new ByteArrayFieldAccess(field); + return new ArrayFieldAccess(field); } - case "int": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new IntLongConverterFieldAccess(field, longConverter); - return new IntegerFieldAccess(field); + if (EnumSet.class.isAssignableFrom(type)) { + final Class componentType = extractClass(computeActualTypeArguments(EnumSet.class, field)[0]); + if (componentType == Object.class || Modifier.isAbstract(componentType.getModifiers())) + throw new RuntimeException("Could not get enum constant directory"); + + boolean isLeaf = !Throwable.class.isAssignableFrom(componentType) + && WIRE_MARSHALLER_CL.get(componentType).isLeaf; + try { + Object[] values = (Object[]) Jvm.getMethod(componentType, "values").invoke(componentType, null); + return new EnumSetFieldAccess(field, isLeaf, values, componentType); + } catch (IllegalAccessException | InvocationTargetException e) { + throw Jvm.rethrow(e); + } } - case "float": - return new FloatFieldAccess(field); - case "long": { - LongConverter longConverter = acquireLongConverter(field); - - return longConverter == null - ? new LongFieldAccess(field) - : new LongConverterFieldAccess(field, longConverter); + if (Collection.class.isAssignableFrom(type)) + return CollectionFieldAccess.of(field, defaultObject); + if (Map.class.isAssignableFrom(type)) + return new MapFieldAccess(field, defaultObject); + + switch (type.getName()) { + case "boolean": + return new BooleanFieldAccess(field); + case "byte": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new ByteLongConverterFieldAccess(field, longConverter); + return new ByteFieldAccess(field); + } + case "char": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new CharLongConverterFieldAccess(field, longConverter); + return new CharFieldAccess(field); + } + case "short": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new ShortLongConverterFieldAccess(field, longConverter); + return new ShortFieldAccess(field); + } + case "int": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new IntLongConverterFieldAccess(field, longConverter); + return new IntegerFieldAccess(field); + } + case "float": + return new FloatFieldAccess(field); + case "long": { + LongConverter longConverter = acquireLongConverter(field); + + return longConverter == null + ? new LongFieldAccess(field) + : new LongConverterFieldAccess(field, longConverter); + } + case "double": + return new DoubleFieldAccess(field); + case "java.lang.String": + return new StringFieldAccess(field); + case "java.lang.StringBuilder": + return new StringBuilderFieldAccess(field, defaultObject); + case "net.openhft.chronicle.bytes.Bytes": + return new BytesFieldAccess(field); + default: + @Nullable Boolean isLeaf = null; + if (IntValue.class.isAssignableFrom(type)) + return new IntValueAccess(field); + if (LongValue.class.isAssignableFrom(type)) + return new LongValueAccess(field); + if (WireMarshaller.class.isAssignableFrom(type)) + isLeaf = WIRE_MARSHALLER_CL.get(type).isLeaf; + else if (isCollection(type)) + isLeaf = false; + return new ObjectFieldAccess(field, isLeaf, defaultObject); } - case "double": - return new DoubleFieldAccess(field); - case "java.lang.String": - return new StringFieldAccess(field); - case "java.lang.StringBuilder": - return new StringBuilderFieldAccess(field); - case "net.openhft.chronicle.bytes.Bytes": - return new BytesFieldAccess(field); - default: - @Nullable Boolean isLeaf = null; - if (IntValue.class.isAssignableFrom(type)) - return new IntValueAccess(field); - if (LongValue.class.isAssignableFrom(type)) - return new LongValueAccess(field); - if (WireMarshaller.class.isAssignableFrom(type)) - isLeaf = WIRE_MARSHALLER_CL.get(type).isLeaf; - else if (isCollection(type)) - isLeaf = false; - return new ObjectFieldAccess(field, isLeaf); + } catch (IllegalAccessException ex) { + throw Jvm.rethrow(ex); } } @@ -957,6 +990,18 @@ protected void readValue(Object o, Object defaults, ValueIn read, boolean overwr */ protected abstract void setValue(Object o, ValueIn read, boolean overwrite) throws IllegalAccessException; + /** + * Abstract method to reset the value of a field in an object to default value. + * The default value is the one present in objects of that class after no-argument constructor. + * Where possible, existing data structures should be preserved without reallocation to avoid garbage. + * + * @param defaultObject A reference unmodified instance of this class. + * @param o Object to reset the value in. + */ + protected void setDefaultValue(Object defaultObject, Object o) throws IllegalAccessException { + copy(defaultObject, o); + } + /** * Abstract method to convert the value of a field in an object to bytes. * @@ -980,11 +1025,6 @@ public boolean isEqual(Object o1, Object o2) { return false; } } - - protected boolean isResettable() { - return false; - } - } static class IntValueAccess extends FieldAccess { @@ -1046,13 +1086,13 @@ public void getAsBytes(Object o, Bytes bytes) { static class ObjectFieldAccess extends FieldAccess { private final Class type; private final AsMarshallable asMarshallable; - private final boolean resettable; + private final Object defaultValue; - ObjectFieldAccess(@NotNull Field field, Boolean isLeaf) { + ObjectFieldAccess(@NotNull Field field, Boolean isLeaf, Object defaultObject) throws IllegalAccessException { super(field, isLeaf); asMarshallable = Jvm.findAnnotation(field, AsMarshallable.class); type = field.getType(); - resettable = false; + defaultValue = defaultObject == null ? null : field.get(defaultObject); } @Override @@ -1109,15 +1149,27 @@ protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) thro } @Override - public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessException { - bytes.writeUtf8(String.valueOf(field.get(o))); + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue != null && defaultValue instanceof Resettable && !(defaultValue instanceof DynamicEnum)) { + Object existingValue = unsafeGetObject(o, offset); + + if (existingValue == defaultValue) + return; + + if (existingValue != null && existingValue.getClass() == defaultValue.getClass()) { + ((Marshallable) existingValue).reset(); + + return; + } + } + + unsafePutObject(o, offset, defaultValue); } @Override - protected boolean isResettable() { - return resettable; + public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessException { + bytes.writeUtf8(String.valueOf(field.get(o))); } - } static class StringFieldAccess extends FieldAccess { @@ -1143,9 +1195,11 @@ public void getAsBytes(Object o, @NotNull Bytes bytes) { } static class StringBuilderFieldAccess extends FieldAccess { + private StringBuilder defaultValue; - public StringBuilderFieldAccess(@NotNull Field field) { + public StringBuilderFieldAccess(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { super(field, true); + this.defaultValue = defaultObject == null ? null : (StringBuilder) field.get(defaultObject); } @Override @@ -1156,7 +1210,7 @@ protected void getValue(Object o, @NotNull ValueOut write, Object previous) { @Override protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) { - @NotNull StringBuilder sb = unsafeGetObject(o, offset); + StringBuilder sb = unsafeGetObject(o, offset); if (sb == null) { sb = new StringBuilder(); unsafePutObject(o, offset, sb); @@ -1165,6 +1219,24 @@ protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) { unsafePutObject(o, offset, null); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue == null) { + unsafePutObject(o, offset, null); + return; + } + + StringBuilder sb = unsafeGetObject(o, offset); + if (sb == defaultValue) + return; + if (sb == null) { + sb = new StringBuilder(); + unsafePutObject(o, offset, sb); + } + sb.setLength(0); + sb.append(defaultValue); + } + @Override public void getAsBytes(Object o, @NotNull Bytes bytes) { bytes.writeUtf8((CharSequence) unsafeGetObject(o, offset)); @@ -1245,6 +1317,9 @@ public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessEx protected void copy(Object from, Object to) { Bytes fromBytes = unsafeGetObject(from, offset); Bytes toBytes = unsafeGetObject(to, offset); + if (fromBytes == toBytes) + return; + if (fromBytes == null) { unsafePutObject(to, offset, null); return; @@ -1499,12 +1574,14 @@ static class CollectionFieldAccess extends FieldAccess { private final Class componentType; private final Class type; private final BiConsumer sequenceGetter; + private final Collection defaultValue; - public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class componentType, Class type) { + public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class componentType, Class type, @Nullable Object defaultObject) throws IllegalAccessException { super(field, isLeaf); this.collectionSupplier = collectionSupplier == null ? newInstance() : collectionSupplier; this.componentType = componentType; this.type = type; + this.defaultValue = defaultObject == null ? null : (Collection) field.get(defaultObject); sequenceGetter = (o, out) -> { Collection coll; try { @@ -1533,7 +1610,7 @@ public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Sup } @NotNull - static FieldAccess of(@NotNull Field field) { + static FieldAccess of(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { @Nullable final Supplier collectionSupplier; @NotNull final Class componentType; final Class type; @@ -1555,8 +1632,8 @@ else if (type == Set.class) } return componentType == String.class - ? new StringCollectionFieldAccess(field, true, collectionSupplier, type) - : new CollectionFieldAccess(field, isLeaf, collectionSupplier, componentType, type); + ? new StringCollectionFieldAccess(field, true, collectionSupplier, type, defaultObject) + : new CollectionFieldAccess(field, isLeaf, collectionSupplier, componentType, type, defaultObject); } private Supplier newInstance() { @@ -1618,6 +1695,22 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + // TODO very limited form of deep reset, check for immutability + if (defaultValue != null && defaultValue.isEmpty()) { + Collection coll = unsafeGetObject(o, offset); + if (coll == null) { + coll = collectionSupplier.get(); + unsafePutObject(o, offset, coll); + } + coll.clear(); + return; + } + + unsafePutObject(o, offset, defaultValue); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); @@ -1633,6 +1726,7 @@ static class StringCollectionFieldAccess extends FieldAccess { @NotNull final Supplier collectionSupplier; private final Class type; + private final Collection defaultValue; @NotNull private final BiConsumer seqConsumer = (c, in2) -> { Bytes bytes = in2.wireIn().bytes(); @@ -1647,10 +1741,11 @@ static class StringCollectionFieldAccess extends FieldAccess { } }; - public StringCollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class type) { + public StringCollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class type, @Nullable Object defaultObject) throws IllegalAccessException { super(field, isLeaf); this.collectionSupplier = collectionSupplier == null ? newInstance() : collectionSupplier; this.type = type; + this.defaultValue = defaultObject == null ? null : (Collection) field.get(defaultObject); } private Supplier newInstance() { @@ -1679,16 +1774,20 @@ protected void getValue(Object o, @NotNull ValueOut write, Object previous) thro } @Override - protected void copy(Object from, Object to) throws IllegalAccessException { - Collection fromColl = (Collection) field.get(from); + protected void copy(Object from, Object to) { + Collection fromColl = unsafeGetObject(from, offset); if (fromColl == null) { - field.set(to, null); + unsafePutObject(to, offset, null); return; } - Collection coll = (Collection) field.get(to); + setValue(to, fromColl); + } + + private void setValue(Object to, Collection fromColl) { + Collection coll = unsafeGetObject(to, offset); if (coll == null) { coll = collectionSupplier.get(); - field.set(to, coll); + unsafePutObject(to, offset, coll); } coll.clear(); coll.addAll(fromColl); @@ -1714,6 +1813,19 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue == null) { + unsafePutObject(o, offset, null); + return; + } + + Collection coll = unsafeGetObject(o, offset); + if (coll == defaultValue) + return; + setValue(o, coll); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); @@ -1728,8 +1840,9 @@ static class MapFieldAccess extends FieldAccess { private final Class keyType; @NotNull private final Class valueType; + private final Map defaultValue; - MapFieldAccess(@NotNull Field field) { + MapFieldAccess(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { super(field); type = field.getType(); if (type == Map.class) @@ -1742,6 +1855,7 @@ else if (type == SortedMap.class || type == NavigableMap.class) Type[] actualTypeArguments = computeActualTypeArguments(Map.class, field); keyType = extractClass(actualTypeArguments[0]); valueType = extractClass(actualTypeArguments[1]); + defaultValue = defaultObject == null ? null : (Map) field.get(defaultObject); } @NotNull @@ -1792,6 +1906,22 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + // TODO very limited form of deep reset, check for immutability + if (defaultValue != null && defaultValue.isEmpty()) { + Map map = unsafeGetObject(o, offset); + if (map == null) { + map = collectionSupplier.get(); + unsafePutObject(o, offset, map); + } + map.clear(); + return; + } + + unsafePutObject(o, offset, defaultValue); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); diff --git a/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java b/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java index e143aa2c7f..3314e569f9 100644 --- a/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java +++ b/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java @@ -25,8 +25,14 @@ public class WireMarshallerForUnexpectedFields extends WireMarshaller { final CharSequenceObjectMap fieldMap; + /** @deprecated To be removed in x.26 */ + @Deprecated public WireMarshallerForUnexpectedFields(@NotNull Class tClass, @NotNull FieldAccess[] fields, boolean isLeaf) { - super(tClass, fields, isLeaf); + this(fields, isLeaf, defaultValueForType(tClass)); + } + + public WireMarshallerForUnexpectedFields(@NotNull FieldAccess[] fields, boolean isLeaf, T defaultValue) { + super(fields, isLeaf, defaultValue); fieldMap = new CharSequenceObjectMap<>(fields.length * 3); for (FieldAccess field : fields) { fieldMap.put(field.key.name().toString(), field); @@ -35,14 +41,14 @@ public WireMarshallerForUnexpectedFields(@NotNull Class tClass, @NotNull Fiel } @Override - public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallable(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try (ScopedResource stlSb = Wires.acquireStringBuilderScoped()) { ReadMarshallable rm = t instanceof ReadMarshallable ? (ReadMarshallable) t : null; StringBuilder sb = stlSb.get(); int next = 0; if (overwrite) { for (FieldAccess field : fields) { - field.copy(defaults, t); + field.copy(defaultValue(), t); } } while (in.hasMore()) { @@ -76,7 +82,7 @@ public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwr } } } else { - field.readValue(t, defaults, vin, overwrite); + field.readValue(t, defaultValue(), vin, overwrite); } if (pos >= in.bytes().readPosition()) { Jvm.warn().on(getClass(), "Failed to parse " + in.bytes()); diff --git a/src/main/java/net/openhft/chronicle/wire/Wires.java b/src/main/java/net/openhft/chronicle/wire/Wires.java index 3718f9f4a0..8d1b8eb9ef 100644 --- a/src/main/java/net/openhft/chronicle/wire/Wires.java +++ b/src/main/java/net/openhft/chronicle/wire/Wires.java @@ -486,7 +486,7 @@ public static void readMarshallable(@NotNull Object marshallable, @NotNull WireI public static void readMarshallable(Class clazz, @NotNull Object marshallable, @NotNull WireIn wire, boolean overwrite) throws InvalidMarshallableException { WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(clazz == null ? marshallable.getClass() : clazz); - wm.readMarshallable(marshallable, wire, wm.defaultValue(), overwrite); + wm.readMarshallable(marshallable, wire, overwrite); } public static void writeMarshallable(@NotNull Object marshallable, @NotNull WireOut wire) throws InvalidMarshallableException { @@ -499,13 +499,13 @@ public static void writeMarshallable(@NotNull Object marshallable, @NotNull Wire if (writeDefault) marshaller.writeMarshallable(marshallable, wire); else - marshaller.writeMarshallable(marshallable, wire, marshaller.defaultValue(), false); + marshaller.writeMarshallable(marshallable, wire, false); } public static void writeMarshallable(@NotNull Object marshallable, @NotNull WireOut wire, @NotNull Object previous, boolean copy) throws InvalidMarshallableException { assert marshallable.getClass() == previous.getClass(); WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(marshallable.getClass()); - wm.writeMarshallable(marshallable, wire, previous, copy); + wm.writeMarshallable(marshallable, wire, copy); } public static void writeKey(@NotNull Object marshallable, Bytes bytes) { @@ -640,7 +640,7 @@ public static E objectMap(ValueIn in, @Nullable E using, @Nullable Class cla Wire wire = wireSR.get(); WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(aClass); wm.writeMarshallable(e, wire); - wm.readMarshallable(e2, wire, null, false); + wm.readMarshallable(e2, wire, false); } return e2; } diff --git a/src/test/java/net/openhft/chronicle/wire/WireResetTest.java b/src/test/java/net/openhft/chronicle/wire/WireResetTest.java index bb5741ab90..560c25756c 100644 --- a/src/test/java/net/openhft/chronicle/wire/WireResetTest.java +++ b/src/test/java/net/openhft/chronicle/wire/WireResetTest.java @@ -20,7 +20,6 @@ import net.openhft.chronicle.core.io.AbstractCloseable; import net.openhft.chronicle.core.io.Closeable; -import org.junit.Ignore; import org.junit.Test; import java.time.LocalDate; @@ -53,7 +52,6 @@ public void testEventAbstractCloseable() { } } - @Ignore("https://github.com/OpenHFT/Chronicle-Wire/issues/745") @Test //https://github.com/OpenHFT/Chronicle-Wire/issues/732 public void testDeepReset() { @@ -72,7 +70,7 @@ public void testDeepReset() { assertSame(identifier1, event1.identifier); assertNull(event1.identifier.id); assertTrue(event1.identifier.permissions.isEmpty()); - assertNull(event1.identifier.parent.id); + assertNull(event1.identifier.parent); assertTrue(event1.ids.isEmpty()); assertNull(event1.payload); @@ -94,7 +92,7 @@ public void testDeepReset() { assertSame(identifier1, event1.identifier); assertNull(event1.identifier.id); assertTrue(event1.identifier.permissions.isEmpty()); - assertNull(event1.identifier.parent.id); + assertNull(event1.identifier.parent); assertTrue(event1.ids.isEmpty()); assertNull(event1.payload);