Skip to content

Commit

Permalink
Set Context.ROOT for Barrage subscription DoExchange (#1612)
Browse files Browse the repository at this point in the history
  • Loading branch information
devinrsmith authored Nov 30, 2021
1 parent b59e6fe commit 280bc14
Showing 1 changed file with 60 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,29 @@
import io.deephaven.barrage.flatbuf.BarrageMessageWrapper;
import io.deephaven.barrage.flatbuf.BarrageSubscriptionRequest;
import io.deephaven.base.log.LogOutput;
import io.deephaven.chunk.ChunkType;
import io.deephaven.engine.liveness.ReferenceCountedLivenessNode;
import io.deephaven.engine.rowset.RowSet;
import io.deephaven.engine.table.TableDefinition;
import io.deephaven.engine.table.impl.util.BarrageMessage;
import io.deephaven.engine.table.impl.util.BarrageMessage.Listener;
import io.deephaven.extensions.barrage.BarrageSubscriptionOptions;
import io.deephaven.extensions.barrage.table.BarrageTable;
import io.deephaven.extensions.barrage.util.BarrageMessageConsumer;
import io.deephaven.extensions.barrage.util.BarrageProtoUtil;
import io.deephaven.extensions.barrage.util.BarrageStreamReader;
import io.deephaven.extensions.barrage.util.BarrageUtil;
import io.deephaven.engine.table.TableDefinition;
import io.deephaven.engine.liveness.ReferenceCountedLivenessNode;
import io.deephaven.chunk.ChunkType;
import io.deephaven.engine.table.impl.util.BarrageMessage;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.Context;
import io.grpc.MethodDescriptor;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.ClientResponseObserver;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.FlightData;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
import org.jetbrains.annotations.Nullable;

Expand All @@ -45,12 +47,12 @@ public class BarrageSubscriptionImpl extends ReferenceCountedLivenessNode implem
private final String logName;
private final TableHandle tableHandle;
private final BarrageSubscriptionOptions options;
private final ClientCall<Flight.FlightData, BarrageMessage> call;
private final ClientCallStreamObserver<FlightData> observer;

private BarrageTable resultTable;

private boolean subscribed = false;
private volatile boolean connected = false;
private volatile boolean connected = true;

/**
* Represents a BarrageSubscription.
Expand All @@ -71,59 +73,66 @@ public BarrageSubscriptionImpl(
resultTable = BarrageTable.make(tableDefinition, false);
resultTable.addParentReference(this);

final MethodDescriptor<Flight.FlightData, BarrageMessage> subscribeDescriptor =
final MethodDescriptor<FlightData, BarrageMessage> subscribeDescriptor =
getClientDoExchangeDescriptor(options, resultTable.getWireChunkTypes(), resultTable.getWireTypes(),
resultTable.getWireComponentTypes(), new BarrageStreamReader());

this.call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT);

ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver<Flight.FlightData, BarrageMessage>() {
@Override
public void beforeStart(final ClientCallStreamObserver<Flight.FlightData> requestStream) {
requestStream.disableAutoInboundFlowControl();
}
// We need to ensure that the DoExchange RPC does not get attached to the server RPC when this is being called
// from a Deephaven server RPC thread. If we need to generalize this in the future, we may wrap this logic in a
// Channel or interceptor; inject the appropriate Context to use; or have the server RPC set a more appropriate
// Context along the stack.
final ClientCall<FlightData, BarrageMessage> call;
final Context previous = Context.ROOT.attach();
try {
call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT);
} finally {
Context.ROOT.detach(previous);
}
observer = (ClientCallStreamObserver<FlightData>) ClientCalls
.asyncBidiStreamingCall(call, new DoExchangeObserver());

@Override
public void onNext(final BarrageMessage barrageMessage) {
if (barrageMessage == null) {
return;
}
try {
final BarrageMessage.Listener listener = resultTable;
if (!connected || listener == null) {
return;
}
listener.handleBarrageMessage(barrageMessage);
} finally {
barrageMessage.close();
}
}
// Allow the server to send us all commands when there is sufficient bandwidth:
observer.request(Integer.MAX_VALUE);
}

@Override
public void onError(final Throwable t) {
log.error().append(BarrageSubscriptionImpl.this)
.append(": Error detected in subscription: ")
.append(t).endl();
private class DoExchangeObserver implements ClientResponseObserver<FlightData, BarrageMessage> {
@Override
public void beforeStart(final ClientCallStreamObserver<FlightData> requestStream) {
requestStream.disableAutoInboundFlowControl();
}

final BarrageMessage.Listener listener = resultTable;
@Override
public void onNext(final BarrageMessage barrageMessage) {
if (barrageMessage == null) {
return;
}
try (barrageMessage) {
final Listener listener = resultTable;
if (!connected || listener == null) {
return;
}
listener.handleBarrageError(t);
handleDisconnect();
listener.handleBarrageMessage(barrageMessage);
}
}

@Override
public void onCompleted() {
handleDisconnect();
@Override
public void onError(final Throwable t) {
log.error().append(BarrageSubscriptionImpl.this)
.append(": Error detected in subscription: ")
.append(t).endl();

final Listener listener = resultTable;
if (!connected || listener == null) {
return;
}
});

// Allow the server to send us all commands when there is sufficient bandwidth:
call.request(Integer.MAX_VALUE);
listener.handleBarrageError(t);
handleDisconnect();
}

// Although this is a white lie, the call is established
this.connected = true;
@Override
public void onCompleted() {
handleDisconnect();
}
}

@Override
Expand All @@ -134,7 +143,7 @@ public synchronized BarrageTable entireTable() {
}
if (!subscribed) {
// Send the initial subscription:
call.sendMessage(Flight.FlightData.newBuilder()
observer.onNext(FlightData.newBuilder()
.setAppMetadata(ByteStringAccess.wrap(makeRequestInternal(null, null, options)))
.build());
subscribed = true;
Expand Down Expand Up @@ -162,7 +171,7 @@ public synchronized void close() {
if (!connected) {
return;
}
call.halfClose();
observer.onCompleted();
cleanup();
}

Expand Down Expand Up @@ -245,15 +254,15 @@ public static <ReqT, RespT> MethodDescriptor<ReqT, RespT> descriptorFor(
* @param <Options> the options related to deserialization
* @return the client side method descriptor
*/
public static <Options> MethodDescriptor<Flight.FlightData, BarrageMessage> getClientDoExchangeDescriptor(
public static <Options> MethodDescriptor<FlightData, BarrageMessage> getClientDoExchangeDescriptor(
final Options options,
final ChunkType[] columnChunkTypes,
final Class<?>[] columnTypes,
final Class<?>[] componentTypes,
final BarrageMessageConsumer.StreamReader<Options> streamReader) {
return descriptorFor(
MethodDescriptor.MethodType.BIDI_STREAMING, FlightServiceGrpc.SERVICE_NAME, "DoExchange",
ProtoUtils.marshaller(Flight.FlightData.getDefaultInstance()),
ProtoUtils.marshaller(FlightData.getDefaultInstance()),
new BarrageDataMarshaller<>(options, columnChunkTypes, columnTypes, componentTypes, streamReader),
FlightServiceGrpc.getDoExchangeMethod());
}
Expand Down

0 comments on commit 280bc14

Please sign in to comment.