Skip to content

Commit

Permalink
Merge pull request quarkusio#43054 from mkouba/issue-42569
Browse files Browse the repository at this point in the history
WebSockets next: improve default strategies for unhandled failures
  • Loading branch information
mkouba authored Sep 5, 2024
2 parents a367a33 + 06c8510 commit 1473784
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 41 deletions.
4 changes: 2 additions & 2 deletions docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ The method that declares a most-specific supertype of the actual exception is se
NOTE: The `@io.quarkus.websockets.next.OnError` annotation can be also used to declare a global error handler, i.e. a method that is not declared on a WebSocket endpoint. Such a method may not accept `@PathParam` paremeters. Error handlers declared on an endpoint take precedence over the global error handlers.

When an error occurs but no error handler can handle the failure, Quarkus uses the strategy specified by `quarkus.websockets-next.server.unhandled-failure-strategy`.
By default, the connection is closed.
Alternatively, an error message can be logged or no operation performed.
For server endpoints, the error message is logged and the connection is closed by default.
For client endpoints, the error message is logged by default.

[[serialization]]
=== Serialization and deserialization
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.quarkus.websockets.next.test.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
Expand All @@ -12,18 +11,19 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketConnector;

public class UnhandledMessageFailureLogStrategyTest {
public class UnhandledMessageFailureCloseStrategyTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(ServerEndpoint.class, ClientMessageErrorEndpoint.class);
}).overrideConfigKey("quarkus.websockets-next.client.unhandled-failure-strategy", "log");
}).overrideConfigKey("quarkus.websockets-next.client.unhandled-failure-strategy", "close");

