From a589b8170333476233d48476587a1d4363c90bd0 Mon Sep 17 00:00:00 2001 From: jwilson Date: Fri, 23 Dec 2016 17:21:59 -0500 Subject: [PATCH] Support 'Expect: 100-continue' as a request header. This will stall reading the response body until the server returns an intermediate '100 Continue' response. Closes: https://github.com/square/okhttp/issues/675 --- .../okhttp3/internal/http2/Http2Server.java | 6 +- .../okhttp3/mockwebserver/MockWebServer.java | 33 +++++-- .../okhttp3/mockwebserver/SocketPolicy.java | 9 +- .../src/test/java/okhttp3/CallTest.java | 85 ++++++++++++++++++- .../internal/http2/Http2ConnectionTest.java | 47 +++++++--- .../src/main/java/okhttp3/OkHttpClient.java | 4 + .../main/java/okhttp3/internal/Internal.java | 3 + .../internal/connection/RealConnection.java | 4 +- .../internal/http/CallServerInterceptor.java | 26 ++++-- .../java/okhttp3/internal/http/HttpCodec.java | 12 ++- .../okhttp3/internal/http1/Http1Codec.java | 39 +++++---- .../okhttp3/internal/http2/Http2Codec.java | 37 ++++++-- .../internal/http2/Http2Connection.java | 4 +- .../okhttp3/internal/http2/Http2Stream.java | 49 ++++++----- 14 files changed, 275 insertions(+), 83 deletions(-) diff --git a/mockwebserver/src/main/java/okhttp3/internal/http2/Http2Server.java b/mockwebserver/src/main/java/okhttp3/internal/http2/Http2Server.java index 9823d01343cf..c9a2da9a397a 100644 --- a/mockwebserver/src/main/java/okhttp3/internal/http2/Http2Server.java +++ b/mockwebserver/src/main/java/okhttp3/internal/http2/Http2Server.java @@ -126,7 +126,7 @@ private void send404(Http2Stream stream, String path) throws IOException { new Header(":version", "HTTP/1.1"), new Header("content-type", "text/plain") ); - stream.reply(responseHeaders, true); + stream.sendResponseHeaders(responseHeaders, true); BufferedSink out = Okio.buffer(stream.getSink()); out.writeUtf8("Not found: " + path); out.close(); @@ -138,7 +138,7 @@ private void serveDirectory(Http2Stream stream, File[] files) throws IOException new Header(":version", "HTTP/1.1"), new Header("content-type", "text/html; charset=UTF-8") ); - stream.reply(responseHeaders, true); + stream.sendResponseHeaders(responseHeaders, true); BufferedSink out = Okio.buffer(stream.getSink()); for (File file : files) { String target = file.isDirectory() ? (file.getName() + "/") : file.getName(); @@ -153,7 +153,7 @@ private void serveFile(Http2Stream stream, File file) throws IOException { new Header(":version", "HTTP/1.1"), new Header("content-type", contentType(file)) ); - stream.reply(responseHeaders, true); + stream.sendResponseHeaders(responseHeaders, true); Source source = Okio.source(file); try { BufferedSink out = Okio.buffer(stream.getSink()); diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.java b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.java index 76ee2fbec01a..6e482dd46143 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.java +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.java @@ -86,6 +86,7 @@ import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_AT_START; import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_REQUEST_BODY; import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY; +import static okhttp3.mockwebserver.SocketPolicy.EXPECT_CONTINUE; import static okhttp3.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; import static okhttp3.mockwebserver.SocketPolicy.NO_RESPONSE; import static okhttp3.mockwebserver.SocketPolicy.RESET_STREAM_AT_START; @@ -590,7 +591,7 @@ private RecordedRequest readRequest(Socket socket, BufferedSource source, Buffer Headers.Builder headers = new Headers.Builder(); long contentLength = -1; boolean chunked = false; - boolean expectContinue = false; + boolean readBody = true; String header; while ((header = source.readUtf8LineStrict()).length() != 0) { Internal.instance.addLenient(headers, header); @@ -603,23 +604,26 @@ private RecordedRequest readRequest(Socket socket, BufferedSource source, Buffer chunked = true; } if (lowercaseHeader.startsWith("expect:") - && lowercaseHeader.substring(7).trim().equals("100-continue")) { - expectContinue = true; + && lowercaseHeader.substring(7).trim().equalsIgnoreCase("100-continue")) { + readBody = false; } } - if (expectContinue) { + if (!readBody && dispatcher.peek().getSocketPolicy() == EXPECT_CONTINUE) { sink.writeUtf8("HTTP/1.1 100 Continue\r\n"); sink.writeUtf8("Content-Length: 0\r\n"); sink.writeUtf8("\r\n"); sink.flush(); + readBody = true; } boolean hasBody = false; TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit); List chunkSizes = new ArrayList<>(); MockResponse policy = dispatcher.peek(); - if (contentLength != -1) { + if (!readBody) { + // Don't read the body unless we've invited the client to send it. + } else if (contentLength != -1) { hasBody = contentLength > 0; throttledTransfer(policy, socket, source, Okio.buffer(requestBody), contentLength, true); } else if (chunked) { @@ -881,6 +885,7 @@ private RecordedRequest readRequest(Http2Stream stream) throws IOException { Headers.Builder httpHeaders = new Headers.Builder(); String method = "<:method omitted>"; String path = "<:path omitted>"; + boolean readBody = true; for (int i = 0, size = streamHeaders.size(); i < size; i++) { ByteString name = streamHeaders.get(i).name; String value = streamHeaders.get(i).value.utf8(); @@ -893,11 +898,23 @@ private RecordedRequest readRequest(Http2Stream stream) throws IOException { } else { throw new IllegalStateException(); } + if (name.utf8().equals("expect") && value.equalsIgnoreCase("100-continue")) { + // Don't read the body unless we've invited the client to send it. + readBody = false; + } + } + + if (!readBody && dispatcher.peek().getSocketPolicy() == EXPECT_CONTINUE) { + stream.sendResponseHeaders(Collections.singletonList( + new Header(Header.RESPONSE_STATUS, ByteString.encodeUtf8("100 Continue"))), true); + stream.getConnection().flush(); + readBody = true; } Buffer body = new Buffer(); - body.writeAll(stream.getSource()); - body.close(); + if (readBody) { + body.writeAll(stream.getSource()); + } String requestLine = method + ' ' + path + " HTTP/1.1"; List chunkSizes = Collections.emptyList(); // No chunked encoding for HTTP/2. @@ -928,7 +945,7 @@ private void writeResponse(Http2Stream stream, MockResponse response) throws IOE Buffer body = response.getBody(); boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty(); - stream.reply(http2Headers, closeStreamAfterHeaders); + stream.sendResponseHeaders(http2Headers, closeStreamAfterHeaders); pushPromises(stream, response.getPushPromises()); if (body != null) { BufferedSink sink = Okio.buffer(stream.getSink()); diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/SocketPolicy.java b/mockwebserver/src/main/java/okhttp3/mockwebserver/SocketPolicy.java index 153d3fd129fc..521c5a0e5840 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/SocketPolicy.java +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/SocketPolicy.java @@ -97,5 +97,12 @@ public enum SocketPolicy { * Fail HTTP/2 requests without processing them by sending an {@linkplain * MockResponse#getHttp2ErrorCode() HTTP/2 error code}. */ - RESET_STREAM_AT_START + RESET_STREAM_AT_START, + + /** + * Transmit a {@code HTTP/1.1 100 Continue} response before reading the HTTP request body. + * Typically this response is sent when a client makes a request with the header {@code + * Expect: 100-continue}. + */ + EXPECT_CONTINUE } diff --git a/okhttp-tests/src/test/java/okhttp3/CallTest.java b/okhttp-tests/src/test/java/okhttp3/CallTest.java index 9f508842b3d2..5b7e450d1a62 100644 --- a/okhttp-tests/src/test/java/okhttp3/CallTest.java +++ b/okhttp-tests/src/test/java/okhttp3/CallTest.java @@ -2185,7 +2185,8 @@ private InetSocketAddress startNullServer() throws IOException { } @Test public void expect100ContinueNonEmptyRequestBody() throws Exception { - server.enqueue(new MockResponse()); + server.enqueue(new MockResponse() + .setSocketPolicy(SocketPolicy.EXPECT_CONTINUE)); Request request = new Request.Builder() .url(server.url("/")) @@ -2214,6 +2215,88 @@ private InetSocketAddress startNullServer() throws IOException { .assertSuccessful(); } + @Test public void expect100ContinueEmptyRequestBody_HTTP2() throws Exception { + enableProtocol(Protocol.HTTP_2); + expect100ContinueEmptyRequestBody(); + } + + @Test public void expect100ContinueTimesOutWithoutContinue() throws Exception { + server.enqueue(new MockResponse() + .setSocketPolicy(SocketPolicy.NO_RESPONSE)); + + client = client.newBuilder() + .readTimeout(500, TimeUnit.MILLISECONDS) + .build(); + + Request request = new Request.Builder() + .url(server.url("/")) + .header("Expect", "100-continue") + .post(RequestBody.create(MediaType.parse("text/plain"), "abc")) + .build(); + + Call call = client.newCall(request); + try { + call.execute(); + fail(); + } catch (SocketTimeoutException expected) { + } + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals("", recordedRequest.getBody().readUtf8()); + } + + @Test public void expect100ContinueTimesOutWithoutContinue_HTTP2() throws Exception { + enableProtocol(Protocol.HTTP_2); + expect100ContinueTimesOutWithoutContinue(); + } + + @Test public void serverRespondsWithUnsolicited100Continue() throws Exception { + server.enqueue(new MockResponse() + .setStatus("HTTP/1.1 100 Continue")); + + Request request = new Request.Builder() + .url(server.url("/")) + .post(RequestBody.create(MediaType.parse("text/plain"), "abc")) + .build(); + + Call call = client.newCall(request); + Response response = call.execute(); + assertEquals(100, response.code()); + assertEquals("Continue", response.message()); + assertEquals("", response.body().string()); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals("abc", recordedRequest.getBody().readUtf8()); + } + + @Test public void serverRespondsWithUnsolicited100Continue_HTTP2() throws Exception { + enableProtocol(Protocol.HTTP_2); + serverRespondsWithUnsolicited100Continue(); + } + + @Test public void successfulExpectContinueAndConnectionReuse() throws Exception { + server.enqueue(new MockResponse() + .setSocketPolicy(SocketPolicy.EXPECT_CONTINUE)); + server.enqueue(new MockResponse()); + + executeSynchronously("/", "Expect", "100-continue"); + executeSynchronously("/"); + + assertEquals(0, server.takeRequest().getSequenceNumber()); + assertEquals(1, server.takeRequest().getSequenceNumber()); + } + + @Test public void unsuccessfulExpectContinueAndConnectionReuse() throws Exception { + server.enqueue(new MockResponse()); + server.enqueue(new MockResponse()); + + executeSynchronously("/", "Expect", "100-continue"); + executeSynchronously("/"); + + assertEquals(0, server.takeRequest().getSequenceNumber()); + assertEquals(1, server.takeRequest().getSequenceNumber()); + } + /** We forbid non-ASCII characters in outgoing request headers, but accept UTF-8. */ @Test public void responseHeaderParsingIsLenient() throws Exception { Headers headers = new Headers.Builder() diff --git a/okhttp-tests/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java b/okhttp-tests/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java index 37bec5e529e8..18fcab5a1221 100644 --- a/okhttp-tests/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java +++ b/okhttp-tests/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java @@ -291,7 +291,7 @@ public final class Http2ConnectionTest { connection.okHttpSettings.set(INITIAL_WINDOW_SIZE, windowSize); Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false); assertEquals(0, stream.unacknowledgedBytesRead); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); Source in = stream.getSource(); Buffer buffer = new Buffer(); buffer.writeAll(in); @@ -513,7 +513,7 @@ public final class Http2ConnectionTest { BufferedSink out = Okio.buffer(stream.getSink()); out.writeUtf8("c3po"); out.close(); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); assertStreamData("robot", stream.getSource()); connection.ping().roundTripTime(); assertEquals(0, connection.openStreamCount()); @@ -914,7 +914,7 @@ public final class Http2ConnectionTest { // play it back Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("c", "cola"), false); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); connection.ping().roundTripTime(); // Ensure that the 2nd SYN REPLY has been received. // verify the peer received what was expected @@ -940,7 +940,7 @@ public final class Http2ConnectionTest { // play it back Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); assertStreamData("robot", stream.getSource()); // verify the peer received what was expected @@ -973,7 +973,7 @@ public final class Http2ConnectionTest { // play it back Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("a", "android"), false); - assertEquals(headerEntries("b", "banana"), stream.getResponseHeaders()); + assertEquals(headerEntries("b", "banana"), stream.takeResponseHeaders()); // verify the peer received what was expected InFrame synStream = peer.takeFrame(); @@ -997,7 +997,7 @@ public final class Http2ConnectionTest { Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("a", "android"), false); try { - stream.getResponseHeaders(); + stream.takeResponseHeaders(); fail(); } catch (IOException expected) { assertEquals("stream was reset: REFUSED_STREAM", expected.getMessage()); @@ -1189,7 +1189,7 @@ public final class Http2ConnectionTest { stream.readTimeout().timeout(500, TimeUnit.MILLISECONDS); long startNanos = System.nanoTime(); try { - stream.getResponseHeaders(); + stream.takeResponseHeaders(); fail(); } catch (InterruptedIOException expected) { } @@ -1362,7 +1362,8 @@ public final class Http2ConnectionTest { Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("b", "banana"), true); connection.ping().roundTripTime(); // Ensure that the HEADERS has been received. - assertEquals(headerEntries("a", "android", "c", "c3po"), stream.getResponseHeaders()); + assertEquals(Arrays.asList(new Header("a", "android"), null, new Header("c", "c3po")), + stream.takeResponseHeaders()); // verify the peer received what was expected InFrame synStream = peer.takeFrame(); @@ -1371,6 +1372,30 @@ public final class Http2ConnectionTest { assertEquals(Http2.TYPE_PING, ping.type); } + @Test public void readMultipleSetsOfResponseHeaders() throws Exception { + // write the mocking script + peer.sendFrame().settings(new Settings()); + peer.acceptFrame(); // ACK + peer.acceptFrame(); // SYN_STREAM + peer.sendFrame().synReply(false, 3, headerEntries("a", "android")); + peer.acceptFrame(); // PING + peer.sendFrame().ping(true, 1, 0); // PING + peer.sendFrame().synReply(true, 3, headerEntries("c", "cola")); + peer.play(); + + // play it back + Http2Connection connection = connect(peer); + Http2Stream stream = connection.newStream(headerEntries("b", "banana"), true); + stream.getConnection().flush(); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); + connection.ping().roundTripTime(); + assertEquals(headerEntries("c", "cola"), stream.takeResponseHeaders()); + + // verify the peer received what was expected + assertEquals(Http2.TYPE_HEADERS, peer.takeFrame().type); + assertEquals(Http2.TYPE_PING, peer.takeFrame().type); + } + @Test public void readSendsWindowUpdate() throws Exception { int windowSize = 100; int windowUpdateThreshold = 50; @@ -1396,7 +1421,7 @@ public final class Http2ConnectionTest { connection.okHttpSettings.set(INITIAL_WINDOW_SIZE, windowSize); Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false); assertEquals(0, stream.unacknowledgedBytesRead); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); Source in = stream.getSource(); Buffer buffer = new Buffer(); buffer.writeAll(in); @@ -1474,7 +1499,7 @@ public final class Http2ConnectionTest { // play it back Http2Connection connection = connect(peer); Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false); - assertEquals(headerEntries("a", "android"), stream.getResponseHeaders()); + assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders()); Source in = stream.getSource(); try { Okio.buffer(in).readByteString(101); @@ -1540,7 +1565,7 @@ public final class Http2ConnectionTest { Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false); try { - stream.getResponseHeaders(); + stream.takeResponseHeaders(); fail(); } catch (IOException expected) { assertEquals("stream was reset: PROTOCOL_ERROR", expected.getMessage()); diff --git a/okhttp/src/main/java/okhttp3/OkHttpClient.java b/okhttp/src/main/java/okhttp3/OkHttpClient.java index d55351e841a6..eb62e1e011c9 100644 --- a/okhttp/src/main/java/okhttp3/OkHttpClient.java +++ b/okhttp/src/main/java/okhttp3/OkHttpClient.java @@ -156,6 +156,10 @@ public class OkHttpClient implements Cloneable, Call.Factory, WebSocket.Factory return connectionPool.routeDatabase; } + @Override public int code(Response.Builder responseBuilder) { + return responseBuilder.code; + } + @Override public void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket, boolean isFallback) { tlsConfiguration.apply(sslSocket, isFallback); diff --git a/okhttp/src/main/java/okhttp3/internal/Internal.java b/okhttp/src/main/java/okhttp3/internal/Internal.java index 2610a4451635..759f4c38431e 100644 --- a/okhttp/src/main/java/okhttp3/internal/Internal.java +++ b/okhttp/src/main/java/okhttp3/internal/Internal.java @@ -26,6 +26,7 @@ import okhttp3.HttpUrl; import okhttp3.OkHttpClient; import okhttp3.Request; +import okhttp3.Response; import okhttp3.internal.cache.InternalCache; import okhttp3.internal.connection.RealConnection; import okhttp3.internal.connection.RouteDatabase; @@ -59,6 +60,8 @@ public abstract RealConnection get( public abstract RouteDatabase routeDatabase(ConnectionPool connectionPool); + public abstract int code(Response.Builder responseBuilder); + public abstract void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket, boolean isFallback); diff --git a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java index 547781d26f73..0583b6c71b5e 100644 --- a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java +++ b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.java @@ -294,7 +294,9 @@ private Request createTunnel(int readTimeout, int writeTimeout, Request tunnelRe sink.timeout().timeout(writeTimeout, MILLISECONDS); tunnelConnection.writeRequest(tunnelRequest.headers(), requestLine); tunnelConnection.finishRequest(); - Response response = tunnelConnection.readResponse().request(tunnelRequest).build(); + Response response = tunnelConnection.readResponseHeaders(false) + .request(tunnelRequest) + .build(); // The response body from a CONNECT should be empty, but if it is not then we should consume // it before proceeding. long contentLength = HttpHeaders.contentLength(response); diff --git a/okhttp/src/main/java/okhttp3/internal/http/CallServerInterceptor.java b/okhttp/src/main/java/okhttp3/internal/http/CallServerInterceptor.java index 6adbc28b56f6..f7c0e5a56106 100644 --- a/okhttp/src/main/java/okhttp3/internal/http/CallServerInterceptor.java +++ b/okhttp/src/main/java/okhttp3/internal/http/CallServerInterceptor.java @@ -42,16 +42,32 @@ public CallServerInterceptor(boolean forWebSocket) { long sentRequestMillis = System.currentTimeMillis(); httpCodec.writeRequestHeaders(request); + Response.Builder responseBuilder = null; if (HttpMethod.permitsRequestBody(request.method()) && request.body() != null) { - Sink requestBodyOut = httpCodec.createRequestBody(request, request.body().contentLength()); - BufferedSink bufferedRequestBody = Okio.buffer(requestBodyOut); - request.body().writeTo(bufferedRequestBody); - bufferedRequestBody.close(); + // If there's a "Expect: 100-continue" header on the request, wait for a "HTTP/1.1 100 + // Continue" response before transmitting the request body. If we don't get that, return what + // we did get (such as a 4xx response) without ever transmitting the request body. + if ("100-continue".equalsIgnoreCase(request.header("Expect"))) { + httpCodec.flushRequest(); + responseBuilder = httpCodec.readResponseHeaders(true); + } + + // Write the request body, unless an "Expect: 100-continue" expectation failed. + if (responseBuilder == null) { + Sink requestBodyOut = httpCodec.createRequestBody(request, request.body().contentLength()); + BufferedSink bufferedRequestBody = Okio.buffer(requestBodyOut); + request.body().writeTo(bufferedRequestBody); + bufferedRequestBody.close(); + } } httpCodec.finishRequest(); - Response response = httpCodec.readResponseHeaders() + if (responseBuilder == null) { + responseBuilder = httpCodec.readResponseHeaders(false); + } + + Response response = responseBuilder .request(request) .handshake(streamAllocation.connection().handshake()) .sentRequestAtMillis(sentRequestMillis) diff --git a/okhttp/src/main/java/okhttp3/internal/http/HttpCodec.java b/okhttp/src/main/java/okhttp3/internal/http/HttpCodec.java index b227399e9135..ad9759acce3e 100644 --- a/okhttp/src/main/java/okhttp3/internal/http/HttpCodec.java +++ b/okhttp/src/main/java/okhttp3/internal/http/HttpCodec.java @@ -37,10 +37,18 @@ public interface HttpCodec { void writeRequestHeaders(Request request) throws IOException; /** Flush the request to the underlying socket. */ + void flushRequest() throws IOException; + + /** Flush the request to the underlying socket and signal no more bytes will be transmitted. */ void finishRequest() throws IOException; - /** Read and return response headers. */ - Response.Builder readResponseHeaders() throws IOException; + /** + * Parses bytes of a response header from an HTTP transport. + * + * @param expectContinue true to return null if this is an intermediate response with a "100" + * response code. Otherwise this method never returns null. + */ + Response.Builder readResponseHeaders(boolean expectContinue) throws IOException; /** Returns a stream that reads the response body. */ ResponseBody openResponseBody(Response response) throws IOException; diff --git a/okhttp/src/main/java/okhttp3/internal/http1/Http1Codec.java b/okhttp/src/main/java/okhttp3/internal/http1/Http1Codec.java index dc2dd1791ef2..625c173bf3a2 100644 --- a/okhttp/src/main/java/okhttp3/internal/http1/Http1Codec.java +++ b/okhttp/src/main/java/okhttp3/internal/http1/Http1Codec.java @@ -55,7 +55,7 @@ *
  • Open a sink to write the request body. Either {@linkplain #newFixedLengthSink * fixed-length} or {@link #newChunkedSink chunked}. *
  • Write to and then close that sink. - *
  • {@linkplain #readResponse Read response headers}. + *
  • {@linkplain #readResponseHeaders Read response headers}. *
  • Open a source to read the response body. Either {@linkplain #newFixedLengthSource * fixed-length}, {@linkplain #newChunkedSource chunked} or {@linkplain * #newUnknownLengthSource unknown length}. @@ -128,10 +128,6 @@ public Http1Codec(OkHttpClient client, StreamAllocation streamAllocation, Buffer writeRequest(request.headers(), requestLine); } - @Override public Response.Builder readResponseHeaders() throws IOException { - return readResponse(); - } - @Override public ResponseBody openResponseBody(Response response) throws IOException { Source source = getTransferStream(response); return new RealResponseBody(response.headers(), Okio.buffer(source)); @@ -162,6 +158,10 @@ public boolean isClosed() { return state == STATE_CLOSED; } + @Override public void flushRequest() throws IOException { + sink.flush(); + } + @Override public void finishRequest() throws IOException { sink.flush(); } @@ -180,27 +180,26 @@ public void writeRequest(Headers headers, String requestLine) throws IOException state = STATE_OPEN_REQUEST_BODY; } - /** Parses bytes of a response header from an HTTP transport. */ - public Response.Builder readResponse() throws IOException { + @Override public Response.Builder readResponseHeaders(boolean expectContinue) throws IOException { if (state != STATE_OPEN_REQUEST_BODY && state != STATE_READ_RESPONSE_HEADERS) { throw new IllegalStateException("state: " + state); } try { - while (true) { - StatusLine statusLine = StatusLine.parse(source.readUtf8LineStrict()); - - Response.Builder responseBuilder = new Response.Builder() - .protocol(statusLine.protocol) - .code(statusLine.code) - .message(statusLine.message) - .headers(readHeaders()); - - if (statusLine.code != HTTP_CONTINUE) { - state = STATE_OPEN_RESPONSE_BODY; - return responseBuilder; - } + StatusLine statusLine = StatusLine.parse(source.readUtf8LineStrict()); + + Response.Builder responseBuilder = new Response.Builder() + .protocol(statusLine.protocol) + .code(statusLine.code) + .message(statusLine.message) + .headers(readHeaders()); + + if (expectContinue && statusLine.code == HTTP_CONTINUE) { + return null; } + + state = STATE_OPEN_RESPONSE_BODY; + return responseBuilder; } catch (EOFException e) { // Provide more context if the server ends the stream before sending a response. IOException exception = new IOException("unexpected end of stream on " + streamAllocation); diff --git a/okhttp/src/main/java/okhttp3/internal/http2/Http2Codec.java b/okhttp/src/main/java/okhttp3/internal/http2/Http2Codec.java index abd4b5ef3c8e..66a72ab12a9b 100644 --- a/okhttp/src/main/java/okhttp3/internal/http2/Http2Codec.java +++ b/okhttp/src/main/java/okhttp3/internal/http2/Http2Codec.java @@ -40,6 +40,7 @@ import okio.Sink; import okio.Source; +import static okhttp3.internal.http.StatusLine.HTTP_CONTINUE; import static okhttp3.internal.http2.Header.RESPONSE_STATUS; import static okhttp3.internal.http2.Header.TARGET_AUTHORITY; import static okhttp3.internal.http2.Header.TARGET_METHOD; @@ -107,12 +108,21 @@ public Http2Codec( stream.writeTimeout().timeout(client.writeTimeoutMillis(), TimeUnit.MILLISECONDS); } + @Override public void flushRequest() throws IOException { + connection.flush(); + } + @Override public void finishRequest() throws IOException { stream.getSink().close(); } - @Override public Response.Builder readResponseHeaders() throws IOException { - return readHttp2HeadersList(stream.getResponseHeaders()); + @Override public Response.Builder readResponseHeaders(boolean expectContinue) throws IOException { + List
    headers = stream.takeResponseHeaders(); + Response.Builder responseBuilder = readHttp2HeadersList(headers); + if (expectContinue && Internal.instance.code(responseBuilder) == HTTP_CONTINUE) { + return null; + } + return responseBuilder; } public static List
    http2HeadersList(Request request) { @@ -135,22 +145,31 @@ public static List
    http2HeadersList(Request request) { /** Returns headers for a name value block containing an HTTP/2 response. */ public static Response.Builder readHttp2HeadersList(List
    headerBlock) throws IOException { - String status = null; - + StatusLine statusLine = null; Headers.Builder headersBuilder = new Headers.Builder(); for (int i = 0, size = headerBlock.size(); i < size; i++) { - ByteString name = headerBlock.get(i).name; + Header header = headerBlock.get(i); + + // If there were multiple header blocks they will be delimited by nulls. Discard existing + // header blocks if the existing header block is a '100 Continue' intermediate response. + if (header == null) { + if (statusLine != null && statusLine.code == HTTP_CONTINUE) { + statusLine = null; + headersBuilder = new Headers.Builder(); + } + continue; + } - String value = headerBlock.get(i).value.utf8(); + ByteString name = header.name; + String value = header.value.utf8(); if (name.equals(RESPONSE_STATUS)) { - status = value; + statusLine = StatusLine.parse("HTTP/1.1 " + value); } else if (!HTTP_2_SKIPPED_RESPONSE_HEADERS.contains(name)) { Internal.instance.addLenient(headersBuilder, name.utf8(), value); } } - if (status == null) throw new ProtocolException("Expected ':status' header not present"); + if (statusLine == null) throw new ProtocolException("Expected ':status' header not present"); - StatusLine statusLine = StatusLine.parse("HTTP/1.1 " + status); return new Response.Builder() .protocol(Protocol.HTTP_2) .code(statusLine.code) diff --git a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.java b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.java index d79c1b3b288f..2548705eed8e 100644 --- a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.java +++ b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.java @@ -865,8 +865,8 @@ public abstract static class Listener { /** * Handle a new stream from this connection's peer. Implementations should respond by either - * {@linkplain Http2Stream#reply replying to the stream} or {@linkplain Http2Stream#close - * closing it}. This response does not need to be synchronous. + * {@linkplain Http2Stream#sendResponseHeaders replying to the stream} or {@linkplain + * Http2Stream#close closing it}. This response does not need to be synchronous. */ public abstract void onStream(Http2Stream stream) throws IOException; diff --git a/okhttp/src/main/java/okhttp3/internal/http2/Http2Stream.java b/okhttp/src/main/java/okhttp3/internal/http2/Http2Stream.java index 3d036860a284..8659fd6a01c5 100644 --- a/okhttp/src/main/java/okhttp3/internal/http2/Http2Stream.java +++ b/okhttp/src/main/java/okhttp3/internal/http2/Http2Stream.java @@ -51,12 +51,15 @@ public final class Http2Stream { final int id; final Http2Connection connection; - /** Headers sent by the stream initiator. Immutable and non null. */ + /** Request headers. Immutable and non null. */ private final List
    requestHeaders; - /** Headers sent in the stream reply. Null if reply is either not sent or not sent yet. */ + /** Response headers yet to be {@linkplain #takeResponseHeaders taken}. */ private List
    responseHeaders; + /** True if response headers have been sent or received. */ + private boolean hasResponseHeaders; + private final FramedDataSource source; final FramedDataSink sink; final StreamTimeout readTimeout = new StreamTimeout(); @@ -106,7 +109,7 @@ public synchronized boolean isOpen() { } if ((source.finished || source.closed) && (sink.finished || sink.closed) - && responseHeaders != null) { + && hasResponseHeaders) { return false; } return true; @@ -127,10 +130,14 @@ public List
    getRequestHeaders() { } /** - * Returns the stream's response headers, blocking if necessary if they have not been received - * yet. + * Removes and returns the stream's received response headers, blocking if necessary until headers + * have been received. If the returned list contains multiple blocks of headers the blocks will be + * delimited by 'null'. */ - public synchronized List
    getResponseHeaders() throws IOException { + public synchronized List
    takeResponseHeaders() throws IOException { + if (!isLocallyInitiated()) { + throw new IllegalStateException("servers cannot read response headers"); + } readTimeout.enter(); try { while (responseHeaders == null && errorCode == null) { @@ -139,7 +146,11 @@ public synchronized List
    getResponseHeaders() throws IOException { } finally { readTimeout.exitAndThrowIfTimedOut(); } - if (responseHeaders != null) return responseHeaders; + List
    result = responseHeaders; + if (result != null) { + responseHeaders = null; + return result; + } throw new StreamResetException(errorCode); } @@ -157,17 +168,14 @@ public synchronized ErrorCode getErrorCode() { * @param out true to create an output stream that we can use to send data to the remote peer. * Corresponds to {@code FLAG_FIN}. */ - public void reply(List
    responseHeaders, boolean out) throws IOException { + public void sendResponseHeaders(List
    responseHeaders, boolean out) throws IOException { assert (!Thread.holdsLock(Http2Stream.this)); + if (responseHeaders == null) { + throw new NullPointerException("responseHeaders == null"); + } boolean outFinished = false; synchronized (this) { - if (responseHeaders == null) { - throw new NullPointerException("responseHeaders == null"); - } - if (this.responseHeaders != null) { - throw new IllegalStateException("reply already sent"); - } - this.responseHeaders = responseHeaders; + this.hasResponseHeaders = true; if (!out) { this.sink.finished = true; outFinished = true; @@ -196,12 +204,12 @@ public Source getSource() { /** * Returns a sink that can be used to write data to the peer. * - * @throws IllegalStateException if this stream was initiated by the peer and a {@link #reply} has - * not yet been sent. + * @throws IllegalStateException if this stream was initiated by the peer and a {@link + * #sendResponseHeaders} has not yet been sent. */ public Sink getSink() { synchronized (this) { - if (responseHeaders == null && !isLocallyInitiated()) { + if (!hasResponseHeaders && !isLocallyInitiated()) { throw new IllegalStateException("reply before requesting the sink"); } } @@ -251,6 +259,7 @@ void receiveHeaders(List
    headers) { assert (!Thread.holdsLock(Http2Stream.this)); boolean open = true; synchronized (this) { + hasResponseHeaders = true; if (responseHeaders == null) { responseHeaders = headers; open = isOpen(); @@ -258,6 +267,7 @@ void receiveHeaders(List
    headers) { } else { List
    newHeaders = new ArrayList<>(); newHeaders.addAll(responseHeaders); + newHeaders.add(null); // Delimit separate blocks of headers with null. newHeaders.addAll(headers); this.responseHeaders = newHeaders; } @@ -320,8 +330,7 @@ private final class FramedDataSource implements Source { this.maxByteCount = maxByteCount; } - @Override public long read(Buffer sink, long byteCount) - throws IOException { + @Override public long read(Buffer sink, long byteCount) throws IOException { if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount); long read;