Skip to content

Commit

Permalink
Simplify ListSerializer to use a single hook
Browse files Browse the repository at this point in the history
ListSerializer really only needs a single hook that is invoked before
each value is written, and by making this hook an IntComsumer that
can inspect if the first value or subsequent values are being written,
there's no need for a between-values hook. To know if any values were
written, the ListSerializer now exposes its position which can be
queried after values are serialized.
  • Loading branch information
mtdowling committed Apr 10, 2024
1 parent c05a39a commit 2787a1c
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
package software.amazon.smithy.java.runtime.core.serde;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.Instant;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.IntConsumer;
import software.amazon.smithy.java.runtime.core.schema.SdkSchema;
import software.amazon.smithy.java.runtime.core.serde.document.Document;

Expand All @@ -22,28 +23,16 @@
public final class ListSerializer implements ShapeSerializer {

private final ShapeSerializer delegate;
private final ThrowableRunnable betweenValues;
private final ThrowableRunnable afterValues;
private boolean wroteValue = false;
private final IntConsumer beforeEachValue;
private int position = 0;

/**
* @param delegate Delegate that does the actual value serialization.
* @param betweenValues Method to invoke between each value.
* @param delegate Delegate that does the actual value serialization.
* @param beforeEachValue Invoked before each value and given the current position in the list.
*/
public ListSerializer(ShapeSerializer delegate, ThrowableRunnable betweenValues) {
this(delegate, betweenValues, () -> {
});
}

/**
* @param delegate Delegate that does the actual value serialization.
* @param betweenValues Method to invoke between each value.
* @param afterValues Method to invoke after each value.
*/
public ListSerializer(ShapeSerializer delegate, ThrowableRunnable betweenValues, ThrowableRunnable afterValues) {
this.delegate = delegate;
this.betweenValues = betweenValues;
this.afterValues = afterValues;
public ListSerializer(ShapeSerializer delegate, IntConsumer beforeEachValue) {
this.delegate = Objects.requireNonNull(delegate, "delegate is null");
this.beforeEachValue = Objects.requireNonNull(beforeEachValue, "beforeEachValue is null");
}

@Override
Expand All @@ -52,158 +41,111 @@ public void flush() throws IOException {
}

private void beforeWrite() {
try {
if (wroteValue) {
betweenValues.run();
} else {
wroteValue = true;
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
beforeEachValue.accept(position++);
}

private <T> T afterWrite(T result) {
if (afterValues != null) {
try {
afterValues.run();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
return result;
/**
* Get the current index of the serializer where the next element would be written.
*
* @return the current index.
*/
public int position() {
return position;
}

@Override
public StructSerializer beginStruct(SdkSchema schema) {
beforeWrite();

if (afterValues == null) {
return delegate.beginStruct(schema);
} else {
// Wrap the structure serializer so that afterWrite callback can be invoked.
StructSerializer delegateSerializer = delegate.beginStruct(schema);
return afterWrite(new StructSerializer() {
@Override
public void endStruct() {
delegateSerializer.endStruct();
try {
afterValues.run();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
public void member(SdkSchema member, Consumer<ShapeSerializer> memberWriter) {
delegateSerializer.member(member, memberWriter);
}
});
}
return delegate.beginStruct(schema);
}

@Override
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
beforeWrite();
delegate.beginList(schema, consumer);
afterWrite(null);
}

@Override
public void beginMap(SdkSchema schema, Consumer<MapSerializer> consumer) {
beforeWrite();
delegate.beginMap(schema, consumer);
afterWrite(null);
}

@Override
public void writeBoolean(SdkSchema schema, boolean value) {
beforeWrite();
delegate.writeBoolean(schema, value);
afterWrite(null);
}

@Override
public void writeShort(SdkSchema schema, short value) {
beforeWrite();
delegate.writeShort(schema, value);
afterWrite(null);
}

@Override
public void writeByte(SdkSchema schema, byte value) {
beforeWrite();
delegate.writeByte(schema, value);
afterWrite(null);
}

@Override
public void writeInteger(SdkSchema schema, int value) {
beforeWrite();
delegate.writeInteger(schema, value);
afterWrite(null);
}

@Override
public void writeLong(SdkSchema schema, long value) {
beforeWrite();
delegate.writeLong(schema, value);
afterWrite(null);
}

@Override
public void writeFloat(SdkSchema schema, float value) {
beforeWrite();
delegate.writeFloat(schema, value);
afterWrite(null);
}

@Override
public void writeDouble(SdkSchema schema, double value) {
beforeWrite();
delegate.writeDouble(schema, value);
afterWrite(null);
}

@Override
public void writeBigInteger(SdkSchema schema, BigInteger value) {
beforeWrite();
delegate.writeBigInteger(schema, value);
afterWrite(null);
}

@Override
public void writeBigDecimal(SdkSchema schema, BigDecimal value) {
beforeWrite();
delegate.writeBigDecimal(schema, value);
afterWrite(null);
}

@Override
public void writeString(SdkSchema schema, String value) {
beforeWrite();
delegate.writeString(schema, value);
afterWrite(null);
}

@Override
public void writeBlob(SdkSchema schema, byte[] value) {
beforeWrite();
delegate.writeBlob(schema, value);
afterWrite(null);
}

@Override
public void writeTimestamp(SdkSchema schema, Instant value) {
beforeWrite();
delegate.writeTimestamp(schema, value);
afterWrite(null);
}

@Override
public void writeDocument(Document value) {
beforeWrite();
delegate.writeDocument(value);
afterWrite(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,16 @@ public void member(SdkSchema member, Consumer<ShapeSerializer> memberWriter) {
@Override
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
indent();
consumer.accept(new ListSerializer(this, () -> append(',').append(System.lineSeparator())));
consumer.accept(new ListSerializer(this, this::writeComma));
dedent();
}

private void writeComma(int position) {
if (position > 0) {
append(',').append(System.lineSeparator());
}
}

@Override
public void beginMap(SdkSchema schema, Consumer<MapSerializer> consumer) {
indent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ public void member(SdkSchema member, Consumer<ShapeSerializer> memberWriter) {
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
List<Document> elements = new ArrayList<>();
var elementParser = new DocumentParser();
ListSerializer serializer = new ListSerializer(elementParser, () -> {
elements.add(elementParser.result);
elementParser.result = null;
}, () -> {});
ListSerializer serializer = new ListSerializer(elementParser, position -> {
if (position > 0) {
elements.add(elementParser.result);
elementParser.result = null;
}
});
consumer.accept(serializer);
if (elementParser.result != null) {
elements.add(elementParser.result);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.java.runtime.core.serde;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;

import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;
import software.amazon.smithy.java.runtime.core.schema.PreludeSchemas;
import software.amazon.smithy.java.runtime.core.schema.SdkSchema;

public class ListSerializerTest {
@Test
public void incrementsPosition() {
List<Integer> positions = new ArrayList<>();
List<String> strings = new ArrayList<>();

var delegate = new SpecificShapeSerializer() {
@Override
public void writeString(SdkSchema schema, String value) {
strings.add(value);
}
};

ListSerializer listSerializer = new ListSerializer(delegate, positions::add);

listSerializer.writeString(PreludeSchemas.STRING, "1");
listSerializer.writeString(PreludeSchemas.STRING, "2");

assertThat(positions, contains(0, 1));
assertThat(strings, contains("1", "2"));
assertThat(listSerializer.position(), is(2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ protected RuntimeException throwForInvalidState(SdkSchema schema) {

@Override
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
consumer.accept(new ListSerializer(this, () -> {
}));
consumer.accept(new ListSerializer(this, position -> {}));
}

void writeHeader(SdkSchema schema, Supplier<String> supplier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ protected RuntimeException throwForInvalidState(SdkSchema schema) {

@Override
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
consumer.accept(new ListSerializer(this, () -> {
}));
consumer.accept(new ListSerializer(this, position -> {}));
}

void writeQuery(SdkSchema schema, Supplier<String> supplier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void writeString(SdkSchema schema, String value) {

@Override
public void writeTimestamp(SdkSchema schema, Instant value) {
var formatter = schema.hasTrait(TimestampFormatTrait.class)
var formatter = useTimestampFormat && schema.hasTrait(TimestampFormatTrait.class)
? TimestampFormatter.of(schema.getTrait(TimestampFormatTrait.class))
: defaultTimestampFormat;
formatter.serializeToUnderlyingFormat(schema, value, this);
Expand All @@ -167,13 +167,23 @@ public StructSerializer beginStruct(SdkSchema schema) {
public void beginList(SdkSchema schema, Consumer<ShapeSerializer> consumer) {
try {
stream.writeArrayStart();
consumer.accept(new ListSerializer(this, stream::writeMore));
consumer.accept(new ListSerializer(this, this::writeComma));
stream.writeArrayEnd();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

private void writeComma(int position) {
if (position > 0) {
try {
stream.writeMore();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
public void beginMap(SdkSchema schema, Consumer<MapSerializer> consumer) {
try {
Expand Down

0 comments on commit 2787a1c

Please sign in to comment.