Skip to content

Commit

Permalink
Support 'Expect: 100-continue' as a request header.
Browse files Browse the repository at this point in the history
This will stall reading the response body until the server returns
an intermediate '100 Continue' response.

Closes: #675
  • Loading branch information
squarejesse committed Dec 27, 2016
1 parent dc5cbfd commit a589b81
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private void send404(Http2Stream stream, String path) throws IOException {
new Header(":version", "HTTP/1.1"),
new Header("content-type", "text/plain")
);
stream.reply(responseHeaders, true);
stream.sendResponseHeaders(responseHeaders, true);
BufferedSink out = Okio.buffer(stream.getSink());
out.writeUtf8("Not found: " + path);
out.close();
Expand All @@ -138,7 +138,7 @@ private void serveDirectory(Http2Stream stream, File[] files) throws IOException
new Header(":version", "HTTP/1.1"),
new Header("content-type", "text/html; charset=UTF-8")
);
stream.reply(responseHeaders, true);
stream.sendResponseHeaders(responseHeaders, true);
BufferedSink out = Okio.buffer(stream.getSink());
for (File file : files) {
String target = file.isDirectory() ? (file.getName() + "/") : file.getName();
Expand All @@ -153,7 +153,7 @@ private void serveFile(Http2Stream stream, File file) throws IOException {
new Header(":version", "HTTP/1.1"),
new Header("content-type", contentType(file))
);
stream.reply(responseHeaders, true);
stream.sendResponseHeaders(responseHeaders, true);
Source source = Okio.source(file);
try {
BufferedSink out = Okio.buffer(stream.getSink());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_REQUEST_BODY;
import static okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY;
import static okhttp3.mockwebserver.SocketPolicy.EXPECT_CONTINUE;
import static okhttp3.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
import static okhttp3.mockwebserver.SocketPolicy.NO_RESPONSE;
import static okhttp3.mockwebserver.SocketPolicy.RESET_STREAM_AT_START;
Expand Down Expand Up @@ -590,7 +591,7 @@ private RecordedRequest readRequest(Socket socket, BufferedSource source, Buffer
Headers.Builder headers = new Headers.Builder();
long contentLength = -1;
boolean chunked = false;
boolean expectContinue = false;
boolean readBody = true;
String header;
while ((header = source.readUtf8LineStrict()).length() != 0) {
Internal.instance.addLenient(headers, header);
Expand All @@ -603,23 +604,26 @@ private RecordedRequest readRequest(Socket socket, BufferedSource source, Buffer
chunked = true;
}
if (lowercaseHeader.startsWith("expect:")
&& lowercaseHeader.substring(7).trim().equals("100-continue")) {
expectContinue = true;
&& lowercaseHeader.substring(7).trim().equalsIgnoreCase("100-continue")) {
readBody = false;
}
}

if (expectContinue) {
if (!readBody && dispatcher.peek().getSocketPolicy() == EXPECT_CONTINUE) {
sink.writeUtf8("HTTP/1.1 100 Continue\r\n");
sink.writeUtf8("Content-Length: 0\r\n");
sink.writeUtf8("\r\n");
sink.flush();
readBody = true;
}

boolean hasBody = false;
TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit);
List<Integer> chunkSizes = new ArrayList<>();
MockResponse policy = dispatcher.peek();
if (contentLength != -1) {
if (!readBody) {
// Don't read the body unless we've invited the client to send it.
} else if (contentLength != -1) {
hasBody = contentLength > 0;
throttledTransfer(policy, socket, source, Okio.buffer(requestBody), contentLength, true);
} else if (chunked) {
Expand Down Expand Up @@ -881,6 +885,7 @@ private RecordedRequest readRequest(Http2Stream stream) throws IOException {
Headers.Builder httpHeaders = new Headers.Builder();
String method = "<:method omitted>";
String path = "<:path omitted>";
boolean readBody = true;
for (int i = 0, size = streamHeaders.size(); i < size; i++) {
ByteString name = streamHeaders.get(i).name;
String value = streamHeaders.get(i).value.utf8();
Expand All @@ -893,11 +898,23 @@ private RecordedRequest readRequest(Http2Stream stream) throws IOException {
} else {
throw new IllegalStateException();
}
if (name.utf8().equals("expect") && value.equalsIgnoreCase("100-continue")) {
// Don't read the body unless we've invited the client to send it.
readBody = false;
}
}

if (!readBody && dispatcher.peek().getSocketPolicy() == EXPECT_CONTINUE) {
stream.sendResponseHeaders(Collections.singletonList(
new Header(Header.RESPONSE_STATUS, ByteString.encodeUtf8("100 Continue"))), true);
stream.getConnection().flush();
readBody = true;
}

Buffer body = new Buffer();
body.writeAll(stream.getSource());
body.close();
if (readBody) {
body.writeAll(stream.getSource());
}

String requestLine = method + ' ' + path + " HTTP/1.1";
List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for HTTP/2.
Expand Down Expand Up @@ -928,7 +945,7 @@ private void writeResponse(Http2Stream stream, MockResponse response) throws IOE

Buffer body = response.getBody();
boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty();
stream.reply(http2Headers, closeStreamAfterHeaders);
stream.sendResponseHeaders(http2Headers, closeStreamAfterHeaders);
pushPromises(stream, response.getPushPromises());
if (body != null) {
BufferedSink sink = Okio.buffer(stream.getSink());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,12 @@ public enum SocketPolicy {
* Fail HTTP/2 requests without processing them by sending an {@linkplain
* MockResponse#getHttp2ErrorCode() HTTP/2 error code}.
*/
RESET_STREAM_AT_START
RESET_STREAM_AT_START,

/**
* Transmit a {@code HTTP/1.1 100 Continue} response before reading the HTTP request body.
* Typically this response is sent when a client makes a request with the header {@code
* Expect: 100-continue}.
*/
EXPECT_CONTINUE
}
85 changes: 84 additions & 1 deletion okhttp-tests/src/test/java/okhttp3/CallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,8 @@ private InetSocketAddress startNullServer() throws IOException {
}

@Test public void expect100ContinueNonEmptyRequestBody() throws Exception {
server.enqueue(new MockResponse());
server.enqueue(new MockResponse()
.setSocketPolicy(SocketPolicy.EXPECT_CONTINUE));

Request request = new Request.Builder()
.url(server.url("/"))
Expand Down Expand Up @@ -2214,6 +2215,88 @@ private InetSocketAddress startNullServer() throws IOException {
.assertSuccessful();
}

@Test public void expect100ContinueEmptyRequestBody_HTTP2() throws Exception {
enableProtocol(Protocol.HTTP_2);
expect100ContinueEmptyRequestBody();
}

@Test public void expect100ContinueTimesOutWithoutContinue() throws Exception {
server.enqueue(new MockResponse()
.setSocketPolicy(SocketPolicy.NO_RESPONSE));

client = client.newBuilder()
.readTimeout(500, TimeUnit.MILLISECONDS)
.build();

Request request = new Request.Builder()
.url(server.url("/"))
.header("Expect", "100-continue")
.post(RequestBody.create(MediaType.parse("text/plain"), "abc"))
.build();

Call call = client.newCall(request);
try {
call.execute();
fail();
} catch (SocketTimeoutException expected) {
}

RecordedRequest recordedRequest = server.takeRequest();
assertEquals("", recordedRequest.getBody().readUtf8());
}

@Test public void expect100ContinueTimesOutWithoutContinue_HTTP2() throws Exception {
enableProtocol(Protocol.HTTP_2);
expect100ContinueTimesOutWithoutContinue();
}

@Test public void serverRespondsWithUnsolicited100Continue() throws Exception {
server.enqueue(new MockResponse()
.setStatus("HTTP/1.1 100 Continue"));

Request request = new Request.Builder()
.url(server.url("/"))
.post(RequestBody.create(MediaType.parse("text/plain"), "abc"))
.build();

Call call = client.newCall(request);
Response response = call.execute();
assertEquals(100, response.code());
assertEquals("Continue", response.message());
assertEquals("", response.body().string());

RecordedRequest recordedRequest = server.takeRequest();
assertEquals("abc", recordedRequest.getBody().readUtf8());
}

@Test public void serverRespondsWithUnsolicited100Continue_HTTP2() throws Exception {
enableProtocol(Protocol.HTTP_2);
serverRespondsWithUnsolicited100Continue();
}

@Test public void successfulExpectContinueAndConnectionReuse() throws Exception {
server.enqueue(new MockResponse()
.setSocketPolicy(SocketPolicy.EXPECT_CONTINUE));
server.enqueue(new MockResponse());

executeSynchronously("/", "Expect", "100-continue");
executeSynchronously("/");

assertEquals(0, server.takeRequest().getSequenceNumber());
assertEquals(1, server.takeRequest().getSequenceNumber());
}

@Test public void unsuccessfulExpectContinueAndConnectionReuse() throws Exception {
server.enqueue(new MockResponse());
server.enqueue(new MockResponse());

executeSynchronously("/", "Expect", "100-continue");
executeSynchronously("/");

assertEquals(0, server.takeRequest().getSequenceNumber());
assertEquals(1, server.takeRequest().getSequenceNumber());
}

/** We forbid non-ASCII characters in outgoing request headers, but accept UTF-8. */
@Test public void responseHeaderParsingIsLenient() throws Exception {
Headers headers = new Headers.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public final class Http2ConnectionTest {
connection.okHttpSettings.set(INITIAL_WINDOW_SIZE, windowSize);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false);
assertEquals(0, stream.unacknowledgedBytesRead);
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
Source in = stream.getSource();
Buffer buffer = new Buffer();
buffer.writeAll(in);
Expand Down Expand Up @@ -513,7 +513,7 @@ public final class Http2ConnectionTest {
BufferedSink out = Okio.buffer(stream.getSink());
out.writeUtf8("c3po");
out.close();
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
assertStreamData("robot", stream.getSource());
connection.ping().roundTripTime();
assertEquals(0, connection.openStreamCount());
Expand Down Expand Up @@ -914,7 +914,7 @@ public final class Http2ConnectionTest {
// play it back
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("c", "cola"), false);
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
connection.ping().roundTripTime(); // Ensure that the 2nd SYN REPLY has been received.

// verify the peer received what was expected
Expand All @@ -940,7 +940,7 @@ public final class Http2ConnectionTest {
// play it back
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false);
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
assertStreamData("robot", stream.getSource());

// verify the peer received what was expected
Expand Down Expand Up @@ -973,7 +973,7 @@ public final class Http2ConnectionTest {
// play it back
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("a", "android"), false);
assertEquals(headerEntries("b", "banana"), stream.getResponseHeaders());
assertEquals(headerEntries("b", "banana"), stream.takeResponseHeaders());

// verify the peer received what was expected
InFrame synStream = peer.takeFrame();
Expand All @@ -997,7 +997,7 @@ public final class Http2ConnectionTest {
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("a", "android"), false);
try {
stream.getResponseHeaders();
stream.takeResponseHeaders();
fail();
} catch (IOException expected) {
assertEquals("stream was reset: REFUSED_STREAM", expected.getMessage());
Expand Down Expand Up @@ -1189,7 +1189,7 @@ public final class Http2ConnectionTest {
stream.readTimeout().timeout(500, TimeUnit.MILLISECONDS);
long startNanos = System.nanoTime();
try {
stream.getResponseHeaders();
stream.takeResponseHeaders();
fail();
} catch (InterruptedIOException expected) {
}
Expand Down Expand Up @@ -1362,7 +1362,8 @@ public final class Http2ConnectionTest {
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), true);
connection.ping().roundTripTime(); // Ensure that the HEADERS has been received.
assertEquals(headerEntries("a", "android", "c", "c3po"), stream.getResponseHeaders());
assertEquals(Arrays.asList(new Header("a", "android"), null, new Header("c", "c3po")),
stream.takeResponseHeaders());

// verify the peer received what was expected
InFrame synStream = peer.takeFrame();
Expand All @@ -1371,6 +1372,30 @@ public final class Http2ConnectionTest {
assertEquals(Http2.TYPE_PING, ping.type);
}

@Test public void readMultipleSetsOfResponseHeaders() throws Exception {
// write the mocking script
peer.sendFrame().settings(new Settings());
peer.acceptFrame(); // ACK
peer.acceptFrame(); // SYN_STREAM
peer.sendFrame().synReply(false, 3, headerEntries("a", "android"));
peer.acceptFrame(); // PING
peer.sendFrame().ping(true, 1, 0); // PING
peer.sendFrame().synReply(true, 3, headerEntries("c", "cola"));
peer.play();

// play it back
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), true);
stream.getConnection().flush();
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
connection.ping().roundTripTime();
assertEquals(headerEntries("c", "cola"), stream.takeResponseHeaders());

// verify the peer received what was expected
assertEquals(Http2.TYPE_HEADERS, peer.takeFrame().type);
assertEquals(Http2.TYPE_PING, peer.takeFrame().type);
}

@Test public void readSendsWindowUpdate() throws Exception {
int windowSize = 100;
int windowUpdateThreshold = 50;
Expand All @@ -1396,7 +1421,7 @@ public final class Http2ConnectionTest {
connection.okHttpSettings.set(INITIAL_WINDOW_SIZE, windowSize);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false);
assertEquals(0, stream.unacknowledgedBytesRead);
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
Source in = stream.getSource();
Buffer buffer = new Buffer();
buffer.writeAll(in);
Expand Down Expand Up @@ -1474,7 +1499,7 @@ public final class Http2ConnectionTest {
// play it back
Http2Connection connection = connect(peer);
Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false);
assertEquals(headerEntries("a", "android"), stream.getResponseHeaders());
assertEquals(headerEntries("a", "android"), stream.takeResponseHeaders());
Source in = stream.getSource();
try {
Okio.buffer(in).readByteString(101);
Expand Down Expand Up @@ -1540,7 +1565,7 @@ public final class Http2ConnectionTest {

Http2Stream stream = connection.newStream(headerEntries("b", "banana"), false);
try {
stream.getResponseHeaders();
stream.takeResponseHeaders();
fail();
} catch (IOException expected) {
assertEquals("stream was reset: PROTOCOL_ERROR", expected.getMessage());
Expand Down
4 changes: 4 additions & 0 deletions okhttp/src/main/java/okhttp3/OkHttpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ public class OkHttpClient implements Cloneable, Call.Factory, WebSocket.Factory
return connectionPool.routeDatabase;
}

@Override public int code(Response.Builder responseBuilder) {
return responseBuilder.code;
}

@Override
public void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket, boolean isFallback) {
tlsConfiguration.apply(sslSocket, isFallback);
Expand Down
3 changes: 3 additions & 0 deletions okhttp/src/main/java/okhttp3/internal/Internal.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.internal.cache.InternalCache;
import okhttp3.internal.connection.RealConnection;
import okhttp3.internal.connection.RouteDatabase;
Expand Down Expand Up @@ -59,6 +60,8 @@ public abstract RealConnection get(

public abstract RouteDatabase routeDatabase(ConnectionPool connectionPool);

public abstract int code(Response.Builder responseBuilder);

public abstract void apply(ConnectionSpec tlsConfiguration, SSLSocket sslSocket,
boolean isFallback);

Expand Down
Loading

0 comments on commit a589b81

Please sign in to comment.