Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #11307 - Explicit demand control in WebSocket endpoints with only onWebSocketFrame #12342

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,46 @@ public void fail(Throwable x)
};
}

/**
* Creates a nested callback that runs completed after
* completing the nested callback.
*
* @param callback The nested callback
* @param completed The completion to run after the nested callback is completed
* @return a new callback.
*/
static Callback from(Callback callback, Runnable completed)
{
return new Callback()
{
@Override
public void succeed()
{
try
{
callback.succeed();
}
finally
{
completed.run();
}
}

@Override
public void fail(Throwable x)
{
try
{
callback.fail(x);
}
finally
{
completed.run();
}
}
};
}

/**
* <p>Method to invoke to succeed the callback.</p>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,22 @@ public String toString()
boolean isRsv2();

boolean isRsv3();

default CloseStatus getCloseStatus()
{
return null;
}

record CloseStatus(int statusCode, String reason)
{
}

/**
* The effective opcode of the frame accounting for the CONTINUATION opcode.
* If the frame is a CONTINUATION frame for a TEXT message, this will return TEXT.
* If the frame is a CONTINUATION frame for a BINARY message, this will return BINARY.
* Otherwise, this will return the same opcode as the frame.
* @return the effective opcode of the frame.
*/
byte getEffectiveOpCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ default void onWebSocketOpen(Session session)
* or data frames either BINARY or TEXT.</p>
*
* @param frame the received frame
* @param callback the callback to complete once the frame has been processed.
*/
default void onWebSocketFrame(Frame frame, Callback callback)
{
Expand Down Expand Up @@ -284,6 +285,7 @@ default void onWebSocketPartialText(String payload, boolean last)
* <p>A WebSocket BINARY message has been received.</p>
*
* @param payload the raw payload array received
* @param callback the callback to complete when the payload has been processed
*/
default void onWebSocketBinary(ByteBuffer payload, Callback callback)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,32 @@
import java.nio.ByteBuffer;

import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;

public class JettyWebSocketFrame implements org.eclipse.jetty.websocket.api.Frame
{
private final Frame frame;
private final byte effectiveOpCode;

/**
* @param frame the core websocket {@link Frame} to wrap as a {@link org.eclipse.jetty.websocket.api.Frame}.
* @deprecated there is no alternative intended to publicly construct a {@link JettyWebSocketFrame}.
*/
@Deprecated(forRemoval = true, since = "12.1.0")
public JettyWebSocketFrame(Frame frame)
{
this(frame, frame.getOpCode());
}

/**
* @param frame the core websocket {@link Frame} to wrap as a Jetty API {@link org.eclipse.jetty.websocket.api.Frame}.
* @param effectiveOpCode the effective OpCode of the Frame, where any CONTINUATION should be replaced with the
* initial opcode of that websocket message.
*/
JettyWebSocketFrame(Frame frame, byte effectiveOpCode)
{
this.frame = frame;
this.effectiveOpCode = effectiveOpCode;
}

@Override
Expand Down Expand Up @@ -92,6 +110,21 @@ public boolean isRsv3()
return frame.isRsv3();
}

@Override
public byte getEffectiveOpCode()
{
return effectiveOpCode;
}

@Override
public CloseStatus getCloseStatus()
{
if (getOpCode() != OpCode.CLOSE)
return null;
org.eclipse.jetty.websocket.core.CloseStatus closeStatus = org.eclipse.jetty.websocket.core.CloseStatus.getCloseStatus(frame);
return new CloseStatus(closeStatus.getCode(), closeStatus.getReason());
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;

import org.eclipse.jetty.util.BufferUtil;
Expand Down Expand Up @@ -69,6 +68,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
private MessageSink binarySink;
private MessageSink activeMessageSink;
private WebSocketSession session;
private byte messageType;

public JettyWebSocketFrameHandler(WebSocketContainer container, Object endpointInstance, JettyWebSocketFrameHandlerMetadata metadata)
{
Expand Down Expand Up @@ -193,44 +193,36 @@ private static MessageSink createMessageSink(Class<? extends MessageSink> sinkCl
@Override
public void onFrame(Frame frame, Callback coreCallback)
{
CompletableFuture<Void> frameCallback = null;
if (frame.getOpCode() == OpCode.TEXT || frame.getOpCode() == OpCode.BINARY)
messageType = frame.getOpCode();

if (frameHandle != null)
{
try
{
frameCallback = new org.eclipse.jetty.websocket.api.Callback.Completable();
frameHandle.invoke(new JettyWebSocketFrame(frame), frameCallback);
byte effectiveOpCode = frame.isDataFrame() ? messageType : frame.getOpCode();
frameHandle.invoke(new JettyWebSocketFrame(frame, effectiveOpCode),
org.eclipse.jetty.websocket.api.Callback.from(coreCallback::succeeded, coreCallback::failed));
}
catch (Throwable cause)
{
coreCallback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " FRAME method error: " + cause.getMessage(), cause));
return;
}

autoDemand();
return;
}

Callback.Completable eventCallback = new Callback.Completable();
switch (frame.getOpCode())
{
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, eventCallback);
case OpCode.BINARY -> onBinaryFrame(frame, eventCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, coreCallback);
case OpCode.BINARY -> onBinaryFrame(frame, coreCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, coreCallback);
case OpCode.PING -> onPingFrame(frame, coreCallback);
case OpCode.PONG -> onPongFrame(frame, coreCallback);
case OpCode.CLOSE -> onCloseFrame(frame, coreCallback);
default -> coreCallback.failed(new IllegalStateException());
};

// Combine the callback from the frame handler and the event handler.
CompletableFuture<Void> callback = eventCallback;
if (frameCallback != null)
callback = frameCallback.thenCompose(ignored -> eventCallback);
callback.whenComplete((r, x) ->
{
if (x == null)
coreCallback.succeeded();
else
coreCallback.failed(x);
});
}
}

