Skip to content

Commit

Permalink
[Improve][API & Zeta] Using connector custom serializer encode/decode…
Browse files Browse the repository at this point in the history
… states

* API: Using DefaultSerializer as connector sink default serializer
* Zeta: Using connector custom serializer encode/decode states
  • Loading branch information
hailin0 committed Aug 7, 2023
1 parent 4f89c1d commit f14a1ee
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public byte[] serialize(T obj) throws IOException {

@Override
public T deserialize(byte[] serialized) throws IOException {
if (serialized == null) {
return null;
}
return SerializationUtils.deserialize(serialized);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.seatunnel.api.common.PluginIdentifierInterface;
import org.apache.seatunnel.api.common.SeaTunnelPluginLifeCycle;
import org.apache.seatunnel.api.serialization.DefaultSerializer;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.source.SeaTunnelJobAware;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
Expand Down Expand Up @@ -84,7 +85,7 @@ default SinkWriter<IN, CommitInfoT, StateT> restoreWriter(
* @return Serializer of {@link StateT}
*/
default Optional<Serializer<StateT>> getWriterStateSerializer() {
return Optional.empty();
return Optional.of(new DefaultSerializer());
}

/**
Expand All @@ -104,7 +105,7 @@ default Optional<SinkCommitter<CommitInfoT>> createCommitter() throws IOExceptio
* @return Serializer of {@link CommitInfoT}
*/
default Optional<Serializer<CommitInfoT>> getCommitInfoSerializer() {
return Optional.empty();
return Optional.of(new DefaultSerializer());
}

/**
Expand All @@ -125,6 +126,6 @@ default Optional<Serializer<CommitInfoT>> getCommitInfoSerializer() {
* @return Serializer of {@link AggregatedCommitInfoT}
*/
default Optional<Serializer<AggregatedCommitInfoT>> getAggregatedCommitInfoSerializer() {
return Optional.empty();
return Optional.of(new DefaultSerializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,10 @@ private List<PhysicalVertex> getSourceTask(
.getJobId(),
taskLocation,
finalParallelismIndex,
f);
(PhysicalExecutionFlow<
SourceAction,
SourceConfig>)
f);
} else {
return new TransformSeaTunnelTask(
jobImmutableInformation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.commons.collections4.CollectionUtils;

import com.hazelcast.cluster.Address;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -45,6 +46,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -75,6 +77,8 @@ public class SinkAggregatedCommitterTask<CommandInfoT, AggregatedCommitInfoT>
private final SinkAggregatedCommitter<CommandInfoT, AggregatedCommitInfoT> aggregatedCommitter;

private transient Serializer<AggregatedCommitInfoT> aggregatedCommitInfoSerializer;
@Getter private transient Serializer<CommandInfoT> commitInfoSerializer;

private Map<Long, Address> writerAddressMap;

private ConcurrentMap<Long, List<CommandInfoT>> commitInfoCache;
Expand Down Expand Up @@ -107,6 +111,7 @@ public void init() throws Exception {
this.writerAddressMap = new ConcurrentHashMap<>();
this.checkpointCommitInfoMap = new ConcurrentHashMap<>();
this.completableFuture = new CompletableFuture<>();
this.commitInfoSerializer = sink.getSink().getCommitInfoSerializer().get();
this.aggregatedCommitInfoSerializer =
sink.getSink().getAggregatedCommitInfoSerializer().get();
log.debug(
Expand Down Expand Up @@ -250,6 +255,7 @@ public void restoreState(List<ActionSubtaskState> actionStateList) throws Except
actionStateList.stream()
.map(ActionSubtaskState::getState)
.flatMap(Collection::stream)
.filter(Objects::nonNull)
.map(
bytes ->
sneaky(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
package org.apache.seatunnel.engine.server.task;

import org.apache.seatunnel.api.common.metrics.MetricsContext;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.server.dag.physical.config.SourceConfig;
import org.apache.seatunnel.engine.server.dag.physical.flow.Flow;
import org.apache.seatunnel.engine.server.dag.physical.flow.PhysicalExecutionFlow;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.flow.SourceFlowLifeCycle;
import org.apache.seatunnel.engine.server.task.record.Barrier;

import com.hazelcast.logging.ILogger;
import com.hazelcast.logging.Logger;
import lombok.Getter;
import lombok.NonNull;

import java.util.List;
Expand All @@ -41,15 +43,24 @@ public class SourceSeaTunnelTask<T, SplitT extends SourceSplit> extends SeaTunne
private transient SeaTunnelSourceCollector<T> collector;

private transient Object checkpointLock;
@Getter private transient Serializer<SplitT> splitSerializer;
private final PhysicalExecutionFlow<SourceAction, SourceConfig> sourceFlow;

public SourceSeaTunnelTask(long jobID, TaskLocation taskID, int indexID, Flow executionFlow) {
public SourceSeaTunnelTask(
long jobID,
TaskLocation taskID,
int indexID,
PhysicalExecutionFlow<SourceAction, SourceConfig> executionFlow) {
super(jobID, taskID, indexID, executionFlow);
this.sourceFlow = executionFlow;
}

@Override
public void init() throws Exception {
super.init();
this.checkpointLock = new Object();
this.splitSerializer = sourceFlow.getAction().getSource().getSplitSerializer();

LOGGER.info("starting seatunnel source task, index " + indexID);
if (!(startFlowLifeCycle instanceof SourceFlowLifeCycle)) {
throw new TaskRuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.hazelcast.cluster.Address;
import com.hazelcast.spi.impl.operationservice.Operation;
import com.hazelcast.spi.impl.operationservice.impl.InvocationFuture;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -77,6 +78,7 @@ public class SourceSplitEnumeratorTask<SplitT extends SourceSplit> extends Coord
private SeaTunnelSplitEnumeratorContext<SplitT> enumeratorContext;

private Serializer<Serializable> enumeratorStateSerializer;
@Getter private Serializer<SplitT> splitSerializer;

private int maxReaderSize;
private Set<Long> unfinishedReaders;
Expand All @@ -102,6 +104,7 @@ public void init() throws Exception {
new SeaTunnelSplitEnumeratorContext<>(
this.source.getParallelism(), this, getMetricsContext());
enumeratorStateSerializer = this.source.getSource().getEnumeratorStateSerializer();
splitSerializer = this.source.getSource().getSplitSerializer();
taskMemberMapping = new ConcurrentHashMap<>();
taskIDToTaskLocationMapping = new ConcurrentHashMap<>();
taskIndexToTaskLocationMapping = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.seatunnel.api.source.SourceEvent;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.api.source.SourceSplitEnumerator;
import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.engine.server.task.SourceSplitEnumeratorTask;
import org.apache.seatunnel.engine.server.task.operation.source.AssignSplitOperation;

Expand All @@ -31,6 +30,9 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.apache.seatunnel.engine.common.utils.ExceptionUtil.sneaky;

@Slf4j
public class SeaTunnelSplitEnumeratorContext<SplitT extends SourceSplit>
Expand Down Expand Up @@ -67,22 +69,26 @@ public void assignSplit(int subtaskIndex, List<SplitT> splits) {
log.warn("No reader is obtained, skip this assign!");
return;
}

List<byte[]> splitBytes =
splits.stream()
.map(split -> sneaky(() -> task.getSplitSerializer().serialize(split)))
.collect(Collectors.toList());
task.getExecutionContext()
.sendToMember(
new AssignSplitOperation<>(
task.getTaskMemberLocationByIndex(subtaskIndex),
SerializationUtils.serialize(splits.toArray())),
task.getTaskMemberLocationByIndex(subtaskIndex), splitBytes),
task.getTaskMemberAddressByIndex(subtaskIndex))
.join();
}

@Override
public void signalNoMoreSplits(int subtaskIndex) {
List<byte[]> emptySplits = Collections.emptyList();
task.getExecutionContext()
.sendToMember(
new AssignSplitOperation<>(
task.getTaskMemberLocationByIndex(subtaskIndex),
SerializationUtils.serialize(Collections.emptyList().toArray())),
task.getTaskMemberLocationByIndex(subtaskIndex), emptySplits),
task.getTaskMemberAddressByIndex(subtaskIndex))
.join();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.seatunnel.api.sink.SinkWriter;
import org.apache.seatunnel.api.table.event.SchemaChangeEvent;
import org.apache.seatunnel.api.table.type.Record;
import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.engine.core.checkpoint.InternalCheckpointListener;
import org.apache.seatunnel.engine.core.dag.actions.SinkAction;
import org.apache.seatunnel.engine.server.checkpoint.ActionStateKey;
Expand All @@ -44,10 +43,10 @@

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand All @@ -66,6 +65,7 @@ public class SinkFlowLifeCycle<T, CommitInfoT extends Serializable, AggregatedCo
private final SinkAction<T, StateT, CommitInfoT, AggregatedCommitInfoT> sinkAction;
private SinkWriter<T, CommitInfoT, StateT> writer;

private transient Optional<Serializer<CommitInfoT>> commitInfoSerializer;
private transient Optional<Serializer<StateT>> writerStateSerializer;

private final int indexID;
Expand Down Expand Up @@ -110,6 +110,7 @@ public SinkFlowLifeCycle(

@Override
public void init() throws Exception {
this.commitInfoSerializer = sinkAction.getSink().getCommitInfoSerializer();
this.writerStateSerializer = sinkAction.getSink().getWriterStateSerializer();
this.committer = sinkAction.getSink().createCommitter();
this.lastCommitInfo = Optional.empty();
Expand Down Expand Up @@ -167,7 +168,7 @@ public void received(Record<?> record) {
throw e;
}
List<StateT> states = writer.snapshotState(barrier.getId());
if (!writerStateSerializer.isPresent()) {
if (states == null || states.isEmpty()) {
runningTask.addState(
barrier, ActionStateKey.of(sinkAction), Collections.emptyList());
} else {
Expand All @@ -184,10 +185,14 @@ public void received(Record<?> record) {
runningTask
.getExecutionContext()
.sendToMember(
new SinkPrepareCommitOperation(
new SinkPrepareCommitOperation<CommitInfoT>(
barrier,
committerTaskLocation,
SerializationUtils.serialize(commitInfoT)),
commitInfoT == null
? null
: commitInfoSerializer
.get()
.serialize(commitInfoT)),
committerTaskAddress)
.join();
}
Expand Down Expand Up @@ -243,22 +248,13 @@ public void notifyCheckpointAborted(long checkpointId) throws Exception {

@Override
public void restoreState(List<ActionSubtaskState> actionStateList) throws Exception {
List<StateT> states = new ArrayList<>();
if (writerStateSerializer.isPresent()) {
states =
actionStateList.stream()
.filter(state -> writerStateSerializer.isPresent())
.map(ActionSubtaskState::getState)
.flatMap(Collection::stream)
.map(
bytes ->
sneaky(
() ->
writerStateSerializer
.get()
.deserialize(bytes)))
.collect(Collectors.toList());
}
List<StateT> states =
actionStateList.stream()
.map(ActionSubtaskState::getState)
.flatMap(Collection::stream)
.filter(Objects::nonNull)
.map(bytes -> sneaky(() -> writerStateSerializer.get().deserialize(bytes)))
.collect(Collectors.toList());
if (states.isEmpty()) {
this.writer =
sinkAction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.seatunnel.api.source.SourceReader;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.api.table.type.Record;
import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.engine.core.checkpoint.CheckpointType;
import org.apache.seatunnel.engine.core.checkpoint.InternalCheckpointListener;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
Expand Down Expand Up @@ -59,7 +58,6 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static org.apache.seatunnel.engine.common.utils.ExceptionUtil.sneaky;
import static org.apache.seatunnel.engine.server.task.AbstractTask.serializeStates;

@Slf4j
Expand Down Expand Up @@ -338,21 +336,17 @@ public void restoreState(List<ActionSubtaskState> actionStateList) throws Except
if (actionStateList.isEmpty()) {
return;
}
List<SplitT> splits =
List<byte[]> splits =
actionStateList.stream()
.map(ActionSubtaskState::getState)
.flatMap(Collection::stream)
.filter(Objects::nonNull)
.map(bytes -> sneaky(() -> splitSerializer.deserialize(bytes)))
.collect(Collectors.toList());
try {
runningTask
.getExecutionContext()
.sendToMember(
new RestoredSplitOperation(
enumeratorTaskLocation,
SerializationUtils.serialize(splits.toArray()),
indexID),
new RestoredSplitOperation(enumeratorTaskLocation, splits, indexID),
enumeratorTaskAddress)
.get();
} catch (InterruptedException | ExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.seatunnel.engine.server.task.operation.sink;

import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.engine.server.SeaTunnelServer;
import org.apache.seatunnel.engine.server.TaskExecutionService;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
Expand All @@ -33,7 +32,7 @@
import java.io.IOException;

@NoArgsConstructor
public class SinkPrepareCommitOperation extends BarrierFlowOperation {
public class SinkPrepareCommitOperation<CommitInfoT> extends BarrierFlowOperation {
private byte[] commitInfos;

public SinkPrepareCommitOperation(
Expand Down Expand Up @@ -73,16 +72,26 @@ public int getClassId() {
public void run() throws Exception {
TaskExecutionService taskExecutionService =
((SeaTunnelServer) getService()).getTaskExecutionService();
SinkAggregatedCommitterTask<?, ?> committerTask =
SinkAggregatedCommitterTask<CommitInfoT, ?> committerTask =
taskExecutionService.getTask(taskLocation);
ClassLoader classLoader =
ClassLoader taskClassLoader =
taskExecutionService
.getExecutionContext(taskLocation.getTaskGroupLocation())
.getClassLoader();
ClassLoader mainClassLoader = Thread.currentThread().getContextClassLoader();

if (commitInfos != null) {
committerTask.receivedWriterCommitInfo(
barrier.getId(), SerializationUtils.deserialize(commitInfos, classLoader));
CommitInfoT deserializeCommitInfo = null;
try {
Thread.currentThread().setContextClassLoader(taskClassLoader);
deserializeCommitInfo =
committerTask.getCommitInfoSerializer().deserialize(commitInfos);
} finally {
Thread.currentThread().setContextClassLoader(mainClassLoader);
}
committerTask.receivedWriterCommitInfo(barrier.getId(), deserializeCommitInfo);
}

committerTask.triggerBarrier(barrier);
}
}
Loading

0 comments on commit f14a1ee

Please sign in to comment.