Skip to content

Commit

Permalink
Allow SSE events to be filtered out from REST Client
Browse files Browse the repository at this point in the history
(cherry picked from commit c9d1eea)
  • Loading branch information
geoand authored and gsmet committed Nov 21, 2023
1 parent fca5cb8 commit 7122897
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
Expand All @@ -23,6 +24,7 @@
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.RestStreamElementType;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand Down Expand Up @@ -136,6 +138,25 @@ public void accept(SseEvent<Dto> event) {
new EventContainer("id1", "name1", new Dto("name1", "1"))));
}

@Test
void shouldBeAbleReadEntireEventWhileAlsoBeingAbleToFilterEvents() {
var resultList = new CopyOnWriteArrayList<>();
createClient()
.eventWithFilter()
.subscribe().with(new Consumer<>() {
@Override
public void accept(SseEvent<Dto> event) {
resultList.add(new EventContainer(event.id(), event.name(), event.data()));
}
});
await().atMost(5, TimeUnit.SECONDS)
.untilAsserted(
() -> assertThat(resultList).containsExactly(
new EventContainer("id", "n0", new Dto("name0", "0")),
new EventContainer("id", "n1", new Dto("name1", "1")),
new EventContainer("id", "n2", new Dto("name2", "2"))));
}

static class EventContainer {
final String id;
final String name;
Expand Down Expand Up @@ -212,6 +233,26 @@ public interface SseClient {
@Path("/event")
@Produces(MediaType.SERVER_SENT_EVENTS)
Multi<SseEvent<Dto>> event();

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
@SseEventFilter(CustomFilter.class)
Multi<SseEvent<Dto>> eventWithFilter();
}

public static class CustomFilter implements Predicate<SseEvent<String>> {

@Override
public boolean test(SseEvent<String> event) {
if ("heartbeat".equals(event.id())) {
return false;
}
if ("END".equals(event.data())) {
return false;
}
return true;
}
}

@Path("/sse")
Expand Down Expand Up @@ -261,6 +302,50 @@ public void event(@Context SseEventSink sink, @Context Sse sse) {
}
}
}

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
public void eventWithFilter(@Context SseEventSink sink, @Context Sse sse) {
try (sink) {
sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name0", "0"))
.name("n0")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name1", "1"))
.name("n1")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name2", "2"))
.name("n2")
.build());

sink.send(sse.newEventBuilder()
.id("end")
.data("END")
.build());
}
}
}

