diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java index 2c3d6066d2a..8f53607335d 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java @@ -11,27 +11,29 @@ import io.deephaven.barrage.flatbuf.BarrageMessageWrapper; import io.deephaven.barrage.flatbuf.BarrageSubscriptionRequest; import io.deephaven.base.log.LogOutput; +import io.deephaven.chunk.ChunkType; +import io.deephaven.engine.liveness.ReferenceCountedLivenessNode; import io.deephaven.engine.rowset.RowSet; +import io.deephaven.engine.table.TableDefinition; +import io.deephaven.engine.table.impl.util.BarrageMessage; +import io.deephaven.engine.table.impl.util.BarrageMessage.Listener; import io.deephaven.extensions.barrage.BarrageSubscriptionOptions; import io.deephaven.extensions.barrage.table.BarrageTable; import io.deephaven.extensions.barrage.util.BarrageMessageConsumer; import io.deephaven.extensions.barrage.util.BarrageProtoUtil; import io.deephaven.extensions.barrage.util.BarrageStreamReader; import io.deephaven.extensions.barrage.util.BarrageUtil; -import io.deephaven.engine.table.TableDefinition; -import io.deephaven.engine.liveness.ReferenceCountedLivenessNode; -import io.deephaven.chunk.ChunkType; -import io.deephaven.engine.table.impl.util.BarrageMessage; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.Context; import io.grpc.MethodDescriptor; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientResponseObserver; -import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.FlightServiceGrpc; import org.jetbrains.annotations.Nullable; @@ -45,12 +47,12 @@ public class BarrageSubscriptionImpl extends ReferenceCountedLivenessNode implem private final String logName; private final TableHandle tableHandle; private final BarrageSubscriptionOptions options; - private final ClientCall call; + private final ClientCallStreamObserver observer; private BarrageTable resultTable; private boolean subscribed = false; - private volatile boolean connected = false; + private volatile boolean connected = true; /** * Represents a BarrageSubscription. @@ -71,59 +73,66 @@ public BarrageSubscriptionImpl( resultTable = BarrageTable.make(tableDefinition, false); resultTable.addParentReference(this); - final MethodDescriptor subscribeDescriptor = + final MethodDescriptor subscribeDescriptor = getClientDoExchangeDescriptor(options, resultTable.getWireChunkTypes(), resultTable.getWireTypes(), resultTable.getWireComponentTypes(), new BarrageStreamReader()); - this.call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT); - - ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver() { - @Override - public void beforeStart(final ClientCallStreamObserver requestStream) { - requestStream.disableAutoInboundFlowControl(); - } + // We need to ensure that the DoExchange RPC does not get attached to the server RPC when this is being called + // from a Deephaven server RPC thread. If we need to generalize this in the future, we may wrap this logic in a + // Channel or interceptor; inject the appropriate Context to use; or have the server RPC set a more appropriate + // Context along the stack. + final ClientCall call; + final Context previous = Context.ROOT.attach(); + try { + call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT); + } finally { + Context.ROOT.detach(previous); + } + observer = (ClientCallStreamObserver) ClientCalls + .asyncBidiStreamingCall(call, new DoExchangeObserver()); - @Override - public void onNext(final BarrageMessage barrageMessage) { - if (barrageMessage == null) { - return; - } - try { - final BarrageMessage.Listener listener = resultTable; - if (!connected || listener == null) { - return; - } - listener.handleBarrageMessage(barrageMessage); - } finally { - barrageMessage.close(); - } - } + // Allow the server to send us all commands when there is sufficient bandwidth: + observer.request(Integer.MAX_VALUE); + } - @Override - public void onError(final Throwable t) { - log.error().append(BarrageSubscriptionImpl.this) - .append(": Error detected in subscription: ") - .append(t).endl(); + private class DoExchangeObserver implements ClientResponseObserver { + @Override + public void beforeStart(final ClientCallStreamObserver requestStream) { + requestStream.disableAutoInboundFlowControl(); + } - final BarrageMessage.Listener listener = resultTable; + @Override + public void onNext(final BarrageMessage barrageMessage) { + if (barrageMessage == null) { + return; + } + try (barrageMessage) { + final Listener listener = resultTable; if (!connected || listener == null) { return; } - listener.handleBarrageError(t); - handleDisconnect(); + listener.handleBarrageMessage(barrageMessage); } + } - @Override - public void onCompleted() { - handleDisconnect(); + @Override + public void onError(final Throwable t) { + log.error().append(BarrageSubscriptionImpl.this) + .append(": Error detected in subscription: ") + .append(t).endl(); + + final Listener listener = resultTable; + if (!connected || listener == null) { + return; } - }); - - // Allow the server to send us all commands when there is sufficient bandwidth: - call.request(Integer.MAX_VALUE); + listener.handleBarrageError(t); + handleDisconnect(); + } - // Although this is a white lie, the call is established - this.connected = true; + @Override + public void onCompleted() { + handleDisconnect(); + } } @Override @@ -134,7 +143,7 @@ public synchronized BarrageTable entireTable() { } if (!subscribed) { // Send the initial subscription: - call.sendMessage(Flight.FlightData.newBuilder() + observer.onNext(FlightData.newBuilder() .setAppMetadata(ByteStringAccess.wrap(makeRequestInternal(null, null, options))) .build()); subscribed = true; @@ -162,7 +171,7 @@ public synchronized void close() { if (!connected) { return; } - call.halfClose(); + observer.onCompleted(); cleanup(); } @@ -245,7 +254,7 @@ public static MethodDescriptor descriptorFor( * @param the options related to deserialization * @return the client side method descriptor */ - public static MethodDescriptor getClientDoExchangeDescriptor( + public static MethodDescriptor getClientDoExchangeDescriptor( final Options options, final ChunkType[] columnChunkTypes, final Class[] columnTypes, @@ -253,7 +262,7 @@ public static MethodDescriptor getC final BarrageMessageConsumer.StreamReader streamReader) { return descriptorFor( MethodDescriptor.MethodType.BIDI_STREAMING, FlightServiceGrpc.SERVICE_NAME, "DoExchange", - ProtoUtils.marshaller(Flight.FlightData.getDefaultInstance()), + ProtoUtils.marshaller(FlightData.getDefaultInstance()), new BarrageDataMarshaller<>(options, columnChunkTypes, columnTypes, componentTypes, streamReader), FlightServiceGrpc.getDoExchangeMethod()); }