From 58fdc7c00a3df436270d00954fe476a3a162fc9e Mon Sep 17 00:00:00 2001 From: mbfreder Date: Thu, 16 Nov 2023 10:22:06 -0800 Subject: [PATCH] Fix: Incorrect ServerName with ALB --- .../serverless/proxy/internal/SecurityUtils.java | 11 +++++++++-- .../servlet/AwsHttpApiV2ProxyHttpServletRequest.java | 2 +- .../internal/servlet/AwsProxyHttpServletRequest.java | 2 +- .../servlet/AwsProxyHttpServletRequestTest.java | 11 +++++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java index d553d17a7..3f52f8950 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java @@ -12,6 +12,8 @@ */ package com.amazonaws.serverless.proxy.internal; +import com.amazonaws.serverless.proxy.model.AlbContext; +import com.amazonaws.serverless.proxy.model.AwsProxyRequestContext; import com.amazonaws.serverless.proxy.model.ContainerConfig; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import org.slf4j.Logger; @@ -21,6 +23,7 @@ import java.io.IOException; import java.util.HashSet; import java.util.Locale; +import java.util.Objects; import java.util.Set; /** @@ -60,11 +63,15 @@ public static boolean isValidScheme(String scheme) { return SCHEMES.contains(scheme); } - public static boolean isValidHost(String host, String apiId, String region) { + public static boolean isValidHost(String host, String apiId, AlbContext elb, String region) { if (host == null) { return false; } - if (host.endsWith(".amazonaws.com")) { + if (!Objects.isNull(elb)) { + String albhost = new StringBuilder().append(region) + .append(".elb.amazonaws.com").toString(); + return host.endsWith(albhost) || LambdaContainerHandler.getContainerConfig().getCustomDomainNames().contains(host); + } else if (host.endsWith(".amazonaws.com")) { String defaultHost = new StringBuilder().append(apiId) .append(".execute-api.") .append(region) diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java index bdf11b819..c46abce22 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java @@ -357,7 +357,7 @@ public String getServerName() { if (headers != null && headers.containsKey(HOST_HEADER_NAME)) { String hostHeader = headers.getFirst(HOST_HEADER_NAME); - if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) { + if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), request.getRequestContext().getElb(), region)) { return hostHeader; } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index a4ee15150..798ca82a4 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -405,7 +405,7 @@ public String getServerName() { if (request.getMultiValueHeaders() != null && request.getMultiValueHeaders().containsKey(HOST_HEADER_NAME)) { String hostHeader = request.getMultiValueHeaders().getFirst(HOST_HEADER_NAME); - if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) { + if (SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), request.getRequestContext().getElb(), region)) { return hostHeader; } } diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java index 855fae403..76507508b 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java @@ -5,6 +5,7 @@ import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder; import com.amazonaws.services.lambda.runtime.Context; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -640,6 +641,16 @@ void serverName_hostHeader_returnsHostHeaderOnly(String type) { assertEquals("testapi.com", serverName); } + @Test + void serverName_albHostHeader_returnsHostHeader() { + initAwsProxyHttpServletRequestTest("ALB"); + AwsProxyRequestBuilder proxyReq = new AwsProxyRequestBuilder("/test", "GET") + .header(HttpHeaders.HOST, "testapi.us-east-1.elb.amazonaws.com"); + HttpServletRequest servletReq = getRequest(proxyReq, null, null); + String serverName = servletReq.getServerName(); + assertEquals("testapi.us-east-1.elb.amazonaws.com", serverName); + } + private AwsProxyRequestBuilder getRequestWithHeaders() { return new AwsProxyRequestBuilder("/hello", "GET") .header(CUSTOM_HEADER_KEY, CUSTOM_HEADER_VALUE)