@Override
Expand Down Expand Up @@ -358,6 +350,7 @@ private void onPongFrame(Frame frame, Callback callback)
}
else
{
callback.succeeded();
internalDemand();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public void setAutoDemand(boolean autoDemand)
public void setBinaryHandle(Class<? extends MessageSink> sinkClass, MethodHandle binary, Object origin)
{
assertNotSet(this.binaryHandle, "BINARY Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
this.binaryHandle = binary;
this.binarySink = sinkClass;
}
Expand Down Expand Up @@ -85,6 +86,10 @@ public MethodHandle getErrorHandle()
public void setFrameHandle(MethodHandle frame, Object origin)
{
assertNotSet(this.frameHandle, "FRAME Handler", origin);
assertNotSet(this.textHandle, "TEXT Handler", origin);
assertNotSet(this.binaryHandle, "BINARY Handler", origin);
assertNotSet(this.pingHandle, "PING Handler", origin);
assertNotSet(this.pongHandle, "PONG Handler", origin);
this.frameHandle = frame;
}

Expand All @@ -107,6 +112,7 @@ public MethodHandle getOpenHandle()
public void setPingHandle(MethodHandle ping, Object origin)
{
assertNotSet(this.pingHandle, "PING Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.pingHandle = ping;
}

Expand All @@ -118,6 +124,7 @@ public MethodHandle getPingHandle()
public void setPongHandle(MethodHandle pong, Object origin)
{
assertNotSet(this.pongHandle, "PONG Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.pongHandle = pong;
}

Expand All @@ -129,6 +136,7 @@ public MethodHandle getPongHandle()
public void setTextHandle(Class<? extends MessageSink> sinkClass, MethodHandle text, Object origin)
{
assertNotSet(this.textHandle, "TEXT Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.textHandle = text;
this.textSink = sinkClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -72,12 +73,36 @@ public void onMessage(String message) throws IOException
public static class ListenerSocket implements Session.Listener
{
final List<Frame> frames = new CopyOnWriteArrayList<>();
final List<Callback> callbacks = new CopyOnWriteArrayList<>();
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
frames.add(frame);
callback.succeed();
callbacks.add(callback);

// Because no pingListener is registered, the frameListener is responsible for handling pings.
if (frame.getOpCode() == OpCode.PING)
{
session.sendPong(frame.getPayload(), Callback.from(session::demand, callback::fail));
return;
}
else if (frame.getOpCode() == OpCode.CLOSE)
{
Frame.CloseStatus closeStatus = frame.getCloseStatus();
session.close(closeStatus.statusCode(), closeStatus.reason(), Callback.NOOP);
return;
}

session.demand();
}
}

Expand Down Expand Up @@ -109,27 +134,19 @@ public void onWebSocketFrame(Frame frame, Callback callback)
if (frame.getOpCode() == OpCode.TEXT)
textMessages.add(BufferUtil.toString(frame.getPayload()));
callback.succeed();
session.demand();
}
}

@WebSocket(autoDemand = false)
public static class PingSocket extends ListenerSocket
{
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
super.onWebSocketFrame(frame, callback);
if (frame.getType() == Frame.Type.TEXT)
session.sendPing(ByteBuffer.wrap("server-ping".getBytes(StandardCharsets.UTF_8)), Callback.NOOP);
super.onWebSocketFrame(frame, callback);
}
}

Expand Down Expand Up @@ -217,13 +234,23 @@ public void testNoAutoDemand() throws Exception
Frame frame0 = listenerSocket.frames.get(0);
assertThat(frame0.getType(), is(Frame.Type.PONG));
assertThat(StandardCharsets.UTF_8.decode(frame0.getPayload()).toString(), is("ping-0"));
Callback callback0 = listenerSocket.callbacks.get(0);
assertNotNull(callback0);
callback0.succeed();

Frame frame1 = listenerSocket.frames.get(1);
assertThat(frame1.getType(), is(Frame.Type.PONG));
assertThat(StandardCharsets.UTF_8.decode(frame1.getPayload()).toString(), is("ping-1"));
Callback callback1 = listenerSocket.callbacks.get(1);
assertNotNull(callback1);
callback1.succeed();

session.close();
await().atMost(5, TimeUnit.SECONDS).until(listenerSocket.frames::size, is(3));
assertThat(listenerSocket.frames.get(2).getType(), is(Frame.Type.CLOSE));
Callback closeCallback = listenerSocket.callbacks.get(2);
assertNotNull(closeCallback);
closeCallback.succeed();
}

@Test
Expand Down
Loading
Loading