Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use one websocket per browser client #2606

Merged
merged 14 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package io.grpc.servlet.web.websocket;

import com.google.common.io.BaseEncoding;
import io.grpc.Attributes;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.ServerTransportListener;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Session;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public abstract class AbstractWebSocketServerStream extends Endpoint {
private static final byte[] BINARY_HEADER_SUFFIX_ARR =
Metadata.BINARY_HEADER_SUFFIX.getBytes(StandardCharsets.US_ASCII);
protected final ServerTransportListener transportListener;
protected final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
protected final int maxInboundMessageSize;
protected final Attributes attributes;

// assigned on open, always available
protected Session websocketSession;

protected AbstractWebSocketServerStream(ServerTransportListener transportListener,
List<? extends ServerStreamTracer.Factory> streamTracerFactories, int maxInboundMessageSize,
Attributes attributes) {
this.transportListener = transportListener;
this.streamTracerFactories = streamTracerFactories;
this.maxInboundMessageSize = maxInboundMessageSize;
this.attributes = attributes;
}

protected static Metadata readHeaders(ByteBuffer headerPayload) {
// Headers are passed as ascii, ":"-separated key/value pairs, separated on "\r\n". The client
// implementation shows that values might be comma-separated, but we'll pass that through directly as a plain
// string.
List<byte[]> byteArrays = new ArrayList<>();
while (headerPayload.hasRemaining()) {
int nameStart = headerPayload.position();
while (headerPayload.hasRemaining() && headerPayload.get() != ':');
int nameEnd = headerPayload.position() - 1;
int valueStart = headerPayload.position() + 1;// assumes that the colon is followed by a space

while (headerPayload.hasRemaining() && headerPayload.get() != '\n');
int valueEnd = headerPayload.position() - 2;// assumes that \n is preceded by a \r, this isnt generally
// safe?
if (valueEnd < valueStart) {
valueEnd = valueStart;
}
int endOfLinePosition = headerPayload.position();

byte[] headerBytes = new byte[nameEnd - nameStart];
headerPayload.position(nameStart);
headerPayload.get(headerBytes);

byteArrays.add(headerBytes);
if (Arrays.equals(headerBytes, "content-type".getBytes(StandardCharsets.US_ASCII))) {
// rewrite grpc-web content type to matching grpc content type, regardless of what it said
byteArrays.add("grpc+proto".getBytes(StandardCharsets.US_ASCII));
// TODO support other formats like text, non-proto
headerPayload.position(valueEnd);
continue;
}

byte[] valueBytes = new byte[valueEnd - valueStart];
headerPayload.position(valueStart);
headerPayload.get(valueBytes);
if (endsWithBinHeaderSuffix(headerBytes)) {
byteArrays.add(BaseEncoding.base64().decode(ByteBuffer.wrap(valueBytes).asCharBuffer()));
} else {
byteArrays.add(valueBytes);
}

headerPayload.position(endOfLinePosition);
}

// add a te:trailers, as gRPC will expect it
byteArrays.add("te".getBytes(StandardCharsets.US_ASCII));
byteArrays.add("trailers".getBytes(StandardCharsets.US_ASCII));

// TODO to support text encoding

return InternalMetadata.newMetadata(byteArrays.toArray(new byte[][] {}));
}

private static boolean endsWithBinHeaderSuffix(byte[] headerBytes) {
// This is intended to be equiv to
// header.endsWith(Metadata.BINARY_HEADER_SUFFIX), without actually making a string for it
if (headerBytes.length < BINARY_HEADER_SUFFIX_ARR.length) {
return false;
}
for (int i = 0; i < BINARY_HEADER_SUFFIX_ARR.length; i++) {
if (headerBytes[headerBytes.length - 3 + i] != BINARY_HEADER_SUFFIX_ARR[i]) {
return false;
}
}
return true;
}

@Override
public void onOpen(Session websocketSession, EndpointConfig config) {
this.websocketSession = websocketSession;

websocketSession.addMessageHandler(String.class, this::onMessage);
websocketSession.addMessageHandler(ByteBuffer.class, message -> {
try {
onMessage(message);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});

// Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses
// can override this method to control those defaults on their own.
websocketSession.setMaxIdleTimeout(0);
websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE);
}

public abstract void onMessage(String message);

public abstract void onMessage(ByteBuffer message) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package io.grpc.servlet.web.websocket;

import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.InternalLogId;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
import io.grpc.internal.ReadableBuffer;
import io.grpc.internal.SerializingExecutor;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBufferAllocator;
import jakarta.websocket.Session;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

public abstract class AbstractWebsocketStreamImpl extends AbstractServerStream {
public final class WebsocketTransportState extends TransportState {

private final SerializingExecutor transportThreadExecutor =
new SerializingExecutor(MoreExecutors.directExecutor());
private final Logger logger;

private WebsocketTransportState(int maxMessageSize, StatsTraceContext statsTraceCtx,
TransportTracer transportTracer, Logger logger) {
super(maxMessageSize, statsTraceCtx, transportTracer);
this.logger = logger;
}

@Override
public void runOnTransportThread(Runnable r) {
transportThreadExecutor.execute(r);
}

@Override
public void bytesRead(int numBytes) {
// no-op, no flow-control yet
}

@Override
public void deframeFailed(Throwable cause) {
if (logger.isLoggable(Level.FINE)) {
logger.log(Level.FINE, String.format("[{%s}] Exception processing message", logId), cause);
}
cancel(Status.fromThrowable(cause));
}
}

protected final TransportState transportState;
protected final Session websocketSession;
protected final InternalLogId logId;
protected final Attributes attributes;

public AbstractWebsocketStreamImpl(WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx,
int maxInboundMessageSize, Session websocketSession, InternalLogId logId, Attributes attributes,
Logger logger) {
super(bufferAllocator, statsTraceCtx);
transportState =
new WebsocketTransportState(maxInboundMessageSize, statsTraceCtx, new TransportTracer(), logger);
this.websocketSession = websocketSession;
this.logId = logId;
this.attributes = attributes;
}

protected static void writeAsciiHeadersToMessage(byte[][] serializedHeaders, ByteBuffer message) {
for (int i = 0; i < serializedHeaders.length; i += 2) {
message.put(serializedHeaders[i]);
message.put((byte) ':');
message.put((byte) ' ');
message.put(serializedHeaders[i + 1]);
message.put((byte) '\r');
message.put((byte) '\n');
}
}

@Override
public int streamId() {
return -1;
}

@Override
public Attributes getAttributes() {
return attributes;
}

public void createStream(ServerTransportListener transportListener, String methodName, Metadata headers) {
transportListener.streamCreated(this, methodName, headers);
transportState().onStreamAllocated();
}

public void inboundDataReceived(ReadableBuffer message, boolean endOfStream) {
transportState().inboundDataReceived(message, endOfStream);
}

public void transportReportStatus(Status status) {
transportState().transportReportStatus(status);
}

@Override
public TransportState transportState() {
return transportState;
}

protected void cancelSink(Status status) {
if (!websocketSession.isOpen() && Status.Code.DEADLINE_EXCEEDED == status.getCode()) {
return;
}
transportState.runOnTransportThread(() -> transportState.transportReportStatus(status));
// There is no way to RST_STREAM with CANCEL code, so write trailers instead
close(Status.CANCELLED.withCause(status.asRuntimeException()), new Metadata());
CountDownLatch countDownLatch = new CountDownLatch(1);
transportState.runOnTransportThread(() -> {
try {
websocketSession.close();
} catch (IOException ioException) {
// already closing, ignore
}
countDownLatch.countDown();
});
try {
countDownLatch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.grpc.servlet.web.websocket;

import io.grpc.internal.WritableBuffer;

import static java.lang.Math.max;
import static java.lang.Math.min;

final class ByteArrayWritableBuffer implements WritableBuffer {

private final int capacity;
final byte[] bytes;
private int index;

ByteArrayWritableBuffer(int capacityHint) {
this.bytes = new byte[min(1024 * 1024, max(4096, capacityHint))];
this.capacity = bytes.length;
}

@Override
public void write(byte[] src, int srcIndex, int length) {
System.arraycopy(src, srcIndex, bytes, index, length);
index += length;
}

@Override
public void write(byte b) {
bytes[index++] = b;
}

@Override
public int writableBytes() {
return capacity - index;
}

@Override
public int readableBytes() {
return index;
}

@Override
public void release() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.grpc.servlet.web.websocket;

import jakarta.websocket.CloseReason;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Session;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

/**
* Supports multiple endpoints, based on the negotiated sub-protocol. If a protocol isn't supported, an error will be
* sent to the client.
*/
public class GrpcWebsocket extends Endpoint {
private final Map<String, Supplier<Endpoint>> endpointFactories = new HashMap<>();
private Endpoint endpoint;

public GrpcWebsocket(Map<String, Supplier<Endpoint>> endpoints) {
endpointFactories.putAll(endpoints);
}

public void onOpen(Session session, EndpointConfig endpointConfig) {
Supplier<Endpoint> supplier = endpointFactories.get(session.getNegotiatedSubprotocol());
if (supplier == null) {
try {
session.close(new CloseReason(CloseReason.CloseCodes.PROTOCOL_ERROR, "Unsupported subprotocol"));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return;
}

endpoint = supplier.get();
endpoint.onOpen(session, endpointConfig);
}

@Override
public void onClose(Session session, CloseReason closeReason) {
endpoint.onClose(session, closeReason);
}

@Override
public void onError(Session session, Throwable thr) {
endpoint.onError(session, thr);
}
}
Loading