Skip to content

Commit

Permalink
Merge pull request #684 from mbfreder/alb-incorrect-server-name
Browse files Browse the repository at this point in the history
Fix: Incorrect ServerName with ALB
  • Loading branch information
deki authored Nov 19, 2023
2 parents b4c3dc5 + 58fdc7c commit 0cfbd8e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package com.amazonaws.serverless.proxy.internal;

import com.amazonaws.serverless.proxy.model.AlbContext;
import com.amazonaws.serverless.proxy.model.AwsProxyRequestContext;
import com.amazonaws.serverless.proxy.model.ContainerConfig;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.slf4j.Logger;
Expand All @@ -21,6 +23,7 @@
import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;

/**
Expand Down Expand Up @@ -60,11 +63,15 @@ public static boolean isValidScheme(String scheme) {
return SCHEMES.contains(scheme);
}

public static boolean isValidHost(String host, String apiId, String region) {
public static boolean isValidHost(String host, String apiId, AlbContext elb, String region) {
if (host == null) {
return false;
}
if (host.endsWith(".amazonaws.com")) {
if (!Objects.isNull(elb)) {
String albhost = new StringBuilder().append(region)
.append(".elb.amazonaws.com").toString();
return host.endsWith(albhost) || LambdaContainerHandler.getContainerConfig().getCustomDomainNames().contains(host);
} else if (host.endsWith(".amazonaws.com")) {
String defaultHost = new StringBuilder().append(apiId)
.append(".execute-api.")
.append(region)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ public String getServerName() {

if (headers != null && headers.containsKey(HOST_HEADER_NAME)) {
String hostHeader = headers.getFirst(HOST_HEADER_NAME);
if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) {
if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), request.getRequestContext().getElb(), region)) {
return hostHeader;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ public String getServerName() {

if (request.getMultiValueHeaders() != null && request.getMultiValueHeaders().containsKey(HOST_HEADER_NAME)) {
String hostHeader = request.getMultiValueHeaders().getFirst(HOST_HEADER_NAME);
if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) {
if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), request.getRequestContext().getElb(), region)) {
return hostHeader;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;

import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

Expand Down Expand Up @@ -640,6 +641,16 @@ void serverName_hostHeader_returnsHostHeaderOnly(String type) {
assertEquals("testapi.com", serverName);
}

@Test
void serverName_albHostHeader_returnsHostHeader() {
initAwsProxyHttpServletRequestTest("ALB");
AwsProxyRequestBuilder proxyReq = new AwsProxyRequestBuilder("/test", "GET")
.header(HttpHeaders.HOST, "testapi.us-east-1.elb.amazonaws.com");
HttpServletRequest servletReq = getRequest(proxyReq, null, null);
String serverName = servletReq.getServerName();
assertEquals("testapi.us-east-1.elb.amazonaws.com", serverName);
}

private AwsProxyRequestBuilder getRequestWithHeaders() {
return new AwsProxyRequestBuilder("/hello", "GET")
.header(CUSTOM_HEADER_KEY, CUSTOM_HEADER_VALUE)
Expand Down

0 comments on commit 0cfbd8e

Please sign in to comment.