diff --git a/packages/credential-provider-ini/src/fromIni.ts b/packages/credential-provider-ini/src/fromIni.ts index 9fcb65e376db..160e4fbe08c5 100644 --- a/packages/credential-provider-ini/src/fromIni.ts +++ b/packages/credential-provider-ini/src/fromIni.ts @@ -39,7 +39,8 @@ export interface FromIniInit extends SourceProfileInit, CredentialProviderOption roleAssumerWithWebIdentity?: (params: AssumeRoleWithWebIdentityParams) => Promise; /** - * STSClientConfig to be used for creating STS Client for assuming role. + * STSClientConfig or SSOClientConfig to be used for creating inner client + * for auth operations. * @internal */ clientConfig?: any; diff --git a/packages/credential-provider-ini/src/resolveSsoCredentials.spec.ts b/packages/credential-provider-ini/src/resolveSsoCredentials.spec.ts index 2953cb51fce0..e9fb9f38475f 100644 --- a/packages/credential-provider-ini/src/resolveSsoCredentials.spec.ts +++ b/packages/credential-provider-ini/src/resolveSsoCredentials.spec.ts @@ -56,4 +56,39 @@ describe(resolveSsoCredentials.name, () => { profile: mockProfileName, }); }); + + it("passes through clientConfig and parentClientConfig to the fromSSO provider", async () => { + const mockProfileName = "mockProfileName"; + const mockCreds: AwsCredentialIdentity = { + accessKeyId: "mockAccessKeyId", + secretAccessKey: "mockSecretAccessKey", + }; + const requestHandler = vi.fn(); + const logger = vi.fn(); + + vi.mocked(fromSSO).mockReturnValue(() => Promise.resolve(mockCreds)); + + const receivedCreds = await resolveSsoCredentials( + mockProfileName, + {}, + { + clientConfig: { + requestHandler, + }, + parentClientConfig: { + logger, + }, + } + ); + expect(receivedCreds).toStrictEqual(mockCreds); + expect(fromSSO).toHaveBeenCalledWith({ + profile: mockProfileName, + clientConfig: { + requestHandler, + }, + parentClientConfig: { + logger, + }, + }); + }); }); diff --git a/packages/credential-provider-ini/src/resolveSsoCredentials.ts b/packages/credential-provider-ini/src/resolveSsoCredentials.ts index e26ac7b5726d..9e0a9d701855 100644 --- a/packages/credential-provider-ini/src/resolveSsoCredentials.ts +++ b/packages/credential-provider-ini/src/resolveSsoCredentials.ts @@ -1,20 +1,19 @@ import { setCredentialFeature } from "@aws-sdk/core/client"; import type { SsoProfile } from "@aws-sdk/credential-provider-sso"; -import type { CredentialProviderOptions } from "@aws-sdk/types"; import type { IniSection, Profile } from "@smithy/types"; +import type { FromIniInit } from "./fromIni"; + /** * @internal */ -export const resolveSsoCredentials = async ( - profile: string, - profileData: IniSection, - options: CredentialProviderOptions = {} -) => { +export const resolveSsoCredentials = async (profile: string, profileData: IniSection, options: FromIniInit = {}) => { const { fromSSO } = await import("@aws-sdk/credential-provider-sso"); return fromSSO({ profile, logger: options.logger, + parentClientConfig: options.parentClientConfig, + clientConfig: options.clientConfig, })().then((creds) => { if (profileData.sso_session) { return setCredentialFeature(creds, "CREDENTIALS_PROFILE_SSO", "r"); diff --git a/packages/credential-provider-sso/src/fromSSO.ts b/packages/credential-provider-sso/src/fromSSO.ts index 0e1d3ab56aae..1960b4855b63 100644 --- a/packages/credential-provider-sso/src/fromSSO.ts +++ b/packages/credential-provider-sso/src/fromSSO.ts @@ -133,6 +133,7 @@ export const fromSSO = ssoRoleName: sso_role_name, ssoClient: ssoClient, clientConfig: init.clientConfig, + parentClientConfig: init.parentClientConfig, profile: profileName, }); } else if (!ssoStartUrl || !ssoAccountId || !ssoRegion || !ssoRoleName) { @@ -150,6 +151,7 @@ export const fromSSO = ssoRoleName, ssoClient, clientConfig: init.clientConfig, + parentClientConfig: init.parentClientConfig, profile: profileName, }); } diff --git a/packages/credential-provider-sso/src/resolveSSOCredentials.ts b/packages/credential-provider-sso/src/resolveSSOCredentials.ts index cb80be5ab0cb..ecd5a06e5d0c 100644 --- a/packages/credential-provider-sso/src/resolveSSOCredentials.ts +++ b/packages/credential-provider-sso/src/resolveSSOCredentials.ts @@ -20,6 +20,7 @@ export const resolveSSOCredentials = async ({ ssoRoleName, ssoClient, clientConfig, + parentClientConfig, profile, logger, }: FromSSOInit & SsoCredentialsParameters): Promise => { @@ -65,6 +66,7 @@ export const resolveSSOCredentials = async ({ ssoClient || new SSOClient( Object.assign({}, clientConfig ?? {}, { + logger: clientConfig?.logger ?? parentClientConfig?.logger, region: clientConfig?.region ?? ssoRegion, }) ); diff --git a/packages/token-providers/src/fromSso.spec.ts b/packages/token-providers/src/fromSso.spec.ts index 84418060d8ca..c688ce544b07 100644 --- a/packages/token-providers/src/fromSso.spec.ts +++ b/packages/token-providers/src/fromSso.spec.ts @@ -48,6 +48,7 @@ describe(fromSso.name, () => { accessToken: "mockNewAccessToken", expiresIn: 3600, refreshToken: "mockNewRefreshToken", + $metadata: {}, }; const mockNewToken = { token: mockNewTokenFromService.accessToken, @@ -166,7 +167,7 @@ describe(fromSso.name, () => { const { fromSso } = await import("./fromSso"); await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken); expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1); - expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region); + expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit); // Simulate token expiration. const ssoTokenExpiryError = new TokenProviderError(`SSO Token is expired. ${REFRESH_MESSAGE}`, false); @@ -182,7 +183,7 @@ describe(fromSso.name, () => { const { fromSso } = await import("./fromSso"); await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken); expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1); - expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region); + expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit); // Return a valid token for second call. const mockValidSsoToken = { @@ -230,7 +231,11 @@ describe(fromSso.name, () => { token: mockValidSsoTokenInExpiryWindow.accessToken, expiration: new Date(mockValidSsoTokenInExpiryWindow.expiresAt), }); - expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockValidSsoTokenInExpiryWindow, mockSsoSession.sso_region); + expect(getNewSsoOidcToken).toHaveBeenCalledWith( + mockValidSsoTokenInExpiryWindow, + mockSsoSession.sso_region, + mockInit + ); }; const throwErrorExpiredTokenTest = async (fromSsoImpl: typeof fromSso) => { @@ -239,7 +244,7 @@ describe(fromSso.name, () => { throw ssoTokenExpiryError; }); await expect(fromSsoImpl(mockInit)()).rejects.toStrictEqual(ssoTokenExpiryError); - expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region); + expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit); }; afterEach(() => { @@ -285,7 +290,7 @@ describe(fromSso.name, () => { const { fromSso } = await import("./fromSso"); await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken); expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1); - expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region); + expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit); expect(writeSSOTokenToFile).toHaveBeenCalledWith(mockSsoSessionName, { ...mockSsoToken, diff --git a/packages/token-providers/src/fromSso.ts b/packages/token-providers/src/fromSso.ts index a65e6febd3ef..35957281f1cb 100644 --- a/packages/token-providers/src/fromSso.ts +++ b/packages/token-providers/src/fromSso.ts @@ -20,7 +20,12 @@ import { writeSSOTokenToFile } from "./writeSSOTokenToFile"; */ const lastRefreshAttemptTime = new Date(0); -export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions {} +export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions { + /** + * @see SSOOIDCClientConfig in \@aws-sdk/client-sso-oidc. + */ + clientConfig?: any; +} /** * Creates a token provider that will read from SSO token cache or ssoOidc.createToken() call. @@ -101,7 +106,7 @@ export const fromSso = try { lastRefreshAttemptTime.setTime(Date.now()); - const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion); + const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion, init); validateTokenKey("accessToken", newSsoOidcToken.accessToken); validateTokenKey("expiresIn", newSsoOidcToken.expiresIn); const newTokenExpiration = new Date(Date.now() + newSsoOidcToken.expiresIn! * 1000); diff --git a/packages/token-providers/src/getNewSsoOidcToken.spec.ts b/packages/token-providers/src/getNewSsoOidcToken.spec.ts index b84c51c2bf32..e12ab1302880 100644 --- a/packages/token-providers/src/getNewSsoOidcToken.spec.ts +++ b/packages/token-providers/src/getNewSsoOidcToken.spec.ts @@ -49,7 +49,7 @@ describe(getNewSsoOidcToken.name, () => { } catch (error) { expect(error).toStrictEqual(mockError); } - expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion); + expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {}); expect(mockSend).not.toHaveBeenCalled(); expect(CreateTokenCommand).not.toHaveBeenCalled(); }); @@ -63,7 +63,7 @@ describe(getNewSsoOidcToken.name, () => { } catch (error) { expect(error).toStrictEqual(mockError); } - expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion); + expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {}); expect(mockSendWithError).toHaveBeenCalledWith(mockCreateTokenArgs); expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs); }); @@ -78,7 +78,7 @@ describe(getNewSsoOidcToken.name, () => { } catch (error) { expect(error).toStrictEqual(mockError); } - expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion); + expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {}); expect(mockSend).not.toHaveBeenCalled(); expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs); }); @@ -90,6 +90,6 @@ describe(getNewSsoOidcToken.name, () => { expect(newSsoOidcToken).toEqual(mockNewToken as any); expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs); expect(mockSend).toHaveBeenCalledWith(mockCreateTokenArgs); - expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion); + expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {}); }); }); diff --git a/packages/token-providers/src/getNewSsoOidcToken.ts b/packages/token-providers/src/getNewSsoOidcToken.ts index 91096f79f171..4458c728af51 100644 --- a/packages/token-providers/src/getNewSsoOidcToken.ts +++ b/packages/token-providers/src/getNewSsoOidcToken.ts @@ -1,16 +1,17 @@ import { SSOToken } from "@smithy/shared-ini-file-loader"; +import { FromSsoInit } from "./fromSso"; import { getSsoOidcClient } from "./getSsoOidcClient"; /** * Returns a new SSO OIDC token from ssoOids.createToken() API call. * @internal */ -export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string) => { +export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string, init: FromSsoInit = {}) => { // @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc' const { CreateTokenCommand } = await import("@aws-sdk/client-sso-oidc"); - const ssoOidcClient = await getSsoOidcClient(ssoRegion); + const ssoOidcClient = await getSsoOidcClient(ssoRegion, init); return ssoOidcClient.send( new CreateTokenCommand({ clientId: ssoToken.clientId, diff --git a/packages/token-providers/src/getSsoOidcClient.spec.ts b/packages/token-providers/src/getSsoOidcClient.spec.ts index bbaf420c3e6b..62750502b55e 100644 --- a/packages/token-providers/src/getSsoOidcClient.spec.ts +++ b/packages/token-providers/src/getSsoOidcClient.spec.ts @@ -5,6 +5,9 @@ vi.mock("@aws-sdk/client-sso-oidc"); describe("getSsoOidcClient", () => { const mockSsoRegion = "mockSsoRegion"; + const mockRequestHandler = { + protocol: "http", + }; const getMockClient = (region: string) => ({ region }); beforeEach(() => { @@ -22,24 +25,22 @@ describe("getSsoOidcClient", () => { expect(SSOOIDCClient).toHaveBeenCalledTimes(1); }); - it("returns SSOOIDC client from hash if already created", async () => { - const { getSsoOidcClient } = await import("./getSsoOidcClient"); - expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any); - expect(SSOOIDCClient).toHaveBeenCalledTimes(1); - expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any); - expect(SSOOIDCClient).toHaveBeenCalledTimes(1); - }); - - it("creates new SSOOIDC client per region", async () => { + it("passes through clientConfig and parentClientConfig.logger", async () => { const { getSsoOidcClient } = await import("./getSsoOidcClient"); const mockSsoRegion1 = `${mockSsoRegion}1`; - expect(await getSsoOidcClient(mockSsoRegion1)).toEqual(getMockClient(mockSsoRegion1) as any); + expect( + await getSsoOidcClient(mockSsoRegion1, { + clientConfig: { requestHandler: mockRequestHandler }, + parentClientConfig: { logger: console }, + }) + ).toEqual({ + region: mockSsoRegion1, + } as any); expect(SSOOIDCClient).toHaveBeenCalledTimes(1); - expect(SSOOIDCClient).toHaveBeenCalledWith({ region: mockSsoRegion1 }); - - const mockSsoRegion2 = `${mockSsoRegion}2`; - expect(await getSsoOidcClient(mockSsoRegion2)).toEqual(getMockClient(mockSsoRegion2) as any); - expect(SSOOIDCClient).toHaveBeenCalledTimes(2); - expect(SSOOIDCClient).toHaveBeenNthCalledWith(2, { region: mockSsoRegion2 }); + expect(SSOOIDCClient).toHaveBeenCalledWith({ + region: mockSsoRegion1, + requestHandler: mockRequestHandler, + logger: console, + }); }); }); diff --git a/packages/token-providers/src/getSsoOidcClient.ts b/packages/token-providers/src/getSsoOidcClient.ts index ec366ffcd73a..00993c5fd045 100644 --- a/packages/token-providers/src/getSsoOidcClient.ts +++ b/packages/token-providers/src/getSsoOidcClient.ts @@ -1,23 +1,18 @@ -const ssoOidcClientsHash: Record = {}; +import { FromSsoInit } from "./fromSso"; /** - * Returns a SSOOIDC client for the given region. If the client has already been created, - * it will be returned from the hash. + * Returns a SSOOIDC client for the given region. * @internal */ -export const getSsoOidcClient = async (ssoRegion: string) => { +export const getSsoOidcClient = async (ssoRegion: string, init: FromSsoInit = {}) => { // @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc' const { SSOOIDCClient } = await import("@aws-sdk/client-sso-oidc"); - // return ssoOidsClient if already created. - if (ssoOidcClientsHash[ssoRegion]) { - return ssoOidcClientsHash[ssoRegion]; - } - - // Create new SSOOIDC client, and store is in hash. - // If we need to support configuration of SsoOidc client in future through code, - // the provision to pass region from client configuration needs to be added. - const ssoOidcClient = new SSOOIDCClient({ region: ssoRegion }); - ssoOidcClientsHash[ssoRegion] = ssoOidcClient; + const ssoOidcClient = new SSOOIDCClient( + Object.assign({}, init.clientConfig ?? {}, { + region: ssoRegion ?? init.clientConfig?.region, + logger: init.clientConfig?.logger ?? init.parentClientConfig?.logger, + }) + ); return ssoOidcClient; };