Skip to content

Commit

Permalink
Integrate VertexAI with Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
dlarocque committed May 6, 2024
1 parent 627b561 commit dddb566
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 1 deletion.
3 changes: 2 additions & 1 deletion packages/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ function registerVertex(): void {
container => {
// getImmediate for FirebaseApp will always succeed
const app = container.getProvider('app').getImmediate();
const auth = container.getProvider('auth-internal');
const appCheckProvider = container.getProvider('app-check-internal');
return new VertexAIService(app, appCheckProvider);
return new VertexAIService(app, auth, appCheckProvider);
},
ComponentType.PUBLIC
).setMultipleInstances(true)
Expand Down
5 changes: 5 additions & 0 deletions packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ export class GenerativeModel {
this._apiSettings.getAppCheckToken = () =>
(vertexAI as VertexAIService).appCheck!.getToken();
}

if ((vertexAI as VertexAIService).auth) {
this._apiSettings.getAuthToken = () =>
(vertexAI as VertexAIService).auth!.getToken();
}
}
if (modelParams.model.includes('/')) {
if (modelParams.model.startsWith('models/')) {
Expand Down
37 changes: 37 additions & 0 deletions packages/vertexai/src/requests/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ describe('request methods', () => {
apiKey: 'key',
project: 'myproject',
location: 'moon',
getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }),
getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' })
};
const fakeUrl = new RequestUrl(
Expand Down Expand Up @@ -173,6 +174,42 @@ describe('request methods', () => {
const headers = await getHeaders(fakeUrl);
expect(headers.has('X-Firebase-AppCheck')).to.be.false;
});
it('adds auth token if it exists', async () => {
const headers = await getHeaders(fakeUrl);
expect(headers.get('Authorization')).to.equal('Firebase authtoken');
});
it('ignores auth token header if no auth service', async () => {
const fakeUrl = new RequestUrl(
'models/model-name',
Task.GENERATE_CONTENT,
{
apiKey: 'key',
project: 'myproject',
location: 'moon'
},
true,
{}
);
const headers = await getHeaders(fakeUrl);
expect(headers.has('Authorization')).to.be.false;
});
it('ignores auth token header if returned token was undefined', async () => {
const fakeUrl = new RequestUrl(
'models/model-name',
Task.GENERATE_CONTENT,
{
apiKey: 'key',
project: 'myproject',
location: 'moon',
//@ts-ignore
getAppCheckToken: () => Promise.resolve()
},
true,
{}
);
const headers = await getHeaders(fakeUrl);
expect(headers.has('Authorization')).to.be.false;
});
});
describe('makeRequest', () => {
it('no error', async () => {
Expand Down
8 changes: 8 additions & 0 deletions packages/vertexai/src/requests/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ export async function getHeaders(url: RequestUrl): Promise<Headers> {
headers.append('X-Firebase-AppCheck', appCheckToken.token);
}
}

if (url.apiSettings.getAuthToken) {
const authToken = await url.apiSettings.getAuthToken();
if (authToken) {
headers.append('Authorization', `Firebase ${authToken.accessToken}`);
}
}

return headers;
}

Expand Down
8 changes: 8 additions & 0 deletions packages/vertexai/src/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,25 @@ import {
FirebaseAppCheckInternal
} from '@firebase/app-check-interop-types';
import { Provider } from '@firebase/component';
import {
FirebaseAuthInternal,
FirebaseAuthInternalName
} from '@firebase/auth-interop-types';
import { DEFAULT_LOCATION } from './constants';

export class VertexAIService implements VertexAI, _FirebaseService {
auth: FirebaseAuthInternal | null;
appCheck: FirebaseAppCheckInternal | null;
location: string;

constructor(
public app: FirebaseApp,
authProvider?: Provider<FirebaseAuthInternalName>,
appCheckProvider?: Provider<AppCheckInternalComponentName>
) {
const appCheck = appCheckProvider?.getImmediate({ optional: true });
const auth = authProvider?.getImmediate({ optional: true });
this.auth = auth || null;
this.appCheck = appCheck || null;
// TODO: add in user-set location option when that feature is available
this.location = DEFAULT_LOCATION;
Expand Down
2 changes: 2 additions & 0 deletions packages/vertexai/src/types/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
*/

import { AppCheckTokenResult } from '@firebase/app-check-interop-types';
import { FirebaseAuthTokenData } from '@firebase/auth-interop-types';

export interface ApiSettings {
apiKey: string;
project: string;
location: string;
getAuthToken?: () => Promise<FirebaseAuthTokenData | null>;
getAppCheckToken?: () => Promise<AppCheckTokenResult>;
}

0 comments on commit dddb566

Please sign in to comment.