diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index 3d53bf104..13a52604c 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -49,6 +49,7 @@ public class SFLoginInput { private String inFlightCtx; // Opaque string sent for Snowsight account activation private boolean disableConsoleLogin = true; + private boolean disableSamlURLCheck = false; // Additional headers to add for Snowsight. Map additionalHttpHeadersForSnowsight; @@ -378,6 +379,15 @@ SFLoginInput setInFlightCtx(String inFlightCtx) { return this; } + boolean getDisableSamlURLCheck() { + return disableSamlURLCheck; + } + + SFLoginInput setDisableSamlURLCheck(boolean disableSamlURLCheck) { + this.disableSamlURLCheck = disableSamlURLCheck; + return this; + } + Map getAdditionalHttpHeadersForSnowsight() { return additionalHttpHeadersForSnowsight; } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index 5f653019d..3d0900940 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -608,7 +608,12 @@ public synchronized void open() throws SFException, SnowflakeSQLException { connectionPropertiesMap.get(SFSessionProperty.DISABLE_CONSOLE_LOGIN) != null ? getBooleanValue( connectionPropertiesMap.get(SFSessionProperty.DISABLE_CONSOLE_LOGIN)) - : true); + : true) + .setDisableSamlURLCheck( + connectionPropertiesMap.get(SFSessionProperty.DISABLE_SAML_URL_CHECK) != null + ? getBooleanValue( + connectionPropertiesMap.get(SFSessionProperty.DISABLE_SAML_URL_CHECK)) + : false); // Enable or disable OOB telemetry based on connection parameter. Default is disabled. // The value may still change later when session parameters from the server are read. diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java index 0ca91809c..359448d24 100644 --- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java +++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java @@ -82,7 +82,9 @@ public enum SFSessionProperty { DISABLE_GCS_DEFAULT_CREDENTIALS("disableGcsDefaultCredentials", false, Boolean.class), - JDBC_ARROW_TREAT_DECIMAL_AS_INT("JDBC_ARROW_TREAT_DECIMAL_AS_INT", false, Boolean.class); + JDBC_ARROW_TREAT_DECIMAL_AS_INT("JDBC_ARROW_TREAT_DECIMAL_AS_INT", false, Boolean.class), + + DISABLE_SAML_URL_CHECK("disableSamlURLCheck", false, Boolean.class); // property key in string private String propertyKey; diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index a3421e841..ec856112d 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -1154,6 +1154,16 @@ private static String federatedFlowStep4( loginInput.getHttpClientSettingsKey()); // step 5 + validateSAML(responseHtml, loginInput); + } catch (IOException | URISyntaxException ex) { + handleFederatedFlowError(loginInput, ex); + } + return responseHtml; + } + + private static void validateSAML(String responseHtml, SFLoginInput loginInput) + throws SnowflakeSQLException, MalformedURLException { + if (!loginInput.getDisableSamlURLCheck()) { String postBackUrl = getPostBackUrlFromHTML(responseHtml); if (!isPrefixEqual(postBackUrl, loginInput.getServerUrl())) { URL idpDestinationUrl = new URL(postBackUrl); @@ -1167,18 +1177,13 @@ private static String federatedFlowStep4( clientDestinationHostName, idpDestinationHostName); - // Session is in process of getting created, so exception constructor takes in null session - // value + // Session is in process of getting created, so exception constructor takes in null throw new SnowflakeSQLLoggedException( null, ErrorCode.IDP_INCORRECT_DESTINATION.getMessageCode(), - SqlState.SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION - /* session = */ ); + SqlState.SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION); } - } catch (IOException | URISyntaxException ex) { - handleFederatedFlowError(loginInput, ex); } - return responseHtml; } /** diff --git a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java index dd6d5e7bd..f936ee616 100644 --- a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java +++ b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java @@ -465,4 +465,171 @@ public void testOktaAuthRetry() throws Throwable { SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL"); } } + + /** + * Tests the disableSamlURLCheck. If the disableSamlUrl is provided to the login input with true, + * the driver will skip checking the format of the saml URL response. This latest test will work + * with jdbc > 3.16.0 + * + * @throws Throwable + */ + @Test + public void testOktaDisableSamlUrlCheck() throws Throwable { + SFLoginInput loginInput = createOktaLoginInput(); + loginInput.setDisableSamlURLCheck(true); + Map connectionPropertiesMap = initConnectionPropertiesMap(); + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpPost.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\"," + + "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\"," + + "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeRequestWithoutCookies( + Mockito.any(HttpRequestBase.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(AtomicBoolean.class), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpGet.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn("
"); + + SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL"); + } + } + + @Test + public void testInvalidOktaSamlFormat() throws Throwable { + SFLoginInput loginInput = createOktaLoginInput(); + Map connectionPropertiesMap = initConnectionPropertiesMap(); + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpPost.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\"," + + "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\"," + + "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeRequestWithoutCookies( + Mockito.any(HttpRequestBase.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(AtomicBoolean.class), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpGet.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn("
"); + + SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL"); + fail("Should be failed because of the invalid form"); + } catch (SnowflakeSQLException ex) { + assertEquals((int) ErrorCode.NETWORK_ERROR.getMessageCode(), ex.getErrorCode()); + } + } + + @Test + public void testOktaWithInvalidHostName() throws Throwable { + SFLoginInput loginInput = createOktaLoginInput(); + Map connectionPropertiesMap = initConnectionPropertiesMap(); + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpPost.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\"," + + "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\"," + + "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeRequestWithoutCookies( + Mockito.any(HttpRequestBase.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(AtomicBoolean.class), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn( + "{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}"); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpGet.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenReturn("
"); + + SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL"); + fail("Should be failed because of the invalid form"); + } catch (SnowflakeSQLException ex) { + assertEquals((int) ErrorCode.IDP_INCORRECT_DESTINATION.getMessageCode(), ex.getErrorCode()); + } + } }