diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java index 80631674a73..25ecab100a1 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java @@ -293,6 +293,13 @@ ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req, @UnstableApi String authority(); + /** + * Returns the host part of {@link #authority()}, without a port number. + */ + @Nullable + @UnstableApi + String host(); + /** * Returns the {@link URI} constructed based on {@link ClientRequestContext#sessionProtocol()}, * {@link ClientRequestContext#authority()}, {@link ClientRequestContext#path()} and diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java index 9497e005357..91f0c564648 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java @@ -72,6 +72,11 @@ public String authority() { return unwrap().authority(); } + @Override + public String host() { + return unwrap().host(); + } + @Override public URI uri() { return unwrap().uri(); diff --git a/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java b/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java index 72c95c7de73..de2562ff5fd 100644 --- a/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java @@ -19,15 +19,16 @@ import static com.linecorp.armeria.internal.client.ClientUtil.executeWithFallback; import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowAllDomains; import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowSameDomain; +import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.findAuthority; -import java.net.URI; -import java.net.URISyntaxException; +import java.util.Iterator; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.BiPredicate; import java.util.function.Function; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.Multimap; @@ -48,8 +49,9 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestHeadersBuilder; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.ResponseHeaders; -import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.RequestLogBuilder; @@ -70,6 +72,8 @@ final class RedirectingClient extends SimpleDecoratingHttpClient { private static final Set<SessionProtocol> httpAndHttps = Sets.immutableEnumSet(SessionProtocol.HTTP, SessionProtocol.HTTPS); + private static final Splitter pathSplitter = Splitter.on('/'); + static Function<? super HttpClient, RedirectingClient> newDecorator( ClientBuilderParams params, RedirectConfig redirectConfig) { final boolean undefinedUri = Clients.isUndefinedUri(params.uri()); @@ -212,42 +216,54 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx, } final RequestHeaders requestHeaders = log.requestHeaders(); - final URI redirectUri; - try { - redirectUri = URI.create(requestHeaders.path()).resolve(location); - if (redirectUri.isAbsolute()) { - final SessionProtocol redirectProtocol = Scheme.parse(redirectUri.getScheme()) - .sessionProtocol(); - if (!allowedProtocols.contains(redirectProtocol)) { - handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, - UnexpectedProtocolRedirectException.of( - redirectProtocol, allowedProtocols)); - return; - } - if (!domainFilter.test(ctx, redirectUri.getHost())) { - handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, - UnexpectedDomainRedirectException.of(redirectUri.getHost())); - return; - } - } - } catch (Throwable t) { - handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t); + // Resolve the actual redirect location. + final RequestTarget nextReqTarget = resolveLocation(ctx, location); + if (nextReqTarget == null) { + handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, + new IllegalArgumentException("Invalid redirect location: " + location)); return; } - final HttpRequestDuplicator newReqDuplicator = - newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders, redirectUri); + final String nextScheme = nextReqTarget.scheme(); + final String nextAuthority = nextReqTarget.authority(); + final String nextHost = nextReqTarget.host(); + assert nextReqTarget.form() == RequestTargetForm.ABSOLUTE && + nextScheme != null && nextAuthority != null && nextHost != null + : "resolveLocation() must return an absolute request target: " + nextReqTarget; - final String redirectFullUri; try { - redirectFullUri = buildFullUri(ctx, redirectUri, newReqDuplicator.headers()); + // Reject if: + // 1) the protocol is not same with the original one; and + // 2) the protocol is not in the allow-list. + final SessionProtocol nextProtocol = SessionProtocol.of(nextScheme); + if (ctx.sessionProtocol() != nextProtocol && + !allowedProtocols.contains(nextProtocol)) { + handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, + UnexpectedProtocolRedirectException.of( + nextProtocol, allowedProtocols)); + return; + } + + // Reject if: + // 1) the host is not same with the original one; and + // 2) the host does not pass the domain filter. + if (!nextHost.equals(ctx.host()) && + !domainFilter.test(ctx, nextHost)) { + handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, + UnexpectedDomainRedirectException.of(nextHost)); + return; + } } catch (Throwable t) { handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t); return; } - if (isCyclicRedirects(redirectCtx, redirectFullUri, newReqDuplicator.headers())) { + final HttpRequestDuplicator newReqDuplicator = + newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders, + nextReqTarget.toString(), nextAuthority); + + if (isCyclicRedirects(redirectCtx, nextReqTarget.toString(), newReqDuplicator.headers().method())) { handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, CyclicRedirectsException.of(redirectCtx.originalUri(), redirectCtx.redirectUris().values())); @@ -274,17 +290,127 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx, }); } + @Nullable + @VisibleForTesting + static RequestTarget resolveLocation(ClientRequestContext ctx, String location) { + final long length = location.length(); + assert length > 0; + + final String resolvedUri; + if (location.charAt(0) == '/') { + if (length > 1 && location.charAt(1) == '/') { + // No scheme, e.g. //foo.com/bar + resolvedUri = ctx.sessionProtocol().uriText() + ':' + location; + } else { + // No scheme, no authority, e.g. /bar + resolvedUri = ctx.sessionProtocol().uriText() + "://" + ctx.authority() + location; + } + } else { + final int authorityIdx = findAuthority(location); + if (authorityIdx < 0) { + // A relative path, e.g. ./bar + resolvedUri = resolveRelativeLocation(ctx, location); + if (resolvedUri == null) { + return null; + } + } else { + // A full absolute URI, e.g. http://foo.com/bar + // Note that we should normalize an explicit scheme such as `h1c` into `http` or `https`, + // because otherwise a potentially malicious peer can force us to use inefficient protocols + // like HTTP/1. + final SessionProtocol proto = SessionProtocol.find(location.substring(0, authorityIdx - 3)); + if (proto != null) { + switch (proto) { + case HTTP: + case HTTPS: + resolvedUri = location; + break; + default: + if (proto.isHttp()) { + resolvedUri = "http://" + location.substring(authorityIdx); + } else if (proto.isHttps()) { + resolvedUri = "https://" + location.substring(authorityIdx); + } else { + return null; + } + } + } else { + // Unknown scheme. + return null; + } + } + } + + return RequestTarget.forClient(resolvedUri); + } + + @Nullable + private static String resolveRelativeLocation(ClientRequestContext ctx, String location) { + final String originalPath = ctx.path(); + + // Find the base path, e.g. + // - /foo -> / + // - /foo/ -> /foo/ + // - /foo/bar -> /foo/ + final int lastSlashIdx = originalPath.lastIndexOf('/'); + assert lastSlashIdx >= 0 : "originalPath doesn't contain a slash: " + originalPath; + + // Generate the full path. + final String fullPath = originalPath.substring(0, lastSlashIdx + 1) + location; + final Iterator<String> it = pathSplitter.split(fullPath).iterator(); + // Splitter will always emit an empty string as the first component, so we skip it. + assert it.hasNext() && it.next().isEmpty() : fullPath; + + // Resolve `.` and `..` from the full path. + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + final StringBuilder buf = tmp.stringBuilder(); + buf.append(ctx.sessionProtocol().uriText()).append("://").append(ctx.authority()); + final int authorityEndIdx = buf.length(); + while (it.hasNext()) { + final String component = it.next(); + switch (component) { + case ".": + if (!it.hasNext()) { + // Append '/' only when the '.' is the last component, e.g. /foo/. -> /foo/ + buf.append('/'); + } + break; + case "..": + final int idx = buf.lastIndexOf("/"); + if (idx < authorityEndIdx) { + // Too few parents + return null; + } + if (it.hasNext()) { + // Don't keep the '/' because the next component will add it anyway, + // e.g. /foo/../bar -> /bar + buf.delete(idx, buf.length()); + } else { + // Keep the last '/' if the '..' is the last component, + // e.g. /foo/bar/.. -> /foo/ + buf.delete(idx + 1, buf.length()); + } + break; + default: + buf.append('/').append(component); + break; + } + } + + return buf.toString(); + } + } + private static HttpRequestDuplicator newReqDuplicator(HttpRequestDuplicator reqDuplicator, ResponseHeaders responseHeaders, - RequestHeaders requestHeaders, URI newUri) { + RequestHeaders requestHeaders, + String nextUri, + String nextAuthority) { + final RequestHeadersBuilder builder = requestHeaders.toBuilder(); - builder.path(newUri.toString()); - final String newAuthority = newUri.getAuthority(); - if (newAuthority != null) { - // Update the old authority with the new one because the request is redirected to a different - // domain. - builder.authority(newAuthority); - } + builder.path(nextUri); + builder.authority(nextAuthority); + final HttpMethod method = requestHeaders.method(); if (responseHeaders.status() == HttpStatus.SEE_OTHER && !(method == HttpMethod.GET || method == HttpMethod.HEAD)) { @@ -343,36 +469,15 @@ private static void abortResponse(HttpResponse originalRes, ClientRequestContext } } - private static String buildFullUri(ClientRequestContext ctx, URI redirectUri, RequestHeaders newHeaders) - throws URISyntaxException { - // Build the full uri so we don't consider the situation, which session protocol or port is changed, - // as a cyclic redirects. - if (redirectUri.isAbsolute()) { - if (redirectUri.getPort() > 0) { - return redirectUri.toString(); - } - final int port; - if (redirectUri.getScheme().startsWith("https")) { - port = SessionProtocol.HTTPS.defaultPort(); - } else { - port = SessionProtocol.HTTP.defaultPort(); - } - return new URI(redirectUri.getScheme(), redirectUri.getRawUserInfo(), redirectUri.getHost(), port, - redirectUri.getRawPath(), redirectUri.getRawQuery(), redirectUri.getRawFragment()) - .toString(); - } - return buildUri(ctx, newHeaders); - } - - private static boolean isCyclicRedirects(RedirectContext redirectCtx, String redirectUri, - RequestHeaders newHeaders) { - final boolean added = redirectCtx.addRedirectUri(newHeaders.method(), redirectUri); + private static boolean isCyclicRedirects(RedirectContext redirectCtx, + String redirectUri, HttpMethod method) { + final boolean added = redirectCtx.addRedirectUri(method, redirectUri); if (!added) { return true; } return redirectCtx.originalUri().equals(redirectUri) && - redirectCtx.request().method() == newHeaders.method(); + redirectCtx.request().method() == method; } private static String buildUri(ClientRequestContext ctx, RequestHeaders headers) { @@ -391,15 +496,15 @@ private static String buildUri(ClientRequestContext ctx, RequestHeaders headers) if (authority == null) { authority = endpoint.authority(); } - setAuthorityAndPort(ctx, endpoint, sb, authority); + appendAuthority(ctx, endpoint, sb, authority); sb.append(headers.path()); originalUri = sb.toString(); } return originalUri; } - private static void setAuthorityAndPort(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb, - String authority) { + private static void appendAuthority(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb, + String authority) { // Add port number as well so that we don't raise a CyclicRedirectsException when the port is // different. diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java index b9e21288ef1..ba21918e7f7 100644 --- a/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java @@ -154,7 +154,7 @@ protected AbstractRequestContextBuilder(boolean server, RpcRequest rpcReq, URI u this.reqTarget = reqTarget; } else { reqTarget = DefaultRequestTarget.createWithoutValidation( - RequestTargetForm.ORIGIN, null, null, + RequestTargetForm.ORIGIN, null, null, null, -1, uri.getRawPath(), uri.getRawPath(), uri.getRawQuery(), uri.getRawFragment()); } } diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java b/core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java index a8edc350e6d..e67e1309d6e 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java @@ -57,6 +57,7 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.util.StringUtil; +import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; import io.netty.util.AsciiString; @@ -214,18 +215,17 @@ URI uri() { checkState(scheme != null, ":scheme header does not exist."); final String authority = authority(); - final StringBuilder sb = new StringBuilder( - scheme.length() + 1 + - (authority != null ? (authority.length() + 2) : 0) + - path.length()); - sb.append(scheme); - sb.append(':'); - if (authority != null) { - sb.append("//"); - sb.append(authority); + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + final StringBuilder sb = tmp.stringBuilder(); + sb.append(scheme); + sb.append(':'); + if (authority != null) { + sb.append("//"); + sb.append(authority); + } + sb.append(path); + uri = sb.toString(); } - sb.append(path); - uri = sb.toString(); } try { diff --git a/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java b/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java index f2d2332d7bc..5b7241a52f8 100644 --- a/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java +++ b/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java @@ -98,6 +98,24 @@ static RequestTarget forClient(String reqTarget, @Nullable String prefix) { @Nullable String authority(); + /** + * Returns the host of this {@link RequestTarget}. Unlike {@link #authority()}, host doesn't include + * a port number. + * + * @return a non-empty string if {@link #form()} is {@link RequestTargetForm#ABSOLUTE}. + * {@code null} otherwise. + */ + @Nullable + String host(); + + /** + * Returns the port of this {@link RequestTarget}. + * + * @return a positive port number if {@link #form()} is {@link RequestTargetForm#ABSOLUTE} and + * {@link #authority()} has the port number. Zero or a negative value otherwise. + */ + int port(); + /** * Returns the path of this {@link RequestTarget}, which always starts with {@code '/'}. */ diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index 24824fcd08a..59052252f66 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -475,7 +475,7 @@ private void autoFillSchemeAuthorityAndOrigin() { // The connection will be established with the IP address but `host` set to the `Endpoint` // could be used for SNI. It would make users send HTTPS requests with CSLB or configure a reverse // proxy based on an authority. - final String host = HostAndPort.fromString(removeUserInfo(authority)).getHost(); + final String host = authorityToHost(authority); if (!NetUtil.isValidIpV4Address(host) && !NetUtil.isValidIpV6Address(host)) { endpoint = endpoint.withHost(host); } @@ -498,6 +498,10 @@ private void autoFillSchemeAuthorityAndOrigin() { internalRequestHeaders = headersBuilder.build(); } + private static String authorityToHost(String authority) { + return HostAndPort.fromString(removeUserInfo(authority)).getHost(); + } + private static String removeUserInfo(String authority) { final int indexOfDelimiter = authority.lastIndexOf('@'); if (indexOfDelimiter == -1) { @@ -761,11 +765,39 @@ private String origin() { return origin; } + @Override + public String host() { + final String authority = authority(); + if (authority == null) { + return null; + } + + return authorityToHost(authority); + } + @Override public URI uri() { final String scheme = getScheme(sessionProtocol()); - try { - return new URI(scheme, authority(), path(), query(), fragment()); + final String authority = authority(); + final String path = path(); + final String query = query(); + final String fragment = fragment(); + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + final StringBuilder buf = tmp.stringBuilder(); + buf.append(scheme); + if (authority != null) { + buf.append("://").append(authority); + } else { + buf.append(':'); + } + buf.append(path); + if (query != null) { + buf.append('?').append(query); + } + if (fragment != null) { + buf.append('#').append(fragment); + } + return new URI(buf.toString()); } catch (URISyntaxException e) { throw new IllegalStateException("not a valid URI", e); } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java index 38de509d23e..1b627000526 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java @@ -252,6 +252,27 @@ private static LoadingCache<AsciiString, String> buildCache(String spec) { return Caffeine.from(spec).build(AsciiString::toString); } + /** + * Returns the index of the authority part if the specified {@code reqTarget} is an absolute URI. + * Returns {@code -1} otherwise. + */ + public static int findAuthority(String reqTarget) { + final int firstColonIdx = reqTarget.indexOf(':'); + if (firstColonIdx <= 0 || reqTarget.length() <= firstColonIdx + 3) { + return -1; + } + final int firstSlashIdx = reqTarget.indexOf('/'); + if (firstSlashIdx <= 0 || firstSlashIdx < firstColonIdx) { + return -1; + } + + if (reqTarget.charAt(firstColonIdx + 1) == '/' && reqTarget.charAt(firstColonIdx + 2) == '/') { + return firstColonIdx + 3; + } + + return -1; + } + /** * Concatenates the specified {@code prefix} and {@code path} into an absolute path. * diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java index bc20077e969..c7e30cc72bd 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java @@ -15,6 +15,7 @@ */ package com.linecorp.armeria.internal.common; +import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.findAuthority; import static io.netty.util.internal.StringUtil.decodeHexNibble; import static java.util.Objects.requireNonNull; @@ -24,6 +25,7 @@ import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.RequestTargetForm; +import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; @@ -122,6 +124,8 @@ boolean mustPreserveEncoding(int cp) { RequestTargetForm.ASTERISK, null, null, + null, + -1, "*", "*", null, @@ -185,9 +189,10 @@ public static RequestTarget forClient(String reqTarget, @Nullable String prefix) */ public static RequestTarget createWithoutValidation( RequestTargetForm form, @Nullable String scheme, @Nullable String authority, - String path, String pathWithMatrixVariables, @Nullable String query, @Nullable String fragment) { + @Nullable String host, int port, String path, String pathWithMatrixVariables, + @Nullable String query, @Nullable String fragment) { return new DefaultRequestTarget( - form, scheme, authority, path, pathWithMatrixVariables, query, fragment); + form, scheme, authority, host, port, path, pathWithMatrixVariables, query, fragment); } private final RequestTargetForm form; @@ -195,6 +200,9 @@ public static RequestTarget createWithoutValidation( private final String scheme; @Nullable private final String authority; + @Nullable + private final String host; + private final int port; private final String path; private final String maybePathWithMatrixVariables; @Nullable @@ -203,16 +211,20 @@ public static RequestTarget createWithoutValidation( private final String fragment; private boolean cached; - private DefaultRequestTarget(RequestTargetForm form, @Nullable String scheme, @Nullable String authority, + private DefaultRequestTarget(RequestTargetForm form, @Nullable String scheme, + @Nullable String authority, @Nullable String host, int port, String path, String maybePathWithMatrixVariables, @Nullable String query, @Nullable String fragment) { - assert (scheme != null && authority != null) || - (scheme == null && authority == null) : "scheme: " + scheme + ", authority: " + authority; + assert (scheme != null && authority != null && host != null) || + (scheme == null && authority == null && host == null) + : "scheme: " + scheme + ", authority: " + authority + ", host: " + host; this.form = form; this.scheme = scheme; this.authority = authority; + this.host = host; + this.port = port; this.path = path; this.maybePathWithMatrixVariables = maybePathWithMatrixVariables; this.query = query; @@ -234,6 +246,17 @@ public String authority() { return authority; } + @Override + @Nullable + public String host() { + return host; + } + + @Override + public int port() { + return port; + } + @Override public String path() { return path; @@ -369,6 +392,8 @@ private static RequestTarget slowForServer(String reqTarget, boolean allowSemico return new DefaultRequestTarget(RequestTargetForm.ORIGIN, null, null, + null, + -1, matrixVariablesRemovedPath, encodedPath, encodeQueryToPercents(query), @@ -436,12 +461,7 @@ private static RequestTarget slowAbsoluteFormForClient(String reqTarget, int aut } if (nextPos < 0) { - return new DefaultRequestTarget(RequestTargetForm.ABSOLUTE, - schemeAndAuthority.getScheme(), - schemeAndAuthority.getRawAuthority(), - "/", - "/", null, - null); + return newAbsoluteTarget(schemeAndAuthority, "/", null, null); } return slowForClient(reqTarget, schemeAndAuthority, nextPos); @@ -569,41 +589,63 @@ private static RequestTarget slowForClient(String reqTarget, final String encodedFragment = encodeFragmentToPercents(fragment); if (schemeAndAuthority != null) { - return new DefaultRequestTarget(RequestTargetForm.ABSOLUTE, - schemeAndAuthority.getScheme(), - schemeAndAuthority.getRawAuthority(), - encodedPath, - encodedPath, encodedQuery, - encodedFragment); + return newAbsoluteTarget(schemeAndAuthority, encodedPath, encodedQuery, encodedFragment); } else { return new DefaultRequestTarget(RequestTargetForm.ORIGIN, null, null, + null, + -1, encodedPath, encodedPath, encodedQuery, encodedFragment); } } - /** - * Returns the index of the authority part if the specified {@code reqTarget} is an absolute URI. - * Returns {@code -1} otherwise. - */ - private static int findAuthority(String reqTarget) { - final int firstColonIdx = reqTarget.indexOf(':'); - if (firstColonIdx <= 0 || reqTarget.length() <= firstColonIdx + 3) { - return -1; - } - final int firstSlashIdx = reqTarget.indexOf('/'); - if (firstSlashIdx <= 0 || firstSlashIdx < firstColonIdx) { - return -1; - } + private static DefaultRequestTarget newAbsoluteTarget( + URI schemeAndAuthority, String encodedPath, + @Nullable String encodedQuery, @Nullable String encodedFragment) { - if (reqTarget.charAt(firstColonIdx + 1) == '/' && reqTarget.charAt(firstColonIdx + 2) == '/') { - return firstColonIdx + 3; + final String scheme = schemeAndAuthority.getScheme(); + final String maybeAuthority = schemeAndAuthority.getRawAuthority(); + final String maybeHost = schemeAndAuthority.getHost(); + final int maybePort = schemeAndAuthority.getPort(); + final String authority; + final String host; + final int port; + if (maybeHost == null) { + authority = maybeAuthority; + host = maybeAuthority; + port = -1; + } else { + host = maybeHost; + + // Specify the port number only when necessary, so that https://foo/ and https://foo:443/ + // are considered equal. + if (maybePort >= 0) { + final Scheme parsedScheme = Scheme.tryParse(scheme); + if (parsedScheme == null || parsedScheme.sessionProtocol().defaultPort() != maybePort) { + authority = maybeAuthority; + port = maybePort; + } else { + authority = maybeHost; + port = -1; + } + } else { + authority = maybeHost; + port = -1; + } } - return -1; + return new DefaultRequestTarget(RequestTargetForm.ABSOLUTE, + scheme, + authority, + host, + port, + encodedPath, + encodedPath, + encodedQuery, + encodedFragment); } @Nullable diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java index f3190ca3a47..82665f30b03 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java @@ -69,6 +69,7 @@ public abstract class NonWrappingRequestContext implements RequestContextExtensi @Nullable private String decodedPath; + private final Request originalRequest; @Nullable private volatile HttpRequest req; diff --git a/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java b/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java index 3f63f9bb846..09d01399166 100644 --- a/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java +++ b/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java @@ -16,6 +16,7 @@ package com.linecorp.armeria.server; +import static com.google.common.base.Preconditions.checkArgument; import static com.linecorp.armeria.internal.common.DefaultRequestTarget.removeMatrixVariables; import static java.util.Objects.requireNonNull; @@ -148,13 +149,19 @@ default String query() { */ default RoutingContext withPath(String path) { requireNonNull(path, "path"); + final String pathWithoutMatrixVariables = removeMatrixVariables(path); + checkArgument(pathWithoutMatrixVariables != null, + "path with invalid matrix variables: %s", path); + final RequestTarget oldReqTarget = requestTarget(); final RequestTarget newReqTarget = DefaultRequestTarget.createWithoutValidation( oldReqTarget.form(), oldReqTarget.scheme(), oldReqTarget.authority(), - removeMatrixVariables(path), + oldReqTarget.host(), + oldReqTarget.port(), + pathWithoutMatrixVariables, path, oldReqTarget.query(), oldReqTarget.fragment()); diff --git a/core/src/test/java/com/linecorp/armeria/client/DomainSocketClientTest.java b/core/src/test/java/com/linecorp/armeria/client/DomainSocketClientTest.java index 587fc7072d8..32e9424fd15 100644 --- a/core/src/test/java/com/linecorp/armeria/client/DomainSocketClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/DomainSocketClientTest.java @@ -105,6 +105,11 @@ void shouldSupportConnectingToDomainSocket(SessionProtocol protocol, boolean use final String expectedAddress = domainSocketAddress(useAbstractNamespace).toString(); assertThat(ctx.localAddress()).hasToString(expectedAddress); assertThat(ctx.remoteAddress()).hasToString(expectedAddress); + + final String expectedUri = (protocol.isTls() ? "https" : "http") + + baseUri.replaceAll("^[^:]+", "") + + "/greet"; + assertThat(ctx.uri()).hasToString(expectedUri); } } diff --git a/core/src/test/java/com/linecorp/armeria/client/RedirectingClientTest.java b/core/src/test/java/com/linecorp/armeria/client/RedirectingClientTest.java index df92f9aa82b..33b13ba3c21 100644 --- a/core/src/test/java/com/linecorp/armeria/client/RedirectingClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/RedirectingClientTest.java @@ -28,16 +28,25 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.client.logging.LoggingClient; import com.linecorp.armeria.client.redirect.RedirectConfig; import com.linecorp.armeria.client.redirect.TooManyRedirectsException; import com.linecorp.armeria.client.redirect.UnexpectedDomainRedirectException; import com.linecorp.armeria.client.redirect.UnexpectedProtocolRedirectException; import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.logging.RequestLog; import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.ServerPort; @@ -98,7 +107,9 @@ protected void configure(ServerBuilder sb) throws Exception { }); sb.service("/removeDotSegments/foo", (ctx, req) -> HttpResponse.ofRedirect("./bar")) - .service("/removeDotSegments/bar", (ctx, req) -> HttpResponse.of(200)); + .service("/removeDotSegments/bar", (ctx, req) -> HttpResponse.of(200)) + .service("/removeDoubleDotSegments/foo", + (ctx, req) -> HttpResponse.ofRedirect("../removeDotSegments/bar")); sb.service("/loop", (ctx, req) -> HttpResponse.ofRedirect("loop1")) .service("/loop1", (ctx, req) -> HttpResponse.ofRedirect("loop2")) @@ -112,6 +123,16 @@ protected void configure(ServerBuilder sb) throws Exception { return HttpResponse.ofRedirect(HttpStatus.SEE_OTHER, "/differentHttpMethod"); } }); + + sb.service("/unencodedLocation", + (ctx, req) -> HttpResponse.ofRedirect("/unencodedLocation/foo bar?value=${P}")); + sb.service("/unencodedLocation/foo%20bar", (ctx, req) -> { + if ("${P}".equals(ctx.queryParam("value"))) { + return HttpResponse.of(200); + } else { + return HttpResponse.of(400); + } + }); } private int otherHttpPort(ServiceRequestContext ctx) { @@ -147,6 +168,7 @@ void redirect_failExceedingTotalAttempts() { private static AggregatedHttpResponse sendRequest(int maxRedirects) { final WebClient client = WebClient.builder(server.httpUri()) + .decorator(LoggingClient.newDecorator()) .followRedirects(RedirectConfig.builder() .maxRedirects(maxRedirects) .build()) @@ -169,6 +191,7 @@ void protocolChange() { requestCounter.set(0); final WebClient client1 = WebClient.builder(server.httpUri()) .factory(ClientFactory.insecure()) + .decorator(LoggingClient.newDecorator()) // Allows HTTPS by default when allowProtocols is not specified. .followRedirects() .build(); @@ -256,14 +279,17 @@ void referenceResolution() { .factory(ClientFactory.insecure()) .followRedirects() .build(); - final AggregatedHttpResponse res = client.get("/removeDotSegments/foo").aggregate().join(); - assertThat(res.status()).isSameAs(HttpStatus.OK); + final AggregatedHttpResponse res1 = client.get("/removeDotSegments/foo").aggregate().join(); + assertThat(res1.status()).isSameAs(HttpStatus.OK); + final AggregatedHttpResponse res2 = client.get("/removeDoubleDotSegments/foo").aggregate().join(); + assertThat(res2.status()).isSameAs(HttpStatus.OK); } @Test void cyclicRedirectsException() { final WebClient client = Clients.builder(server.httpUri()) .followRedirects() + .decorator(LoggingClient.newDecorator()) .build(WebClient.class); assertThatThrownBy(() -> client.get("/loop").aggregate().join()) .hasMessageContainingAll("The original URI:", "/loop", "Redirect URIs:", "/loop1", "/loop2") @@ -285,6 +311,132 @@ void notCyclicRedirectsWhenHttpMethodDiffers() { } } + @Test + void unencodedLocation() { + final WebClient client = WebClient.builder(server.httpUri()) + .followRedirects() + .build(); + + final ClientRequestContext ctx; + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + assertThat(client.get("/unencodedLocation").aggregate().join().status()).isSameAs(HttpStatus.OK); + ctx = captor.get(); + } + + final ImmutableList<RequestLog> logs = ctx.log().whenComplete().join() + .children().stream() + .map(log -> log.whenComplete().join()) + .collect(toImmutableList()); + + assertThat(logs.size()).isEqualTo(2); + + final ResponseHeaders log1headers = logs.get(0).responseHeaders(); + assertThat(log1headers.status()).isEqualTo(HttpStatus.TEMPORARY_REDIRECT); + assertThat(log1headers.get(HttpHeaderNames.LOCATION)) + .isEqualTo("/unencodedLocation/foo bar?value=${P}"); + + final RequestLog log2 = logs.get(1); + assertThat(log2.requestHeaders().path()).isEqualTo("/unencodedLocation/foo%20bar?value=$%7BP%7D"); + assertThat(log2.responseHeaders().status()).isEqualTo(HttpStatus.OK); + assertThat(log2.context().uri().toString()).endsWith("/unencodedLocation/foo%20bar?value=$%7BP%7D"); + } + + @Test + void testResolveLocation() { + // Absolute paths and URIs should supersede the original path. + assertThat(resolveLocation("/a/", "/")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/a/", "/b")).isEqualTo("h2c://foo/b"); + assertThat(resolveLocation("/a/", "//bar")).isEqualTo("h2c://bar/"); + assertThat(resolveLocation("/a/", "//bar/b")).isEqualTo("h2c://bar/b"); + assertThat(resolveLocation("/a/", "https://bar")).isEqualTo("https://bar/"); + assertThat(resolveLocation("/a/", "https://bar/b")).isEqualTo("https://bar/b"); + + // Should reject the absolute URI with an unknown scheme. + assertThat(resolveLocation("/a/", "a://bar")).isNull(); + + // Should normalize the scheme into "http" or "https" in an absolute URI, + // because we should not trust the response blindly, e.g. DDoS by enforcing HTTP/1. + assertThat(resolveLocation("/a/", "h1c://bar")).isEqualTo("http://bar/"); + assertThat(resolveLocation("/a/", "h1://bar")).isEqualTo("https://bar/"); + + // Simple cases + assertThat(resolveLocation("/", "b")).isEqualTo("h2c://foo/b"); + assertThat(resolveLocation("/", "b/")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/", "b/c")).isEqualTo("h2c://foo/b/c"); + assertThat(resolveLocation("/", "b/c/")).isEqualTo("h2c://foo/b/c/"); + + assertThat(resolveLocation("/a", "b")).isEqualTo("h2c://foo/b"); + assertThat(resolveLocation("/a", "b/")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/a", "b/c")).isEqualTo("h2c://foo/b/c"); + assertThat(resolveLocation("/a", "b/c/")).isEqualTo("h2c://foo/b/c/"); + + assertThat(resolveLocation("/a/", "b")).isEqualTo("h2c://foo/a/b"); + assertThat(resolveLocation("/a/", "b/")).isEqualTo("h2c://foo/a/b/"); + assertThat(resolveLocation("/a/", "b/c")).isEqualTo("h2c://foo/a/b/c"); + assertThat(resolveLocation("/a/", "b/c/")).isEqualTo("h2c://foo/a/b/c/"); + + // Single-dot cases + assertThat(resolveLocation("/", ".")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/", "b/.")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/", "b/./")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/", "b/./c")).isEqualTo("h2c://foo/b/c"); + + assertThat(resolveLocation("/a", ".")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/a", "b/.")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/a", "b/./c")).isEqualTo("h2c://foo/b/c"); + + assertThat(resolveLocation("/a/", ".")).isEqualTo("h2c://foo/a/"); + assertThat(resolveLocation("/a/", "b/.")).isEqualTo("h2c://foo/a/b/"); + assertThat(resolveLocation("/a/", "b/./c")).isEqualTo("h2c://foo/a/b/c"); + + // Double-dot cases + assertThat(resolveLocation("/", "..")).isNull(); + assertThat(resolveLocation("/", "b/..")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/", "b/../")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/", "b/../c")).isEqualTo("h2c://foo/c"); + + assertThat(resolveLocation("/a", "..")).isNull(); + assertThat(resolveLocation("/a", "b/..")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/a", "b/../c")).isEqualTo("h2c://foo/c"); + + assertThat(resolveLocation("/a/", "..")).isEqualTo("h2c://foo/"); + assertThat(resolveLocation("/a/", "b/..")).isEqualTo("h2c://foo/a/"); + assertThat(resolveLocation("/a/", "b/../c")).isEqualTo("h2c://foo/a/c"); + + // Multiple single- or double- dots + assertThat(resolveLocation("/", "././a")).isEqualTo("h2c://foo/a"); + assertThat(resolveLocation("/", "a/././b")).isEqualTo("h2c://foo/a/b"); + assertThat(resolveLocation("/", "a/./.")).isEqualTo("h2c://foo/a/"); + assertThat(resolveLocation("/", "a/././")).isEqualTo("h2c://foo/a/"); + + assertThat(resolveLocation("/a", "././b")).isEqualTo("h2c://foo/b"); + assertThat(resolveLocation("/a", "b/././c")).isEqualTo("h2c://foo/b/c"); + assertThat(resolveLocation("/a", "b/./.")).isEqualTo("h2c://foo/b/"); + assertThat(resolveLocation("/a", "b/././")).isEqualTo("h2c://foo/b/"); + + assertThat(resolveLocation("/a/b/", "../../c")).isEqualTo("h2c://foo/c"); + assertThat(resolveLocation("/a/b/", "c/../../d")).isEqualTo("h2c://foo/a/d"); + assertThat(resolveLocation("/a/b/", "c/../..")).isEqualTo("h2c://foo/a/"); + assertThat(resolveLocation("/a/b/", "c/../../")).isEqualTo("h2c://foo/a/"); + + assertThat(resolveLocation("/a/b", "../../c")).isNull(); + assertThat(resolveLocation("/a/b/c", "../../d")).isEqualTo("h2c://foo/d"); + assertThat(resolveLocation("/a/b/c", "d/../../e")).isEqualTo("h2c://foo/a/e"); + assertThat(resolveLocation("/a/b/c", "d/../..")).isEqualTo("h2c://foo/a/"); + assertThat(resolveLocation("/a/b/c", "d/../../")).isEqualTo("h2c://foo/a/"); + } + + @Nullable + private String resolveLocation(String originalPath, String redirectLocation) { + final HttpRequest req = HttpRequest.builder() + .get(originalPath) + .header(HttpHeaderNames.AUTHORITY, "foo") + .build(); + final ClientRequestContext ctx = ClientRequestContext.of(req); + final RequestTarget result = RedirectingClient.resolveLocation(ctx, redirectLocation); + return result != null ? result.toString() : null; + } + private static ClientFactory localhostAccessingClientFactory() { return ClientFactory.builder().addressResolverGroupFactory( eventLoopGroup -> MockAddressResolverGroup.localhost()).build(); diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java index 8a2a1fbad79..8a51f856220 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java @@ -465,6 +465,11 @@ void shouldAcceptAsteriskPath(Mode mode) { // IP addresses "a://127.0.0.1/, a, 127.0.0.1, /,,", "a://[::1]:80/, a, [::1]:80, /,,", + // default port numbers should be omitted + "http://a:80/, http, a, /,,", + "http://a:443/, http, a:443, /,,", + "https://a:80/, https, a:80, /,,", + "https://a:443/, https, a, /,,", }) void clientShouldAcceptAbsoluteUri(String uri, String expectedScheme, String expectedAuthority, String expectedPath,