Skip to content

Commit

Permalink
Maybe final touch
Browse files Browse the repository at this point in the history
  • Loading branch information
trustin committed Mar 8, 2024
1 parent 46b73f7 commit 7d883b9
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req,
@UnstableApi
String authority();

/**
* Returns the host part of {@link #authority()}, without a port number.
*/
@Nullable
@UnstableApi
String host();

/**
* Returns the {@link URI} constructed based on {@link ClientRequestContext#sessionProtocol()},
* {@link ClientRequestContext#authority()}, {@link ClientRequestContext#path()} and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public String authority() {
return unwrap().authority();
}

@Override
public String host() {
return unwrap().host();
}

@Override
public URI uri() {
return unwrap().uri();
Expand Down
186 changes: 95 additions & 91 deletions core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import com.linecorp.armeria.common.RequestTarget;
import com.linecorp.armeria.common.RequestTargetForm;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.Scheme;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
Expand Down Expand Up @@ -219,55 +218,52 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx,
final RequestHeaders requestHeaders = log.requestHeaders();

// Resolve the actual redirect location.
final String resolvedLocation = resolveLocation(requestHeaders.path(), location);
if (resolvedLocation == null) {
final RequestTarget nextReqTarget = resolveLocation(ctx, location);
if (nextReqTarget == null) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
new IllegalArgumentException("Invalid redirect location: " + location));
return;
}

// Parse and normalize the redirect location.
final RequestTarget redirectTarget = RequestTarget.forClient(resolvedLocation);
if (redirectTarget == null) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
new IllegalArgumentException("Invalid redirect location: " + location));
return;
}
final String nextScheme = nextReqTarget.scheme();
final String nextAuthority = nextReqTarget.authority();
final String nextHost = nextReqTarget.host();
assert nextReqTarget.form() == RequestTargetForm.ABSOLUTE &&
nextScheme != null && nextAuthority != null && nextHost != null
: "resolveLocation() must return an absolute request target: " + nextReqTarget;

try {
if (redirectTarget.form() == RequestTargetForm.ABSOLUTE) {
final SessionProtocol redirectProtocol = Scheme.parse(redirectTarget.scheme())
.sessionProtocol();
if (!allowedProtocols.contains(redirectProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
redirectProtocol, allowedProtocols));
return;
}
// Reject if:
// 1) the protocol is not same with the original one; and
// 2) the protocol is not in the allow-list.
final SessionProtocol nextProtocol = SessionProtocol.of(nextScheme);
if (ctx.sessionProtocol() != nextProtocol &&
!allowedProtocols.contains(nextProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
nextProtocol, allowedProtocols));
return;
}

if (!domainFilter.test(ctx, redirectTarget.host())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(redirectTarget.host()));
return;
}
// Reject if:
// 1) the host is not same with the original one; and
// 2) the host does not pass the domain filter.
if (!nextHost.equals(ctx.host()) &&
!domainFilter.test(ctx, nextHost)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(nextHost));
return;
}
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
return;
}

final HttpRequestDuplicator newReqDuplicator =
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders, redirectTarget);
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders,
nextReqTarget.toString(), nextAuthority);

final String redirectFullUri;
try {
redirectFullUri = buildFullUri(ctx, redirectTarget, newReqDuplicator.headers());
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
return;
}

if (isCyclicRedirects(redirectCtx, redirectFullUri, newReqDuplicator.headers())) {
if (isCyclicRedirects(redirectCtx, nextReqTarget.toString(), newReqDuplicator.headers().method())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
CyclicRedirectsException.of(redirectCtx.originalUri(),
redirectCtx.redirectUris().values()));
Expand Down Expand Up @@ -296,19 +292,62 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx,

@Nullable
@VisibleForTesting
static String resolveLocation(String originalPath, String redirectLocation) {
assert !isNullOrEmpty(redirectLocation) : "redirectLocation is null or empty";

// Use as-is if 1) an absolute path or 2) an absolute URI.
if (redirectLocation.charAt(0) == '/' || findAuthority(redirectLocation) >= 0) {
return redirectLocation;
static RequestTarget resolveLocation(ClientRequestContext ctx, String location) {
final long length = location.length();
assert length > 0 : location;

final String resolvedUri;
if (location.charAt(0) == '/') {
if (length > 1 && location.charAt(1) == '/') {
// No scheme, e.g. //foo.com/bar
resolvedUri = ctx.sessionProtocol().uriText() + ':' + location;
} else {
// No scheme, no authority, e.g. /bar
resolvedUri = ctx.sessionProtocol().uriText() + "://" + ctx.authority() + location;
}
} else {
final int authorityIdx = findAuthority(location);
if (authorityIdx < 0) {
// A relative path, e.g. ./bar
resolvedUri = resolveRelativeLocation(ctx, location);
if (resolvedUri == null) {
return null;
}
} else {
// A full absolute URI, e.g. http://foo.com/bar
// Note that we should normalize an explicit scheme such as `h1c` into `http` or `https`,
// because otherwise a potentially malicious peer can force us to use inefficient protocols
// like HTTP/1.
final SessionProtocol proto = SessionProtocol.find(location.substring(0, authorityIdx - 3));
if (proto != null) {
switch (proto) {
case HTTP:
case HTTPS:
resolvedUri = location;
break;
default:
if (proto.isHttp()) {
resolvedUri = "http://" + location.substring(authorityIdx);
} else if (proto.isHttps()) {
resolvedUri = "https://" + location.substring(authorityIdx);
} else {
return null;
}
}
} else {
// Unknown scheme.
return null;
}
}
}

return resolveLocationSlow(originalPath, redirectLocation);
return RequestTarget.forClient(resolvedUri);
}

@Nullable
private static String resolveLocationSlow(String originalPath, String redirectLocation) {
private static String resolveRelativeLocation(ClientRequestContext ctx, String location) {
final String originalPath = ctx.path();

// Find the base path, e.g.
// - /foo -> /
// - /foo/ -> /foo/
Expand All @@ -317,14 +356,16 @@ private static String resolveLocationSlow(String originalPath, String redirectLo
assert lastSlashIdx >= 0 : "originalPath doesn't contain a slash: " + originalPath;

// Generate the full path.
final String fullPath = originalPath.substring(0, lastSlashIdx + 1) + redirectLocation;
final String fullPath = originalPath.substring(0, lastSlashIdx + 1) + location;
final Iterator<String> it = pathSplitter.split(fullPath).iterator();
// Splitter will always emit an empty string as the first component, so we skip it.
assert it.hasNext() && it.next().isEmpty() : fullPath;

// Resolve `.` and `..` from the full path.
try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) {
final StringBuilder buf = tmp.stringBuilder();
buf.append(ctx.sessionProtocol().uriText()).append("://").append(ctx.authority());
final int authorityEndIdx = buf.length();
while (it.hasNext()) {
final String component = it.next();
switch (component) {
Expand All @@ -336,7 +377,7 @@ private static String resolveLocationSlow(String originalPath, String redirectLo
break;
case "..":
final int idx = buf.lastIndexOf("/");
if (idx < 0) {
if (idx < authorityEndIdx) {
// Too few parents
return null;
}
Expand All @@ -363,15 +404,13 @@ private static String resolveLocationSlow(String originalPath, String redirectLo
private static HttpRequestDuplicator newReqDuplicator(HttpRequestDuplicator reqDuplicator,
ResponseHeaders responseHeaders,
RequestHeaders requestHeaders,
RequestTarget newTarget) {
String nextUri,
String nextAuthority) {

final RequestHeadersBuilder builder = requestHeaders.toBuilder();
builder.path(newTarget.toString());
final String newAuthority = newTarget.authority();
if (newAuthority != null) {
// Update the old authority with the new one because the request is redirected to a different
// domain.
builder.authority(newAuthority);
}
builder.path(nextUri);
builder.authority(nextAuthority);

final HttpMethod method = requestHeaders.method();
if (responseHeaders.status() == HttpStatus.SEE_OTHER &&
!(method == HttpMethod.GET || method == HttpMethod.HEAD)) {
Expand Down Expand Up @@ -430,37 +469,15 @@ private static void abortResponse(HttpResponse originalRes, ClientRequestContext
}
}

private static String buildFullUri(ClientRequestContext ctx,
RequestTarget redirectTarget, RequestHeaders newHeaders) {
// Build the full URI, so we don't consider the situation, which session protocol or port is changed,
// as a cyclic redirects.
if (redirectTarget.form() != RequestTargetForm.ABSOLUTE) {
return buildUri(ctx, newHeaders);
}

if (redirectTarget.port() > 0) {
return redirectTarget.toString();
}

final int port;
if (redirectTarget.scheme().startsWith("https")) {
port = SessionProtocol.HTTPS.defaultPort();
} else {
port = SessionProtocol.HTTP.defaultPort();
}

return buildUri(redirectTarget, port);
}

private static boolean isCyclicRedirects(RedirectContext redirectCtx, String redirectUri,
RequestHeaders newHeaders) {
final boolean added = redirectCtx.addRedirectUri(newHeaders.method(), redirectUri);
private static boolean isCyclicRedirects(RedirectContext redirectCtx,
String redirectUri, HttpMethod method) {
final boolean added = redirectCtx.addRedirectUri(method, redirectUri);
if (!added) {
return true;
}

return redirectCtx.originalUri().equals(redirectUri) &&
redirectCtx.request().method() == newHeaders.method();
redirectCtx.request().method() == method;
}

private static String buildUri(ClientRequestContext ctx, RequestHeaders headers) {
Expand All @@ -486,19 +503,6 @@ private static String buildUri(ClientRequestContext ctx, RequestHeaders headers)
return originalUri;
}

private static String buildUri(RequestTarget redirectTarget, int port) {
final String originalUri;
try (TemporaryThreadLocals threadLocals = TemporaryThreadLocals.acquire()) {
final StringBuilder sb = threadLocals.stringBuilder();
sb.append(redirectTarget.scheme());
sb.append("://");
sb.append(redirectTarget.host()).append(':').append(port);
sb.append(redirectTarget.path());
originalUri = sb.toString();
}
return originalUri;
}

private static void appendAuthority(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb,
String authority) {
// Add port number as well so that we don't raise a CyclicRedirectsException when the port is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ private void autoFillSchemeAuthorityAndOrigin() {
// The connection will be established with the IP address but `host` set to the `Endpoint`
// could be used for SNI. It would make users send HTTPS requests with CSLB or configure a reverse
// proxy based on an authority.
final String host = HostAndPort.fromString(removeUserInfo(authority)).getHost();
final String host = authorityToHost(authority);
if (!NetUtil.isValidIpV4Address(host) && !NetUtil.isValidIpV6Address(host)) {
endpoint = endpoint.withHost(host);
}
Expand All @@ -487,6 +487,10 @@ private void autoFillSchemeAuthorityAndOrigin() {
internalRequestHeaders = headersBuilder.build();
}

private static String authorityToHost(String authority) {
return HostAndPort.fromString(removeUserInfo(authority)).getHost();
}

private static String removeUserInfo(String authority) {
final int indexOfDelimiter = authority.lastIndexOf('@');
if (indexOfDelimiter == -1) {
Expand Down Expand Up @@ -750,6 +754,16 @@ private String origin() {
return origin;
}

@Override
public String host() {
final String authority = authority();
if (authority == null) {
return null;
}

return authorityToHost(authority);
}

@Override
public URI uri() {
final String scheme = getScheme(sessionProtocol());
Expand Down
Loading

0 comments on commit 7d883b9

Please sign in to comment.