diff --git a/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java b/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java index e524f73def..a4e637cb45 100644 --- a/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java +++ b/src/main/java/net/openhft/chronicle/wire/AbstractMarshallableCfg.java @@ -49,7 +49,7 @@ public void readMarshallable(@NotNull WireIn wire) throws IORuntimeException, In // Use the WireMarshaller to read the configuration // Field that are not present in the input are not touched. - wm.readMarshallable(this, wire, wm.defaultValue(), false); + wm.readMarshallable(this, wire, false); } /** @@ -69,7 +69,7 @@ public void writeMarshallable(@NotNull WireOut wire) throws InvalidMarshallableE // Use the WireMarshaller to write the configuration // Fields with a default value are not written - marshaller.writeMarshallable(this, wire, marshaller.defaultValue(), false); + marshaller.writeMarshallable(this, wire, false); } /** diff --git a/src/main/java/net/openhft/chronicle/wire/Marshallable.java b/src/main/java/net/openhft/chronicle/wire/Marshallable.java index c2c612aa63..6359f744d0 100644 --- a/src/main/java/net/openhft/chronicle/wire/Marshallable.java +++ b/src/main/java/net/openhft/chronicle/wire/Marshallable.java @@ -136,7 +136,7 @@ default void readMarshallable(@NotNull WireIn wire) throws IORuntimeException, I WireMarshaller wm = WIRE_MARSHALLER_CL.get(this.getClass()); // Delegate the reading process to the obtained WireMarshaller - wm.readMarshallable(this, wire, wm.defaultValue(), true); + wm.readMarshallable(this, wire, true); } /** diff --git a/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java b/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java index 4a08886463..360deea796 100644 --- a/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java +++ b/src/main/java/net/openhft/chronicle/wire/WireMarshaller.java @@ -22,10 +22,10 @@ import net.openhft.chronicle.core.*; import net.openhft.chronicle.core.io.*; import net.openhft.chronicle.core.scoped.ScopedResource; -import net.openhft.chronicle.core.util.ClassLocal; import net.openhft.chronicle.core.util.ClassNotFoundRuntimeException; import net.openhft.chronicle.core.util.ObjectUtils; import net.openhft.chronicle.core.util.StringUtils; +import net.openhft.chronicle.core.util.ClassLocal; import net.openhft.chronicle.core.values.IntValue; import net.openhft.chronicle.core.values.LongValue; import org.jetbrains.annotations.NotNull; @@ -62,7 +62,7 @@ protected WireMarshaller(@NotNull Class tClass, @NotNull FieldAccess[] fields : WireMarshaller.of(tClass) ); - private WireMarshaller(@NotNull FieldAccess[] fields, boolean isLeaf, @Nullable T defaultValue) { + WireMarshaller(@NotNull FieldAccess[] fields, boolean isLeaf, @Nullable T defaultValue) { this.fields = fields; this.isLeaf = isLeaf; this.defaultValue = defaultValue; @@ -76,12 +76,14 @@ public static WireMarshaller of(@NotNull Class tClass) { if (tClass.isInterface() || (tClass.isEnum() && !DynamicEnum.class.isAssignableFrom(tClass))) return new WireMarshaller<>(tClass, NO_FIELDS, true); + T defaultObject = defaultValueForType(tClass); + @NotNull Map map = new LinkedHashMap<>(); getAllField(tClass, map); final FieldAccess[] fields = map.values().stream() // for Java 15+ strip "hidden" fields that can't be accessed in Java 15+ this way. .filter(field -> !(Jvm.isJava15Plus() && field.getName().matches("^.*\\$\\d+$"))) - .map(FieldAccess::create) + .map(field -> FieldAccess.create(field, defaultObject)) .toArray(FieldAccess[]::new); Map fieldCount = Stream.of(fields).collect(Collectors.groupingBy(f -> f.field.getName(), Collectors.counting())); fieldCount.forEach((n, c) -> { @@ -92,8 +94,8 @@ public static WireMarshaller of(@NotNull Class tClass) { .collect(Collectors.toList()); boolean isLeaf = collect.isEmpty(); return overridesUnexpectedFields(tClass) - ? new WireMarshallerForUnexpectedFields<>(tClass, fields, isLeaf) - : new WireMarshaller<>(tClass, fields, isLeaf); + ? new WireMarshallerForUnexpectedFields<>(fields, isLeaf, defaultObject) + : new WireMarshaller<>(fields, isLeaf, defaultObject); } protected static boolean leafable(FieldAccess c) { @@ -116,12 +118,14 @@ private static boolean overridesUnexpectedFields(Class tClass) { @NotNull private static WireMarshaller ofThrowable(@NotNull Class tClass) { + T defaultObject = defaultValueForType(tClass); + @NotNull Map map = new LinkedHashMap<>(); getAllField(tClass, map); final FieldAccess[] fields = map.values().stream() - .map(FieldAccess::create).toArray(FieldAccess[]::new); + .map(field -> FieldAccess.create(field, defaultObject)).toArray(FieldAccess[]::new); boolean isLeaf = false; - return new WireMarshaller<>(tClass, fields, isLeaf); + return new WireMarshaller<>(fields, isLeaf, defaultObject); } private static boolean isCollection(@NotNull Class c) { @@ -149,7 +153,7 @@ public static void getAllField(@NotNull Class clazz, @NotNull Map } } - private static T defaultValueForType(@NotNull Class tClass) { + static T defaultValueForType(@NotNull Class tClass) { // tClass = ObjectUtils.implementationToUse(tClass); if (ObjectUtils.isConcreteClass(tClass) && !tClass.getName().startsWith("java") @@ -279,65 +283,89 @@ public void writeMarshallable(T t, Bytes bytes) { } } + /** + * @see #writeMarshallable(Object, WireOut, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void writeMarshallable(T t, @NotNull WireOut out, T ignored, boolean copy) throws InvalidMarshallableException { + writeMarshallable(t, out, copy); + } + /** * Writes the values of the fields from the provided object (DTO) to the output. Before writing, * the object is validated. The method also supports optional copying of the values * from the source object to a previous instance. * - * @param t Object whose field values are to be written. - * @param out Output destination where the field values are written to. - * @param previous Previous object to compare for optional copying. - * @param copy Flag indicating whether to copy values from the source object to the previous object. + * @param t Object whose field values are to be written. + * @param out Output destination where the field values are written to. + * @param copy Flag indicating whether to copy values from the source object to the previous object. * @throws InvalidMarshallableException If there's an error during marshalling. */ - public void writeMarshallable(T t, @NotNull WireOut out, T previous, boolean copy) throws InvalidMarshallableException { + public void writeMarshallable(T t, @NotNull WireOut out, boolean copy) throws InvalidMarshallableException { // Validate the object before writing ValidatableUtil.validate(t); try { // Iterate through all fields and write their values to the output for (@NotNull FieldAccess field : fields) { - field.write(t, out, previous, copy); + field.write(t, out, defaultValue, copy); } } catch (IllegalAccessException e) { throw new AssertionError(e); } } + /** + * @see #readMarshallable(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallable(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallable(t, in, overwrite); + } + /** * Reads and populates the DTO based on the provided input. The input order can be hinted. * After reading, the object is validated. * * @param t Object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallable(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { // Choose the reading method based on the hint if (in.hintReadInputOrder()) - readMarshallableInputOrder(t, in, defaults, overwrite); + readMarshallableInputOrder(t, in, overwrite); else - readMarshallableDTOOrder(t, in, defaults, overwrite); + readMarshallableDTOOrder(t, in, overwrite); // Validate the object after reading ValidatableUtil.validate(t); } + /** + * @see #readMarshallableDTOOrder(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallableDTOOrder(t, in, overwrite); + } + /** * Reads and populates the DTO based on the provided order. * * @param t Target object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallableDTOOrder(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try { for (@NotNull FieldAccess field : fields) { ValueIn vin = in.read(field.key); - field.readValue(t, defaults, vin, overwrite); + field.readValue(t, defaultValue, vin, overwrite); } ValidatableUtil.validate(t); } catch (IllegalAccessException e) { @@ -345,16 +373,24 @@ public void readMarshallableDTOOrder(T t, @NotNull WireIn in, T defaults, boolea } } + /** + * @see #readMarshallableInputOrder(Object, WireIn, boolean) + * @deprecated To be removed in x.26 + */ + @Deprecated(/* To be removed in x.26 */) + public void readMarshallableInputOrder(T t, @NotNull WireIn in, T ignored, boolean overwrite) throws InvalidMarshallableException { + readMarshallableDTOOrder(t, in, overwrite); + } + /** * Reads and populates the DTO based on the input's order. * * @param t Target object to populate with read values. * @param in Input source from which values are read. - * @param defaults Default values to use if a value isn't provided in the input. * @param overwrite Flag indicating whether to overwrite the existing value in the target object. * @throws InvalidMarshallableException If there is an error during marshalling. */ - public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallableInputOrder(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try (ScopedResource stlSb = Wires.acquireStringBuilderScoped()) { StringBuilder sb = stlSb.get(); @@ -367,13 +403,13 @@ public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, bool // Check if fields are present and in order if (more && matchesFieldName(sb, field)) { - field.readValue(t, defaults, in.getValueIn(), overwrite); + field.readValue(t, defaultValue, in.getValueIn(), overwrite); } else { // If not, copy default values for (; i < fields.length; i++) { FieldAccess field2 = fields[i]; - field2.copy(defaults, t); + field2.setDefaultValue(defaultValue, t); } if (vin == null || sb.length() <= 0) @@ -385,7 +421,7 @@ public void readMarshallableInputOrder(T t, @NotNull WireIn in, T defaults, bool if (fieldAccess == null) vin.skipValue(); else - fieldAccess.readValue(t, defaults, vin, overwrite); + fieldAccess.readValue(t, defaultValue, vin, overwrite); vin = in.read(sb); } while (in.hasMore()); @@ -485,15 +521,8 @@ public T defaultValue() { public void reset(T o) { try { - for (FieldAccess field : fields) { - if (field.isResettable()) { - Object value = field.field.get(o); - if (value != null) - Wires.reset(value); - } else { - field.copy(defaultValue, o); - } - } + for (FieldAccess field : fields) + field.setDefaultValue(defaultValue, o); } catch (IllegalAccessException e) { // should never happen as the types should match. throw new AssertionError(e); @@ -712,87 +741,91 @@ abstract static class FieldAccess { * @return FieldAccess object specific to the field type. */ @Nullable - public static Object create(@NotNull Field field) { + public static Object create(@NotNull Field field, @Nullable Object defaultObject) { Class type = field.getType(); - if (type.isArray()) { - if (type.getComponentType() == byte.class) - return new ByteArrayFieldAccess(field); - return new ArrayFieldAccess(field); - } - if (EnumSet.class.isAssignableFrom(type)) { - final Class componentType = extractClass(computeActualTypeArguments(EnumSet.class, field)[0]); - if (componentType == Object.class || Modifier.isAbstract(componentType.getModifiers())) - throw new RuntimeException("Could not get enum constant directory"); - - boolean isLeaf = !Throwable.class.isAssignableFrom(componentType) - && WIRE_MARSHALLER_CL.get(componentType).isLeaf; - try { - Object[] values = (Object[]) Jvm.getMethod(componentType, "values").invoke(componentType, null); - return new EnumSetFieldAccess(field, isLeaf, values, componentType); - } catch (IllegalAccessException | InvocationTargetException e) { - throw Jvm.rethrow(e); - } - } - if (Collection.class.isAssignableFrom(type)) - return CollectionFieldAccess.of(field); - if (Map.class.isAssignableFrom(type)) - return new MapFieldAccess(field); - - switch (type.getName()) { - case "boolean": - return new BooleanFieldAccess(field); - case "byte": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new ByteLongConverterFieldAccess(field, longConverter); - return new ByteFieldAccess(field); - } - case "char": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new CharLongConverterFieldAccess(field, longConverter); - return new CharFieldAccess(field); - } - case "short": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new ShortLongConverterFieldAccess(field, longConverter); - return new ShortFieldAccess(field); + try { + if (type.isArray()) { + if (type.getComponentType() == byte.class) + return new ByteArrayFieldAccess(field); + return new ArrayFieldAccess(field); } - case "int": { - LongConverter longConverter = acquireLongConverter(field); - if (longConverter != null) - return new IntLongConverterFieldAccess(field, longConverter); - return new IntegerFieldAccess(field); + if (EnumSet.class.isAssignableFrom(type)) { + final Class componentType = extractClass(computeActualTypeArguments(EnumSet.class, field)[0]); + if (componentType == Object.class || Modifier.isAbstract(componentType.getModifiers())) + throw new RuntimeException("Could not get enum constant directory"); + + boolean isLeaf = !Throwable.class.isAssignableFrom(componentType) + && WIRE_MARSHALLER_CL.get(componentType).isLeaf; + try { + Object[] values = (Object[]) Jvm.getMethod(componentType, "values").invoke(componentType, null); + return new EnumSetFieldAccess(field, isLeaf, values, componentType); + } catch (IllegalAccessException | InvocationTargetException e) { + throw Jvm.rethrow(e); + } } - case "float": - return new FloatFieldAccess(field); - case "long": { - LongConverter longConverter = acquireLongConverter(field); - - return longConverter == null - ? new LongFieldAccess(field) - : new LongConverterFieldAccess(field, longConverter); + if (Collection.class.isAssignableFrom(type)) + return CollectionFieldAccess.of(field, defaultObject); + if (Map.class.isAssignableFrom(type)) + return new MapFieldAccess(field, defaultObject); + + switch (type.getName()) { + case "boolean": + return new BooleanFieldAccess(field); + case "byte": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new ByteLongConverterFieldAccess(field, longConverter); + return new ByteFieldAccess(field); + } + case "char": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new CharLongConverterFieldAccess(field, longConverter); + return new CharFieldAccess(field); + } + case "short": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new ShortLongConverterFieldAccess(field, longConverter); + return new ShortFieldAccess(field); + } + case "int": { + LongConverter longConverter = acquireLongConverter(field); + if (longConverter != null) + return new IntLongConverterFieldAccess(field, longConverter); + return new IntegerFieldAccess(field); + } + case "float": + return new FloatFieldAccess(field); + case "long": { + LongConverter longConverter = acquireLongConverter(field); + + return longConverter == null + ? new LongFieldAccess(field) + : new LongConverterFieldAccess(field, longConverter); + } + case "double": + return new DoubleFieldAccess(field); + case "java.lang.String": + return new StringFieldAccess(field); + case "java.lang.StringBuilder": + return new StringBuilderFieldAccess(field, defaultObject); + case "net.openhft.chronicle.bytes.Bytes": + return new BytesFieldAccess(field); + default: + @Nullable Boolean isLeaf = null; + if (IntValue.class.isAssignableFrom(type)) + return new IntValueAccess(field); + if (LongValue.class.isAssignableFrom(type)) + return new LongValueAccess(field); + if (WireMarshaller.class.isAssignableFrom(type)) + isLeaf = WIRE_MARSHALLER_CL.get(type).isLeaf; + else if (isCollection(type)) + isLeaf = false; + return new ObjectFieldAccess(field, isLeaf, defaultObject); } - case "double": - return new DoubleFieldAccess(field); - case "java.lang.String": - return new StringFieldAccess(field); - case "java.lang.StringBuilder": - return new StringBuilderFieldAccess(field); - case "net.openhft.chronicle.bytes.Bytes": - return new BytesFieldAccess(field); - default: - @Nullable Boolean isLeaf = null; - if (IntValue.class.isAssignableFrom(type)) - return new IntValueAccess(field); - if (LongValue.class.isAssignableFrom(type)) - return new LongValueAccess(field); - if (WireMarshaller.class.isAssignableFrom(type)) - isLeaf = WIRE_MARSHALLER_CL.get(type).isLeaf; - else if (isCollection(type)) - isLeaf = false; - return new ObjectFieldAccess(field, isLeaf); + } catch (IllegalAccessException ex) { + throw Jvm.rethrow(ex); } } @@ -957,6 +990,18 @@ protected void readValue(Object o, Object defaults, ValueIn read, boolean overwr */ protected abstract void setValue(Object o, ValueIn read, boolean overwrite) throws IllegalAccessException; + /** + * Abstract method to reset the value of a field in an object to default value. + * The default value is the one present in objects of that class after no-argument constructor. + * Where possible, existing data structures should be preserved without reallocation to avoid garbage. + * + * @param defaultObject A reference unmodified instance of this class. + * @param o Object to reset the value in. + */ + protected void setDefaultValue(Object defaultObject, Object o) throws IllegalAccessException { + copy(defaultObject, o); + } + /** * Abstract method to convert the value of a field in an object to bytes. * @@ -980,11 +1025,6 @@ public boolean isEqual(Object o1, Object o2) { return false; } } - - protected boolean isResettable() { - return false; - } - } static class IntValueAccess extends FieldAccess { @@ -1046,13 +1086,13 @@ public void getAsBytes(Object o, Bytes bytes) { static class ObjectFieldAccess extends FieldAccess { private final Class type; private final AsMarshallable asMarshallable; - private final boolean resettable; + private final Object defaultValue; - ObjectFieldAccess(@NotNull Field field, Boolean isLeaf) { + ObjectFieldAccess(@NotNull Field field, Boolean isLeaf, Object defaultObject) throws IllegalAccessException { super(field, isLeaf); asMarshallable = Jvm.findAnnotation(field, AsMarshallable.class); type = field.getType(); - resettable = false; + defaultValue = defaultObject == null ? null : field.get(defaultObject); } @Override @@ -1109,15 +1149,27 @@ protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) thro } @Override - public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessException { - bytes.writeUtf8(String.valueOf(field.get(o))); + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue != null && defaultValue instanceof Resettable && !(defaultValue instanceof DynamicEnum)) { + Object existingValue = unsafeGetObject(o, offset); + + if (existingValue == defaultValue) + return; + + if (existingValue != null && existingValue.getClass() == defaultValue.getClass()) { + ((Marshallable) existingValue).reset(); + + return; + } + } + + unsafePutObject(o, offset, defaultValue); } @Override - protected boolean isResettable() { - return resettable; + public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessException { + bytes.writeUtf8(String.valueOf(field.get(o))); } - } static class StringFieldAccess extends FieldAccess { @@ -1143,9 +1195,11 @@ public void getAsBytes(Object o, @NotNull Bytes bytes) { } static class StringBuilderFieldAccess extends FieldAccess { + private StringBuilder defaultValue; - public StringBuilderFieldAccess(@NotNull Field field) { + public StringBuilderFieldAccess(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { super(field, true); + this.defaultValue = defaultObject == null ? null : (StringBuilder) field.get(defaultObject); } @Override @@ -1156,7 +1210,7 @@ protected void getValue(Object o, @NotNull ValueOut write, Object previous) { @Override protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) { - @NotNull StringBuilder sb = unsafeGetObject(o, offset); + StringBuilder sb = unsafeGetObject(o, offset); if (sb == null) { sb = new StringBuilder(); unsafePutObject(o, offset, sb); @@ -1165,6 +1219,24 @@ protected void setValue(Object o, @NotNull ValueIn read, boolean overwrite) { unsafePutObject(o, offset, null); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue == null) { + unsafePutObject(o, offset, null); + return; + } + + StringBuilder sb = unsafeGetObject(o, offset); + if (sb == defaultValue) + return; + if (sb == null) { + sb = new StringBuilder(); + unsafePutObject(o, offset, sb); + } + sb.setLength(0); + sb.append(defaultValue); + } + @Override public void getAsBytes(Object o, @NotNull Bytes bytes) { bytes.writeUtf8((CharSequence) unsafeGetObject(o, offset)); @@ -1245,6 +1317,9 @@ public void getAsBytes(Object o, @NotNull Bytes bytes) throws IllegalAccessEx protected void copy(Object from, Object to) { Bytes fromBytes = unsafeGetObject(from, offset); Bytes toBytes = unsafeGetObject(to, offset); + if (fromBytes == toBytes) + return; + if (fromBytes == null) { unsafePutObject(to, offset, null); return; @@ -1499,12 +1574,14 @@ static class CollectionFieldAccess extends FieldAccess { private final Class componentType; private final Class type; private final BiConsumer sequenceGetter; + private final Collection defaultValue; - public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class componentType, Class type) { + public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class componentType, Class type, @Nullable Object defaultObject) throws IllegalAccessException { super(field, isLeaf); this.collectionSupplier = collectionSupplier == null ? newInstance() : collectionSupplier; this.componentType = componentType; this.type = type; + this.defaultValue = defaultObject == null ? null : (Collection) field.get(defaultObject); sequenceGetter = (o, out) -> { Collection coll; try { @@ -1533,7 +1610,7 @@ public CollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Sup } @NotNull - static FieldAccess of(@NotNull Field field) { + static FieldAccess of(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { @Nullable final Supplier collectionSupplier; @NotNull final Class componentType; final Class type; @@ -1555,8 +1632,8 @@ else if (type == Set.class) } return componentType == String.class - ? new StringCollectionFieldAccess(field, true, collectionSupplier, type) - : new CollectionFieldAccess(field, isLeaf, collectionSupplier, componentType, type); + ? new StringCollectionFieldAccess(field, true, collectionSupplier, type, defaultObject) + : new CollectionFieldAccess(field, isLeaf, collectionSupplier, componentType, type, defaultObject); } private Supplier newInstance() { @@ -1618,6 +1695,22 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + // TODO very limited form of deep reset, check for immutability + if (defaultValue != null && defaultValue.isEmpty()) { + Collection coll = unsafeGetObject(o, offset); + if (coll == null) { + coll = collectionSupplier.get(); + unsafePutObject(o, offset, coll); + } + coll.clear(); + return; + } + + unsafePutObject(o, offset, defaultValue); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); @@ -1633,6 +1726,7 @@ static class StringCollectionFieldAccess extends FieldAccess { @NotNull final Supplier collectionSupplier; private final Class type; + private final Collection defaultValue; @NotNull private final BiConsumer seqConsumer = (c, in2) -> { Bytes bytes = in2.wireIn().bytes(); @@ -1647,10 +1741,11 @@ static class StringCollectionFieldAccess extends FieldAccess { } }; - public StringCollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class type) { + public StringCollectionFieldAccess(@NotNull Field field, Boolean isLeaf, @Nullable Supplier collectionSupplier, Class type, @Nullable Object defaultObject) throws IllegalAccessException { super(field, isLeaf); this.collectionSupplier = collectionSupplier == null ? newInstance() : collectionSupplier; this.type = type; + this.defaultValue = defaultObject == null ? null : (Collection) field.get(defaultObject); } private Supplier newInstance() { @@ -1679,16 +1774,20 @@ protected void getValue(Object o, @NotNull ValueOut write, Object previous) thro } @Override - protected void copy(Object from, Object to) throws IllegalAccessException { - Collection fromColl = (Collection) field.get(from); + protected void copy(Object from, Object to) { + Collection fromColl = unsafeGetObject(from, offset); if (fromColl == null) { - field.set(to, null); + unsafePutObject(to, offset, null); return; } - Collection coll = (Collection) field.get(to); + setValue(to, fromColl); + } + + private void setValue(Object to, Collection fromColl) { + Collection coll = unsafeGetObject(to, offset); if (coll == null) { coll = collectionSupplier.get(); - field.set(to, coll); + unsafePutObject(to, offset, coll); } coll.clear(); coll.addAll(fromColl); @@ -1714,6 +1813,19 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + if (defaultValue == null) { + unsafePutObject(o, offset, null); + return; + } + + Collection coll = unsafeGetObject(o, offset); + if (coll == defaultValue) + return; + setValue(o, coll); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); @@ -1728,8 +1840,9 @@ static class MapFieldAccess extends FieldAccess { private final Class keyType; @NotNull private final Class valueType; + private final Map defaultValue; - MapFieldAccess(@NotNull Field field) { + MapFieldAccess(@NotNull Field field, @Nullable Object defaultObject) throws IllegalAccessException { super(field); type = field.getType(); if (type == Map.class) @@ -1742,6 +1855,7 @@ else if (type == SortedMap.class || type == NavigableMap.class) Type[] actualTypeArguments = computeActualTypeArguments(Map.class, field); keyType = extractClass(actualTypeArguments[0]); valueType = extractClass(actualTypeArguments[1]); + defaultValue = defaultObject == null ? null : (Map) field.get(defaultObject); } @NotNull @@ -1792,6 +1906,22 @@ protected void setValue(Object o, ValueIn read, boolean overwrite) { throw new UnsupportedOperationException(); } + @Override + protected void setDefaultValue(Object ignored, Object o) { + // TODO very limited form of deep reset, check for immutability + if (defaultValue != null && defaultValue.isEmpty()) { + Map map = unsafeGetObject(o, offset); + if (map == null) { + map = collectionSupplier.get(); + unsafePutObject(o, offset, map); + } + map.clear(); + return; + } + + unsafePutObject(o, offset, defaultValue); + } + @Override public void getAsBytes(Object o, Bytes bytes) { throw new UnsupportedOperationException(); diff --git a/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java b/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java index e143aa2c7f..3314e569f9 100644 --- a/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java +++ b/src/main/java/net/openhft/chronicle/wire/WireMarshallerForUnexpectedFields.java @@ -25,8 +25,14 @@ public class WireMarshallerForUnexpectedFields extends WireMarshaller { final CharSequenceObjectMap fieldMap; + /** @deprecated To be removed in x.26 */ + @Deprecated public WireMarshallerForUnexpectedFields(@NotNull Class tClass, @NotNull FieldAccess[] fields, boolean isLeaf) { - super(tClass, fields, isLeaf); + this(fields, isLeaf, defaultValueForType(tClass)); + } + + public WireMarshallerForUnexpectedFields(@NotNull FieldAccess[] fields, boolean isLeaf, T defaultValue) { + super(fields, isLeaf, defaultValue); fieldMap = new CharSequenceObjectMap<>(fields.length * 3); for (FieldAccess field : fields) { fieldMap.put(field.key.name().toString(), field); @@ -35,14 +41,14 @@ public WireMarshallerForUnexpectedFields(@NotNull Class tClass, @NotNull Fiel } @Override - public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwrite) throws InvalidMarshallableException { + public void readMarshallable(T t, @NotNull WireIn in, boolean overwrite) throws InvalidMarshallableException { try (ScopedResource stlSb = Wires.acquireStringBuilderScoped()) { ReadMarshallable rm = t instanceof ReadMarshallable ? (ReadMarshallable) t : null; StringBuilder sb = stlSb.get(); int next = 0; if (overwrite) { for (FieldAccess field : fields) { - field.copy(defaults, t); + field.copy(defaultValue(), t); } } while (in.hasMore()) { @@ -76,7 +82,7 @@ public void readMarshallable(T t, @NotNull WireIn in, T defaults, boolean overwr } } } else { - field.readValue(t, defaults, vin, overwrite); + field.readValue(t, defaultValue(), vin, overwrite); } if (pos >= in.bytes().readPosition()) { Jvm.warn().on(getClass(), "Failed to parse " + in.bytes()); diff --git a/src/main/java/net/openhft/chronicle/wire/Wires.java b/src/main/java/net/openhft/chronicle/wire/Wires.java index 3718f9f4a0..8d1b8eb9ef 100644 --- a/src/main/java/net/openhft/chronicle/wire/Wires.java +++ b/src/main/java/net/openhft/chronicle/wire/Wires.java @@ -486,7 +486,7 @@ public static void readMarshallable(@NotNull Object marshallable, @NotNull WireI public static void readMarshallable(Class clazz, @NotNull Object marshallable, @NotNull WireIn wire, boolean overwrite) throws InvalidMarshallableException { WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(clazz == null ? marshallable.getClass() : clazz); - wm.readMarshallable(marshallable, wire, wm.defaultValue(), overwrite); + wm.readMarshallable(marshallable, wire, overwrite); } public static void writeMarshallable(@NotNull Object marshallable, @NotNull WireOut wire) throws InvalidMarshallableException { @@ -499,13 +499,13 @@ public static void writeMarshallable(@NotNull Object marshallable, @NotNull Wire if (writeDefault) marshaller.writeMarshallable(marshallable, wire); else - marshaller.writeMarshallable(marshallable, wire, marshaller.defaultValue(), false); + marshaller.writeMarshallable(marshallable, wire, false); } public static void writeMarshallable(@NotNull Object marshallable, @NotNull WireOut wire, @NotNull Object previous, boolean copy) throws InvalidMarshallableException { assert marshallable.getClass() == previous.getClass(); WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(marshallable.getClass()); - wm.writeMarshallable(marshallable, wire, previous, copy); + wm.writeMarshallable(marshallable, wire, copy); } public static void writeKey(@NotNull Object marshallable, Bytes bytes) { @@ -640,7 +640,7 @@ public static E objectMap(ValueIn in, @Nullable E using, @Nullable Class cla Wire wire = wireSR.get(); WireMarshaller wm = WireMarshaller.WIRE_MARSHALLER_CL.get(aClass); wm.writeMarshallable(e, wire); - wm.readMarshallable(e2, wire, null, false); + wm.readMarshallable(e2, wire, false); } return e2; } diff --git a/src/test/java/net/openhft/chronicle/wire/WireResetTest.java b/src/test/java/net/openhft/chronicle/wire/WireResetTest.java index bb5741ab90..560c25756c 100644 --- a/src/test/java/net/openhft/chronicle/wire/WireResetTest.java +++ b/src/test/java/net/openhft/chronicle/wire/WireResetTest.java @@ -20,7 +20,6 @@ import net.openhft.chronicle.core.io.AbstractCloseable; import net.openhft.chronicle.core.io.Closeable; -import org.junit.Ignore; import org.junit.Test; import java.time.LocalDate; @@ -53,7 +52,6 @@ public void testEventAbstractCloseable() { } } - @Ignore("https://github.com/OpenHFT/Chronicle-Wire/issues/745") @Test //https://github.com/OpenHFT/Chronicle-Wire/issues/732 public void testDeepReset() { @@ -72,7 +70,7 @@ public void testDeepReset() { assertSame(identifier1, event1.identifier); assertNull(event1.identifier.id); assertTrue(event1.identifier.permissions.isEmpty()); - assertNull(event1.identifier.parent.id); + assertNull(event1.identifier.parent); assertTrue(event1.ids.isEmpty()); assertNull(event1.payload); @@ -94,7 +92,7 @@ public void testDeepReset() { assertSame(identifier1, event1.identifier); assertNull(event1.identifier.id); assertTrue(event1.identifier.permissions.isEmpty()); - assertNull(event1.identifier.parent.id); + assertNull(event1.identifier.parent); assertTrue(event1.ids.isEmpty()); assertNull(event1.payload);