@Path("/sse-rest-stream-element-type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.eclipse.microprofile.rest.client.annotation.RegisterProviders;
import org.eclipse.microprofile.rest.client.ext.ResponseExceptionMapper;
import org.jboss.jandex.DotName;
import org.jboss.resteasy.reactive.client.SseEventFilter;

import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.ClientFormParam;
Expand Down Expand Up @@ -41,6 +42,8 @@ public class DotNames {

static final DotName METHOD = DotName.createSimple(Method.class.getName());

public static final DotName SSE_EVENT_FILTER = DotName.createSimple(SseEventFilter.class);

private DotNames() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.jboss.resteasy.reactive.common.util.QuarkusMultivaluedHashMap;

import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
Expand Down Expand Up @@ -371,6 +372,42 @@ void registerCompressionInterceptors(BuildProducer<ReflectiveClassBuildItem> ref
}
}

@BuildStep
void handleSseEventFilter(BuildProducer<ReflectiveClassBuildItem> reflectiveClasses,
BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) {
var index = beanArchiveIndexBuildItem.getIndex();
Collection<AnnotationInstance> instances = index.getAnnotations(DotNames.SSE_EVENT_FILTER);
if (instances.isEmpty()) {
return;
}

List<String> filterClassNames = new ArrayList<>(instances.size());
for (AnnotationInstance instance : instances) {
if (instance.target().kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
if (instance.value() == null) {
continue; // can't happen
}
Type filterType = instance.value().asClass();
DotName filterClassName = filterType.name();
ClassInfo filterClassInfo = index.getClassByName(filterClassName.toString());
if (filterClassInfo == null) {
log.warn("Unable to find class '" + filterType.name() + "' in index");
} else if (!filterClassInfo.hasNoArgsConstructor()) {
throw new RestClientDefinitionException(
"Classes used in @SseEventFilter must have a no-args constructor. Offending class is '"
+ filterClassName + "'");
} else {
filterClassNames.add(filterClassName.toString());
}
}
reflectiveClasses.produce(ReflectiveClassBuildItem
.builder(filterClassNames.toArray(new String[0]))
.constructors(true)
.build());
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
void addRestClientBeans(Capabilities capabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,38 @@
*/
public interface SseEvent<T> {

/**
* Get event identifier.
* <p>
* Contains value of SSE {@code "id"} field. This field is optional. Method may return {@code null}, if the event
* identifier is not specified.
*
* @return event id.
*/
String id();

/**
* Get event name.
* <p>
* Contains value of SSE {@code "event"} field. This field is optional. Method may return {@code null}, if the event
* name is not specified.
*
* @return event name, or {@code null} if not set.
*/
String name();

/**
* Get a comment string that accompanies the event.
* <p>
* Contains value of the comment associated with SSE event. This field is optional. Method may return {@code null}, if
* the event comment is not specified.
*
* @return comment associated with the event.
*/
String comment();

/**
* Get event data.
*/
T data();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.jboss.resteasy.reactive.client;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.function.Predicate;

/**
* Used when not all SSE events streamed from the server should be included in the event stream returned by the client.
* <p>
* IMPORTANT: implementations MUST contain a no-args constructor
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SseEventFilter {

/**
* Predicate which decides whether an event should be included in the event stream returned by the client.
*/
Class<? extends Predicate<SseEvent<String>>> value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;

import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.GenericType;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;

import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.common.jaxrs.ResponseImpl;
import org.jboss.resteasy.reactive.common.util.RestMediaType;

Expand Down Expand Up @@ -45,8 +48,8 @@ public <R> Multi<R> get(GenericType<R> responseType) {

/**
* We need this class to work around a bug in Mutiny where we can register our cancel listener
* after the subscription is cancelled and we never get notified
* See https://github.com/smallrye/smallrye-mutiny/issues/417
* after the subscription is cancelled, and we never get notified
* See <a href="https://github.com/smallrye/smallrye-mutiny/issues/417">...</a>
*/
static class MultiRequest<R> {

Expand Down Expand Up @@ -127,9 +130,11 @@ public <R> Multi<R> method(String name, Entity<?> entity, GenericType<R> respons
if (!emitter.isCancelled()) {
if (response.getStatus() == 200
&& MediaType.SERVER_SENT_EVENTS_TYPE.isCompatible(response.getMediaType())) {
registerForSse(multiRequest, responseType, response, vertxResponse,
registerForSse(
multiRequest, responseType, vertxResponse,
(String) restClientRequestContext.getProperties()
.get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP));
.get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP),
restClientRequestContext.getInvokedMethod());
} else if (response.getStatus() == 200
&& RestMediaType.APPLICATION_STREAM_JSON_TYPE.isCompatible(response.getMediaType())) {
registerForJsonStream(multiRequest, restClientRequestContext, responseType, response,
Expand All @@ -156,14 +161,16 @@ private boolean isNewlineDelimited(ResponseImpl response) {
@SuppressWarnings({ "unchecked", "rawtypes" })
private <R> void registerForSse(MultiRequest<? super R> multiRequest,
GenericType<R> responseType,
Response response,
HttpClientResponse vertxResponse, String defaultContentType) {
HttpClientResponse vertxResponse, String defaultContentType,
Method invokedMethod) {

boolean returnSseEvent = SseEvent.class.equals(responseType.getRawType());
GenericType responseTypeFirstParam = responseType.getType() instanceof ParameterizedType
? new GenericType(((ParameterizedType) responseType.getType()).getActualTypeArguments()[0])
: null;

Predicate<SseEvent<String>> eventPredicate = createEventPredicate(invokedMethod);

// honestly, isn't reconnect contradictory with completion?
// FIXME: Reconnect settings?
// For now we don't want multi to reconnect
Expand All @@ -172,8 +179,39 @@ private <R> void registerForSse(MultiRequest<? super R> multiRequest,

multiRequest.onCancel(sseSource::close);
sseSource.register(event -> {

// TODO: we might want to cut down on the allocations here...

if (eventPredicate != null) {
boolean keep = eventPredicate.test(new SseEvent<>() {
@Override
public String id() {
return event.getId();
}

@Override
public String name() {
return event.getName();
}

@Override
public String comment() {
return event.getComment();
}

@Override
public String data() {
return event.readData();
}
});
if (!keep) {
return;
}
}

// DO NOT pass the response mime type because it's SSE: let the event pick between the X-SSE-Content-Type header or
// the content-type SSE field

if (returnSseEvent) {
multiRequest.emit((R) new SseEvent() {
@Override
Expand Down Expand Up @@ -212,6 +250,23 @@ public Object data() {
sseSource.registerAfterRequest(vertxResponse);
}

private Predicate<SseEvent<String>> createEventPredicate(Method invokedMethod) {
if (invokedMethod == null) {
return null; // should never happen
}

SseEventFilter filterAnnotation = invokedMethod.getAnnotation(SseEventFilter.class);
if (filterAnnotation == null) {
return null;
}

try {
return filterAnnotation.value().getConstructor().newInstance();
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new RuntimeException(e);
}
}

private <R> void registerForChunks(MultiRequest<? super R> multiRequest,
RestClientRequestContext restClientRequestContext,
GenericType<R> responseType,
Expand Down

0 comments on commit 7122897

Please sign in to comment.