Skip to content

Commit

Permalink
feat(core): set oidc issuer to custom domain
Browse files Browse the repository at this point in the history
  • Loading branch information
wangsijie committed Mar 18, 2024
1 parent 58885b3 commit 3456273
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 19 deletions.
11 changes: 8 additions & 3 deletions packages/core/src/app/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type Koa from 'koa';
import { EnvSet } from '#src/env-set/index.js';
import { TenantNotFoundError, tenantPool } from '#src/tenants/index.js';
import { consoleLog } from '#src/utils/console.js';
import { getTenantId } from '#src/utils/tenant.js';
import { getTenantId, getTenantIdFromCustomDomain } from '#src/utils/tenant.js';

const logListening = (type: 'core' | 'admin' = 'core') => {
const urlSet = type === 'core' ? EnvSet.values.urlSet : EnvSet.values.adminUrlSet;
Expand All @@ -29,15 +29,20 @@ export default async function initApp(app: Koa): Promise<void> {
return next();
}

const tenantId = await getTenantId(ctx.URL);
const tenantIdFromCustomDomain = await getTenantIdFromCustomDomain(ctx.URL);
const tenantId = tenantIdFromCustomDomain ?? (await getTenantId(ctx.URL, true));

if (!tenantId) {
ctx.status = 404;

return next();
}

const tenant = await trySafe(tenantPool.get(tenantId), (error) => {
// If the request is a custom domain of the tenant, use the custom endpoint to build "OIDC issuer"
// otherwise, build from the default endpoint (subdomain).
const customEndpoint = tenantIdFromCustomDomain ? ctx.URL.origin : undefined;

const tenant = await trySafe(tenantPool.get(tenantId, customEndpoint), (error) => {
ctx.status = error instanceof TenantNotFoundError ? 404 : 500;
void appInsights.trackException(error);
});
Expand Down
6 changes: 4 additions & 2 deletions packages/core/src/env-set/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class EnvSet {
return this.#oidc;
}

async load() {
async load(customDomain?: string) {
const pool = await createPoolByEnv(
this.databaseUrl,
EnvSet.values.isUnitTest,
Expand All @@ -77,7 +77,9 @@ export class EnvSet {
});

const oidcConfigs = await getOidcConfigs();
const endpoint = getTenantEndpoint(this.tenantId, EnvSet.values);
const endpoint = customDomain
? new URL(customDomain)
: getTenantEndpoint(this.tenantId, EnvSet.values);
this.#oidc = await loadOidcValues(appendPath(endpoint, '/oidc').href, oidcConfigs);
}

Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/tenants/Tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ import type TenantContext from './TenantContext.js';
import { getTenantDatabaseDsn } from './utils.js';

export default class Tenant implements TenantContext {
static async create(id: string, redisCache: RedisCache): Promise<Tenant> {
static async create(id: string, redisCache: RedisCache, customDomain?: string): Promise<Tenant> {
// Treat the default database URL as the management URL
const envSet = new EnvSet(id, await getTenantDatabaseDsn(id));
await envSet.load();
// Custom endpoint is used for building OIDC issuer URL when the request is a custom domain
await envSet.load(customDomain);

return new Tenant(envSet, id, new WellKnownCache(id, redisCache));
}
Expand Down
11 changes: 6 additions & 5 deletions packages/core/src/tenants/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ export class TenantPool {
},
});

async get(tenantId: string): Promise<Tenant> {
const tenantPromise = this.cache.get(tenantId);
async get(tenantId: string, customDomain?: string): Promise<Tenant> {
const cacheKey = `${tenantId}-${customDomain ?? 'default'}`;
const tenantPromise = this.cache.get(cacheKey);

if (tenantPromise) {
const tenant = await tenantPromise;
Expand All @@ -27,9 +28,9 @@ export class TenantPool {
// Otherwise, create a new tenant instance and store in LRU cache, using the code below.
}

consoleLog.info('Init tenant:', tenantId);
const newTenantPromise = Tenant.create(tenantId, redisCache);
this.cache.set(tenantId, newTenantPromise);
consoleLog.info('Init tenant:', tenantId, customDomain);
const newTenantPromise = Tenant.create(tenantId, redisCache, customDomain);
this.cache.set(cacheKey, newTenantPromise);

return newTenantPromise;
}
Expand Down
10 changes: 10 additions & 0 deletions packages/core/src/utils/tenant.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,14 @@ describe('getTenantId()', () => {
findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' });
await expect(getTenantId(new URL('https://logto.mock.com'))).resolves.toBe('mock');
});

it('should skip custom domain searching', async () => {
process.env = {
...backupEnv,
ENDPOINT: 'https://foo.*.logto.mock/app',
NODE_ENV: 'production',
};
findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' });
await expect(getTenantId(new URL('https://logto.mock.com'), true)).resolves.toBeUndefined();
});
});
29 changes: 22 additions & 7 deletions packages/core/src/utils/tenant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@ export const clearCustomDomainCache = async (url: URL | string) => {
await trySafe(async () => redisCache.delete(getDomainCacheKey(url)));
};

const getTenantIdFromCustomDomain = async (
/**
* Get tenant ID from the custom domain URL.
*/
export const getTenantIdFromCustomDomain = async (
url: URL,
pool: CommonQueryMethods
pool?: CommonQueryMethods
): Promise<string | undefined> => {
const cachedValue = await trySafe(async () => redisCache.get(getDomainCacheKey(url)));

if (cachedValue) {
return cachedValue;
}

const { findActiveDomain } = createDomainsQueries(pool);
const { findActiveDomain } = createDomainsQueries(pool ?? (await EnvSet.sharedPool));

const domain = await findActiveDomain(url.hostname);

Expand All @@ -74,7 +77,17 @@ const getTenantIdFromCustomDomain = async (
return domain?.tenantId;
};

export const getTenantId = async (url: URL) => {
/**
* Get tenant ID from the current request's URL.
*
* @param url The current request's URL
* @param skipCustomDomain Indicating whether to skip looking for custom domain
* @returns tenantId or undefined
*/
export const getTenantId = async (
url: URL,
skipCustomDomain?: boolean
): Promise<string | undefined> => {
const {
values: {
isMultiTenancy,
Expand Down Expand Up @@ -107,10 +120,12 @@ export const getTenantId = async (url: URL) => {
return matchPathBasedTenantId(urlSet, url);
}

const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool);
if (!skipCustomDomain) {
const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool);

if (customDomainTenantId) {
return customDomainTenantId;
if (customDomainTenantId) {
return customDomainTenantId;
}
}

return matchDomainBasedTenantId(urlSet.endpoint, url);
Expand Down

0 comments on commit 3456273

Please sign in to comment.