Skip to content

Commit

Permalink
Relax the validation of Location header when redirecting (#5477)
Browse files Browse the repository at this point in the history
Reported by @ohadgur at
https://discord.com/channels/1087271586832318494/1209914423494311948

Motivation:

If a client is configured to follow redirects with `followRedirects()`,
the client will validate the value of `Location` header before sending a
follow-up request to the given redirect location. `RedirectingClient`
validates and resolves the target location using `URI.resolve()` which
rejects poorly encoded `Location` header values such as:

- `Location: /foo bar` (space should be percent-encoded)
- `Location: /?${}` (`$`, `{` and `}` should be percent-encoded.)

Such strict validation might not be suitable in real world scenarios and
we could
normalize them just like a client validates and normalizes the initial
request target.

Modifications:

- `RedirectingClient` now uses `RequestTarget.forClient()` to parse and
normalize the target location so it is more tolerant to poorly encoded
`Location` header values.
- `RedirectingClient` now implements its own relative path resolution
logic. See
`RedirectingClient.resolveLocation()` and `resolveRelativeLocation() for
the detail.
- Added `host` and `port` properties to `RequestTarget`.
- Added `host` property to `ClientRequestContext`.
- Moved `DefaultRequestTarget.findAuthority()` to `ArmeriaHttpUtil` to
reuse it in `RedirectingClient`.
- Miscellaneous:
- Fixed a potential bug where `RoutingContext.newPath()` creates a new
`RequestTarget` whose `path` is `null`
- Fixed a bug where `ClientRequestContext.uri()` returns a
double-encoded URI
(Special thanks to @ohadgur for reporting this bug and suggesting a
fix.)
- Improved `RequestTarget.forClient()` to remove the port number in an
absolute
request target when possible, so that `http://a:80` is normalized into
`http://a`.

Result:

- (Bug fix) An Armeria client is now more tolerant to poorly encoded
`Location` header values when following redirects.
- (New feature) You can now get the host and port part separately from a
`RequestTarget`
  using `RequestTarget.host()` and `RequestTarget.port()`.
- (New feature) You can now get the host part from authority of the
request URI using
  `ClientRequestContext.host()`.
- (Improvement) `RequestTarget.forClient()` now removes a redundant port
number from
the specified URI for simpler request target comparison. For example,
`https://foo` and
  `https://foo:443` are considered equal.

---------

Co-authored-by: minux <songmw725@gmail.com>
  • Loading branch information
trustin and minwoox authored Mar 29, 2024
1 parent 8678567 commit db3973d
Show file tree
Hide file tree
Showing 14 changed files with 516 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req,
@UnstableApi
String authority();

/**
* Returns the host part of {@link #authority()}, without a port number.
*/
@Nullable
@UnstableApi
String host();

/**
* Returns the {@link URI} constructed based on {@link ClientRequestContext#sessionProtocol()},
* {@link ClientRequestContext#authority()}, {@link ClientRequestContext#path()} and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public String authority() {
return unwrap().authority();
}

@Override
public String host() {
return unwrap().host();
}

@Override
public URI uri() {
return unwrap().uri();
Expand Down
233 changes: 169 additions & 64 deletions core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import static com.linecorp.armeria.internal.client.ClientUtil.executeWithFallback;
import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowAllDomains;
import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowSameDomain;
import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.findAuthority;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiPredicate;
import java.util.function.Function;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Multimap;
Expand All @@ -48,8 +49,9 @@
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.RequestTarget;
import com.linecorp.armeria.common.RequestTargetForm;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.Scheme;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
Expand All @@ -70,6 +72,8 @@ final class RedirectingClient extends SimpleDecoratingHttpClient {
private static final Set<SessionProtocol> httpAndHttps =
Sets.immutableEnumSet(SessionProtocol.HTTP, SessionProtocol.HTTPS);

private static final Splitter pathSplitter = Splitter.on('/');

static Function<? super HttpClient, RedirectingClient> newDecorator(
ClientBuilderParams params, RedirectConfig redirectConfig) {
final boolean undefinedUri = Clients.isUndefinedUri(params.uri());
Expand Down Expand Up @@ -212,42 +216,54 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx,
}

final RequestHeaders requestHeaders = log.requestHeaders();
final URI redirectUri;
try {
redirectUri = URI.create(requestHeaders.path()).resolve(location);
if (redirectUri.isAbsolute()) {
final SessionProtocol redirectProtocol = Scheme.parse(redirectUri.getScheme())
.sessionProtocol();
if (!allowedProtocols.contains(redirectProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
redirectProtocol, allowedProtocols));
return;
}

if (!domainFilter.test(ctx, redirectUri.getHost())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(redirectUri.getHost()));
return;
}
}
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
// Resolve the actual redirect location.
final RequestTarget nextReqTarget = resolveLocation(ctx, location);
if (nextReqTarget == null) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
new IllegalArgumentException("Invalid redirect location: " + location));
return;
}

final HttpRequestDuplicator newReqDuplicator =
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders, redirectUri);
final String nextScheme = nextReqTarget.scheme();
final String nextAuthority = nextReqTarget.authority();
final String nextHost = nextReqTarget.host();
assert nextReqTarget.form() == RequestTargetForm.ABSOLUTE &&
nextScheme != null && nextAuthority != null && nextHost != null
: "resolveLocation() must return an absolute request target: " + nextReqTarget;

final String redirectFullUri;
try {
redirectFullUri = buildFullUri(ctx, redirectUri, newReqDuplicator.headers());
// Reject if:
// 1) the protocol is not same with the original one; and
// 2) the protocol is not in the allow-list.
final SessionProtocol nextProtocol = SessionProtocol.of(nextScheme);
if (ctx.sessionProtocol() != nextProtocol &&
!allowedProtocols.contains(nextProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
nextProtocol, allowedProtocols));
return;
}

// Reject if:
// 1) the host is not same with the original one; and
// 2) the host does not pass the domain filter.
if (!nextHost.equals(ctx.host()) &&
!domainFilter.test(ctx, nextHost)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(nextHost));
return;
}
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
return;
}

if (isCyclicRedirects(redirectCtx, redirectFullUri, newReqDuplicator.headers())) {
final HttpRequestDuplicator newReqDuplicator =
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders,
nextReqTarget.toString(), nextAuthority);

if (isCyclicRedirects(redirectCtx, nextReqTarget.toString(), newReqDuplicator.headers().method())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
CyclicRedirectsException.of(redirectCtx.originalUri(),
redirectCtx.redirectUris().values()));
Expand All @@ -274,17 +290,127 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx,
});
}

@Nullable
@VisibleForTesting
static RequestTarget resolveLocation(ClientRequestContext ctx, String location) {
final long length = location.length();
assert length > 0;

final String resolvedUri;
if (location.charAt(0) == '/') {
if (length > 1 && location.charAt(1) == '/') {
// No scheme, e.g. //foo.com/bar
resolvedUri = ctx.sessionProtocol().uriText() + ':' + location;
} else {
// No scheme, no authority, e.g. /bar
resolvedUri = ctx.sessionProtocol().uriText() + "://" + ctx.authority() + location;
}
} else {
final int authorityIdx = findAuthority(location);
if (authorityIdx < 0) {
// A relative path, e.g. ./bar
resolvedUri = resolveRelativeLocation(ctx, location);
if (resolvedUri == null) {
return null;
}
} else {
// A full absolute URI, e.g. http://foo.com/bar
// Note that we should normalize an explicit scheme such as `h1c` into `http` or `https`,
// because otherwise a potentially malicious peer can force us to use inefficient protocols
// like HTTP/1.
final SessionProtocol proto = SessionProtocol.find(location.substring(0, authorityIdx - 3));
if (proto != null) {
switch (proto) {
case HTTP:
case HTTPS:
resolvedUri = location;
break;
default:
if (proto.isHttp()) {
resolvedUri = "http://" + location.substring(authorityIdx);
} else if (proto.isHttps()) {
resolvedUri = "https://" + location.substring(authorityIdx);
} else {
return null;
}
}
} else {
// Unknown scheme.
return null;
}
}
}

return RequestTarget.forClient(resolvedUri);
}

@Nullable
private static String resolveRelativeLocation(ClientRequestContext ctx, String location) {
final String originalPath = ctx.path();

// Find the base path, e.g.
// - /foo -> /
// - /foo/ -> /foo/
// - /foo/bar -> /foo/
final int lastSlashIdx = originalPath.lastIndexOf('/');
assert lastSlashIdx >= 0 : "originalPath doesn't contain a slash: " + originalPath;

// Generate the full path.
final String fullPath = originalPath.substring(0, lastSlashIdx + 1) + location;
final Iterator<String> it = pathSplitter.split(fullPath).iterator();
// Splitter will always emit an empty string as the first component, so we skip it.
assert it.hasNext() && it.next().isEmpty() : fullPath;

// Resolve `.` and `..` from the full path.
try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) {
final StringBuilder buf = tmp.stringBuilder();
buf.append(ctx.sessionProtocol().uriText()).append("://").append(ctx.authority());
final int authorityEndIdx = buf.length();
while (it.hasNext()) {
final String component = it.next();
switch (component) {
case ".":
if (!it.hasNext()) {
// Append '/' only when the '.' is the last component, e.g. /foo/. -> /foo/
buf.append('/');
}
break;
case "..":
final int idx = buf.lastIndexOf("/");
if (idx < authorityEndIdx) {
// Too few parents
return null;
}
if (it.hasNext()) {
// Don't keep the '/' because the next component will add it anyway,
// e.g. /foo/../bar -> /bar
buf.delete(idx, buf.length());
} else {
// Keep the last '/' if the '..' is the last component,
// e.g. /foo/bar/.. -> /foo/
buf.delete(idx + 1, buf.length());
}
break;
default:
buf.append('/').append(component);
break;
}
}

return buf.toString();
}
}

private static HttpRequestDuplicator newReqDuplicator(HttpRequestDuplicator reqDuplicator,
ResponseHeaders responseHeaders,
RequestHeaders requestHeaders, URI newUri) {
RequestHeaders requestHeaders,
String nextUri,
String nextAuthority) {

final RequestHeadersBuilder builder = requestHeaders.toBuilder();
builder.path(newUri.toString());
final String newAuthority = newUri.getAuthority();
if (newAuthority != null) {
// Update the old authority with the new one because the request is redirected to a different
// domain.
builder.authority(newAuthority);
}
builder.path(nextUri);
builder.authority(nextAuthority);

final HttpMethod method = requestHeaders.method();
if (responseHeaders.status() == HttpStatus.SEE_OTHER &&
!(method == HttpMethod.GET || method == HttpMethod.HEAD)) {
Expand Down Expand Up @@ -343,36 +469,15 @@ private static void abortResponse(HttpResponse originalRes, ClientRequestContext
}
}

private static String buildFullUri(ClientRequestContext ctx, URI redirectUri, RequestHeaders newHeaders)
throws URISyntaxException {
// Build the full uri so we don't consider the situation, which session protocol or port is changed,
// as a cyclic redirects.
if (redirectUri.isAbsolute()) {
if (redirectUri.getPort() > 0) {
return redirectUri.toString();
}
final int port;
if (redirectUri.getScheme().startsWith("https")) {
port = SessionProtocol.HTTPS.defaultPort();
} else {
port = SessionProtocol.HTTP.defaultPort();
}
return new URI(redirectUri.getScheme(), redirectUri.getRawUserInfo(), redirectUri.getHost(), port,
redirectUri.getRawPath(), redirectUri.getRawQuery(), redirectUri.getRawFragment())
.toString();
}
return buildUri(ctx, newHeaders);
}

private static boolean isCyclicRedirects(RedirectContext redirectCtx, String redirectUri,
RequestHeaders newHeaders) {
final boolean added = redirectCtx.addRedirectUri(newHeaders.method(), redirectUri);
private static boolean isCyclicRedirects(RedirectContext redirectCtx,
String redirectUri, HttpMethod method) {
final boolean added = redirectCtx.addRedirectUri(method, redirectUri);
if (!added) {
return true;
}

return redirectCtx.originalUri().equals(redirectUri) &&
redirectCtx.request().method() == newHeaders.method();
redirectCtx.request().method() == method;
}

private static String buildUri(ClientRequestContext ctx, RequestHeaders headers) {
Expand All @@ -391,15 +496,15 @@ private static String buildUri(ClientRequestContext ctx, RequestHeaders headers)
if (authority == null) {
authority = endpoint.authority();
}
setAuthorityAndPort(ctx, endpoint, sb, authority);
appendAuthority(ctx, endpoint, sb, authority);
sb.append(headers.path());
originalUri = sb.toString();
}
return originalUri;
}

private static void setAuthorityAndPort(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb,
String authority) {
private static void appendAuthority(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb,
String authority) {
// Add port number as well so that we don't raise a CyclicRedirectsException when the port is
// different.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected AbstractRequestContextBuilder(boolean server, RpcRequest rpcReq, URI u
this.reqTarget = reqTarget;
} else {
reqTarget = DefaultRequestTarget.createWithoutValidation(
RequestTargetForm.ORIGIN, null, null,
RequestTargetForm.ORIGIN, null, null, null, -1,
uri.getRawPath(), uri.getRawPath(), uri.getRawQuery(), uri.getRawFragment());
}
}
Expand Down
22 changes: 11 additions & 11 deletions core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.util.StringUtil;
import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals;

import io.netty.util.AsciiString;

Expand Down Expand Up @@ -214,18 +215,17 @@ URI uri() {
checkState(scheme != null, ":scheme header does not exist.");
final String authority = authority();

final StringBuilder sb = new StringBuilder(
scheme.length() + 1 +
(authority != null ? (authority.length() + 2) : 0) +
path.length());
sb.append(scheme);
sb.append(':');
if (authority != null) {
sb.append("//");
sb.append(authority);
try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) {
final StringBuilder sb = tmp.stringBuilder();
sb.append(scheme);
sb.append(':');
if (authority != null) {
sb.append("//");
sb.append(authority);
}
sb.append(path);
uri = sb.toString();
}
sb.append(path);
uri = sb.toString();
}

try {
Expand Down
Loading

0 comments on commit db3973d

Please sign in to comment.