Skip to content

Commit

Permalink
Produce a TrailersOnly response when gRPC endpoints throw an error. (#…
Browse files Browse the repository at this point in the history
…1573)

Motivation:
On some erroneous cases we can optimise the gRPC response to return TrailerOnly [1] when headers aren't written to the wire yet.

Modifications:
gRPC router was altered to generate a TrailersOnly response when an error happens before headers are sent out.

Result:
An optimised response that contains only a single frame.
  • Loading branch information
tkountis authored May 25, 2021
1 parent bdc0c9f commit ce49400
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -97,8 +97,8 @@ final class GrpcRouter {

private static final GrpcStatus STATUS_UNIMPLEMENTED = fromCodeValue(UNIMPLEMENTED.value());
private static final StreamingHttpService NOT_FOUND_SERVICE = (ctx, request, responseFactory) -> {
final StreamingHttpResponse response = newResponse(responseFactory, null, STATUS_UNIMPLEMENTED,
ctx.executionContext().bufferAllocator());
final StreamingHttpResponse response = newErrorResponse(responseFactory, null, STATUS_UNIMPLEMENTED,
null, ctx.executionContext().bufferAllocator());
response.version(request.version());
return succeeded(response);
};
Expand Down Expand Up @@ -264,10 +264,10 @@ public Single<HttpResponse> handle(final HttpServiceContext ctx, final HttpReque
.payloadBody(rawResp,
serializationProvider.serializerFor(responseEncoding,
responseClass)))
.onErrorReturn(cause -> newErrorResponse(responseFactory,
finalServiceContext, cause, ctx.executionContext().bufferAllocator()));
.onErrorReturn(cause -> newErrorResponse(responseFactory, finalServiceContext,
null, cause, ctx.executionContext().bufferAllocator()));
} catch (Throwable t) {
return succeeded(newErrorResponse(responseFactory, serviceContext, t,
return succeeded(newErrorResponse(responseFactory, serviceContext, null, t,
ctx.executionContext().bufferAllocator()));
}
}
Expand Down Expand Up @@ -320,7 +320,7 @@ public Single<StreamingHttpResponse> handle(final HttpServiceContext ctx,
serializationProvider.serializerFor(responseEncoding, responseClass),
ctx.executionContext().bufferAllocator()));
} catch (Throwable t) {
return succeeded(newErrorResponse(responseFactory, serviceContext, t,
return succeeded(newErrorResponse(responseFactory, serviceContext, null, t,
ctx.executionContext().bufferAllocator()));
}
}
Expand Down Expand Up @@ -465,7 +465,7 @@ public HttpResponse handle(final HttpServiceContext ctx, final HttpRequest reque
ctx.executionContext().bufferAllocator()).payloadBody(response,
serializationProvider.serializerFor(responseEncoding, responseClass));
} catch (Throwable t) {
return newErrorResponse(responseFactory, serviceContext, t,
return newErrorResponse(responseFactory, serviceContext, null, t,
ctx.executionContext().bufferAllocator());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ public Single<StreamingHttpResponse> handle(final HttpServiceContext ctx,
private static StreamingHttpResponse convertToGrpcErrorResponse(
final HttpServiceContext ctx, final StreamingHttpResponseFactory responseFactory,
final Throwable cause) {
return newErrorResponse(responseFactory, null, cause, ctx.executionContext().bufferAllocator());
return newErrorResponse(responseFactory, null, null, cause, ctx.executionContext().bufferAllocator());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
import javax.annotation.Nullable;

import static io.servicetalk.buffer.api.CharSequences.newAsciiString;
import static io.servicetalk.concurrent.api.Publisher.empty;
import static io.servicetalk.concurrent.api.Publisher.failed;
import static io.servicetalk.encoding.api.Identity.identity;
import static io.servicetalk.encoding.api.internal.HeaderUtils.encodingFor;
import static io.servicetalk.grpc.api.GrpcStatusCode.CANCELLED;
Expand Down Expand Up @@ -167,18 +165,32 @@ static HttpResponse newResponse(final HttpResponseFactory responseFactory,
return response;
}

static HttpResponse newErrorResponse(final HttpResponseFactory responseFactory,
@Nullable final GrpcServiceContext context,
final Throwable cause, final BufferAllocator allocator) {
HttpResponse response = newResponse(responseFactory, context, allocator);
setStatus(response.trailers(), cause, allocator);
static HttpResponse newErrorResponse(
final HttpResponseFactory responseFactory, @Nullable final GrpcServiceContext context,
@Nullable final GrpcStatus status, @Nullable final Throwable cause, final BufferAllocator allocator) {
assert status != null || cause != null;
final HttpResponse response = responseFactory.ok();
initResponse(response, context);
if (status != null) {
setStatus(response.headers(), status, null, allocator);
} else {
setStatus(response.headers(), cause, allocator);
}
return response;
}

static StreamingHttpResponse newErrorResponse(final StreamingHttpResponseFactory responseFactory,
@Nullable final GrpcServiceContext context, final Throwable cause,
final BufferAllocator allocator) {
return newStreamingResponse(responseFactory, context).transform(new ErrorUpdater(cause, allocator));
static StreamingHttpResponse newErrorResponse(
final StreamingHttpResponseFactory responseFactory, @Nullable final GrpcServiceContext context,
@Nullable final GrpcStatus status, @Nullable final Throwable cause, final BufferAllocator allocator) {
assert (status != null && cause == null) || (status == null && cause != null);
final StreamingHttpResponse response = responseFactory.ok();
initResponse(response, context);
if (status != null) {
setStatus(response.headers(), status, null, allocator);
} else {
setStatus(response.headers(), cause, allocator);
}
return response;
}

private static StreamingHttpResponse newStreamingResponse(final StreamingHttpResponseFactory responseFactory,
Expand Down Expand Up @@ -243,13 +255,19 @@ static <Resp> Publisher<Resp> validateResponseAndGetPayload(final StreamingHttpR
// HTTP1-based implementation translates them into response headers so we need to look for a grpc-status in both
// headers and trailers. Since this is streaming response and we have the headers now, we check for the
// grpc-status here first. If there is no grpc-status in headers, we look for it in trailers later.

final HttpHeaders headers = response.headers();
ensureGrpcContentType(response.status(), headers);
final GrpcStatusCode grpcStatusCode = extractGrpcStatusCodeFromHeaders(headers);
if (grpcStatusCode != null) {
final GrpcStatusException grpcStatusException = convertToGrpcStatusException(grpcStatusCode, headers);
return response.messageBody().ignoreElements()
.concat(grpcStatusException != null ? failed(grpcStatusException) : empty());
if (grpcStatusException != null) {
// Give priority to the error if it happens, to allow delayed requests or streams to terminate.
return Publisher.<Resp>failed(grpcStatusException)
.concat(response.messageBody().ignoreElements());
} else {
return response.messageBody().ignoreElements().toPublisher();
}
}

response.transform(ENSURE_GRPC_STATUS_RECEIVED);
Expand Down Expand Up @@ -496,20 +514,4 @@ protected HttpHeaders payloadFailed(final Throwable cause, final HttpHeaders tra
return trailers;
}
}

private static final class ErrorUpdater extends StatelessTrailersTransformer<Buffer> {
private final Throwable cause;
private final BufferAllocator allocator;

ErrorUpdater(final Throwable cause, final BufferAllocator allocator) {
this.cause = cause;
this.allocator = allocator;
}

@Override
protected HttpHeaders payloadComplete(final HttpHeaders trailers) {
setStatus(trailers, cause, allocator);
return trailers;
}
}
}
Loading

0 comments on commit ce49400

Please sign in to comment.