@Inject
WebSocketConnector<ClientMessageErrorEndpoint> connector;
Expand All @@ -37,10 +37,11 @@ void testError() throws InterruptedException {
.baseUri(testUri)
.connectAndAwait();
connection.sendTextAndAwait("foo");
assertFalse(connection.isClosed());
connection.sendTextAndAwait("bar");
assertTrue(ClientMessageErrorEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertEquals("bar", ClientMessageErrorEndpoint.MESSAGES.get(0));
assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientMessageErrorEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(connection.isClosed());
assertEquals(WebSocketCloseStatus.INVALID_MESSAGE_TYPE.code(), connection.closeReason().getCode());
assertTrue(ClientMessageErrorEndpoint.MESSAGES.isEmpty());
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkus.websockets.next.test.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
Expand All @@ -11,7 +12,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.WebSocketClientConnection;
Expand All @@ -37,11 +37,10 @@ void testError() throws InterruptedException {
.baseUri(testUri)
.connectAndAwait();
connection.sendTextAndAwait("foo");
assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientMessageErrorEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(connection.isClosed());
assertEquals(WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(), connection.closeReason().getCode());
assertTrue(ClientMessageErrorEndpoint.MESSAGES.isEmpty());
assertFalse(connection.isClosed());
connection.sendTextAndAwait("bar");
assertTrue(ClientMessageErrorEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertEquals("bar", ClientMessageErrorEndpoint.MESSAGES.get(0));
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package io.quarkus.websockets.next.test.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
Expand All @@ -13,18 +11,19 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketConnector;

public class UnhandledOpenFailureLogStrategyTest {
public class UnhandledOpenFailureCloseStrategyTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(ServerEndpoint.class, ClientOpenErrorEndpoint.class);
}).overrideConfigKey("quarkus.websockets-next.client.unhandled-failure-strategy", "log");
}).overrideConfigKey("quarkus.websockets-next.client.unhandled-failure-strategy", "close");

@Inject
WebSocketConnector<ClientOpenErrorEndpoint> connector;
Expand All @@ -37,11 +36,11 @@ void testError() throws InterruptedException {
WebSocketClientConnection connection = connector
.baseUri(testUri)
.connectAndAwait();
connection.sendTextAndAwait("foo");
assertFalse(connection.isClosed());
assertNull(connection.closeReason());
assertTrue(ClientOpenErrorEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertEquals("foo", ClientOpenErrorEndpoint.MESSAGES.get(0));
assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientOpenErrorEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(connection.isClosed());
assertEquals(WebSocketCloseStatus.INVALID_MESSAGE_TYPE.code(), connection.closeReason().getCode());
assertTrue(ClientOpenErrorEndpoint.MESSAGES.isEmpty());
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.quarkus.websockets.next.test.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
Expand All @@ -11,7 +13,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.WebSocketClientConnection;
Expand All @@ -36,11 +37,11 @@ void testError() throws InterruptedException {
WebSocketClientConnection connection = connector
.baseUri(testUri)
.connectAndAwait();
assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientOpenErrorEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(connection.isClosed());
assertEquals(WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(), connection.closeReason().getCode());
assertTrue(ClientOpenErrorEndpoint.MESSAGES.isEmpty());
connection.sendTextAndAwait("foo");
assertFalse(connection.isClosed());
assertNull(connection.closeReason());
assertTrue(ClientOpenErrorEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertEquals("foo", ClientOpenErrorEndpoint.MESSAGES.get(0));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
*/
public enum UnhandledFailureStrategy {
/**
* Close the connection.
* Log the error message and close the connection.
*/
LOG_AND_CLOSE,
/**
* Close the connection silently.
*/
CLOSE,
/**
* Log an error message.
* Log the error message.
*/
LOG,
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ public interface WebSocketsClientRuntimeConfig {
/**
* The strategy used when an error occurs but no error handler can handle the failure.
* <p>
* By default, the connection is closed when an unhandled failure occurs.
* By default, the error message is logged when an unhandled failure occurs.
* <p>
* Note that clients should not close the WebSocket connection arbitrarily. See also RFC-6455
* <a href="https://datatracker.ietf.org/doc/html/rfc6455#section-7.3">section 7.3</a>.
*/
@WithDefault("close")
@WithDefault("log")
UnhandledFailureStrategy unhandledFailureStrategy();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public interface WebSocketsServerRuntimeConfig {
/**
* The strategy used when an error occurs but no error handler can handle the failure.
* <p>
* By default, the connection is closed when an unhandled failure occurs.
* By default, the error message is logged and the connection is closed when an unhandled failure occurs.
*/
@WithDefault("close")
@WithDefault("log-and-close")
UnhandledFailureStrategy unhandledFailureStrategy();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import org.jboss.logging.Logger;

import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InjectableContext;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.security.AuthenticationFailedException;
import io.quarkus.security.ForbiddenException;
import io.quarkus.security.UnauthorizedException;
Expand Down Expand Up @@ -253,18 +255,32 @@ public void handle(Void event) {
private static void handleFailure(UnhandledFailureStrategy strategy, Throwable cause, String message,
WebSocketConnectionBase connection) {
switch (strategy) {
case CLOSE -> closeConnection(cause, connection);
case LOG_AND_CLOSE -> logAndClose(cause, message, connection);
case CLOSE -> closeConnection(cause, message, connection);
case LOG -> logFailure(cause, message, connection);
case NOOP -> LOG.tracef("Unhandled failure ignored: %s", connection);
default -> throw new IllegalArgumentException("Unexpected strategy: " + strategy);
}
}

private static void closeConnection(Throwable cause, WebSocketConnectionBase connection) {
private static void logAndClose(Throwable cause, String message, WebSocketConnectionBase connection) {
logFailure(cause, message, connection);
closeConnection(cause, message, connection);
}

private static void closeConnection(Throwable cause, String message, WebSocketConnectionBase connection) {
if (connection.isClosed()) {
return;
}
connection.close(CloseReason.INTERNAL_SERVER_ERROR).subscribe().with(
CloseReason closeReason;
int statusCode = connection instanceof WebSocketClientConnectionImpl ? WebSocketCloseStatus.INVALID_MESSAGE_TYPE.code()
: WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
if (LaunchMode.current().isDevOrTest()) {
closeReason = new CloseReason(statusCode, cause.getMessage());
} else {
closeReason = new CloseReason(statusCode);
}
connection.close(closeReason).subscribe().with(
v -> LOG.debugf("Connection closed due to unhandled failure %s: %s", cause, connection),
t -> LOG.errorf("Unable to close connection [%s] due to unhandled failure [%s]: %s", connection.id(), cause,
t));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ void binaryMessageSent(WebSocketConnectionBase connection, Buffer payload) {

void connectionClosed(WebSocketConnectionBase connection) {
if (LOG.isDebugEnabled()) {
LOG.debugf("%s connection closed, Connection[%s]",
LOG.debugf("%s connection closed, Connection[%s], %s",
typeToString(),
connectionToString(connection));
connectionToString(connection),
connection.closeReason());
}
}

Expand Down

0 comments on commit 1473784

Please sign in to comment.