diff --git a/README.md b/README.md index 998f306a..a38cf94f 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,14 @@ The Chrome DevTools MCP server supports the following configuration option: Path to a file to write debug logs to. Set the env variable `DEBUG` to `*` to enable verbose logs. Useful for submitting bug reports. - **Type:** string +- **`--allowedOrigins`** + Semicolon-separated list of origins the browser is allowed to request. If not specified, all origins are allowed (except those in blockedOrigins). Example: https://example.com;https://api.example.com + - **Type:** string + +- **`--blockedOrigins`** + Semicolon-separated list of origins the browser is blocked from requesting. Takes precedence over allowedOrigins. Example: https://ads.example.com;https://tracker.example.com + - **Type:** string + Pass them via the `args` property in the JSON configuration. For example: diff --git a/package-lock.json b/package-lock.json index 8c0ee9d5..7040887d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,7 +13,8 @@ "core-js": "3.45.1", "debug": "4.4.3", "puppeteer-core": "24.22.3", - "yargs": "18.0.0" + "yargs": "18.0.0", + "zod": "3.24.1" }, "bin": { "chrome-devtools-mcp": "build/src/index.js" @@ -6191,9 +6192,9 @@ } }, "node_modules/zod": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz", - "integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==", + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.1.tgz", + "integrity": "sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" diff --git a/package.json b/package.json index 3f222050..3866b6d8 100644 --- a/package.json +++ b/package.json @@ -41,7 +41,8 @@ "core-js": "3.45.1", "debug": "4.4.3", "puppeteer-core": "24.22.3", - "yargs": "18.0.0" + "yargs": "18.0.0", + "zod": "3.24.1" }, "devDependencies": { "@eslint/js": "^9.35.0", diff --git a/src/McpContext.ts b/src/McpContext.ts index 57fab794..9c369dd5 100644 --- a/src/McpContext.ts +++ b/src/McpContext.ts @@ -25,6 +25,7 @@ import {takeSnapshot} from './tools/snapshot.js'; import {CLOSE_PAGE_ERROR} from './tools/ToolDefinition.js'; import type {Context} from './tools/ToolDefinition.js'; import type {TraceResult} from './trace-processing/parse.js'; +import type {UrlValidator} from './utils/urlValidator.js'; import {WaitForHelper} from './WaitForHelper.js'; export interface TextSnapshotNode extends SerializedAXNode { @@ -77,10 +78,16 @@ export class McpContext implements Context { #nextSnapshotId = 1; #traceResults: TraceResult[] = []; + #urlValidator?: UrlValidator; - private constructor(browser: Browser, logger: Debugger) { + private constructor( + browser: Browser, + logger: Debugger, + urlValidator?: UrlValidator, + ) { this.browser = browser; this.logger = logger; + this.#urlValidator = urlValidator; this.#networkCollector = new NetworkCollector( this.browser, @@ -109,10 +116,52 @@ export class McpContext implements Context { this.setSelectedPageIdx(0); await this.#networkCollector.init(); await this.#consoleCollector.init(); + if (this.#urlValidator?.hasRestrictions()) { + await this.#setupRequestInterception(); + } + } + + async #setupRequestInterception() { + const pages = await this.browser.pages(); + for (const page of pages) { + await this.#enableRequestInterceptionForPage(page); + } + + this.browser.on('targetcreated', async target => { + const page = await target.page(); + if (page) { + await this.#enableRequestInterceptionForPage(page); + } + }); } - static async from(browser: Browser, logger: Debugger) { - const context = new McpContext(browser, logger); + async #enableRequestInterceptionForPage(page: Page) { + try { + await page.setRequestInterception(true); + + page.on('request', interceptedRequest => { + if (interceptedRequest.isInterceptResolutionHandled()) { + return; + } + + const url = interceptedRequest.url(); + if (this.#urlValidator && !this.#urlValidator.isAllowed(url)) { + void interceptedRequest.abort('blockedbyclient', 0); + } else { + void interceptedRequest.continue({}, 0); + } + }); + } catch (error) { + this.logger(`Failed to enable request interception for page: ${error}`); + } + } + + static async from( + browser: Browser, + logger: Debugger, + urlValidator?: UrlValidator, + ) { + const context = new McpContext(browser, logger, urlValidator); await context.#init(); return context; } @@ -133,6 +182,9 @@ export class McpContext implements Context { this.setSelectedPageIdx(pages.indexOf(page)); this.#networkCollector.addPage(page); this.#consoleCollector.addPage(page); + if (this.#urlValidator?.hasRestrictions()) { + await this.#enableRequestInterceptionForPage(page); + } return page; } async closePage(pageIdx: number): Promise { diff --git a/src/cli.ts b/src/cli.ts index df07bfad..e60f7b67 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -54,6 +54,16 @@ export const cliOptions = { describe: 'Path to a file to write debug logs to. Set the env variable `DEBUG` to `*` to enable verbose logs. Useful for submitting bug reports.', }, + allowedOrigins: { + type: 'string' as const, + describe: + 'Semicolon-separated list of origins the browser is allowed to request. If not specified, all origins are allowed (except those in blockedOrigins). Example: https://example.com;https://api.example.com', + }, + blockedOrigins: { + type: 'string' as const, + describe: + 'Semicolon-separated list of origins the browser is blocked from requesting. Takes precedence over allowedOrigins. Example: https://ads.example.com;https://tracker.example.com', + }, }; export function parseArguments(version: string, argv = process.argv) { @@ -78,6 +88,14 @@ export function parseArguments(version: string, argv = process.argv) { ['$0 --channel dev', 'Use Chrome Dev installed on this system'], ['$0 --channel stable', 'Use stable Chrome installed on this system'], ['$0 --logFile /tmp/log.txt', 'Save logs to a file'], + [ + '$0 --allowedOrigins "https://example.com;https://api.example.com"', + 'Only allow requests to specific origins', + ], + [ + '$0 --blockedOrigins "https://ads.example.com;https://tracker.com"', + 'Block requests to specific origins', + ], ['$0 --help', 'Print CLI options'], ]); diff --git a/src/main.ts b/src/main.ts index 6add9a90..46135e0f 100644 --- a/src/main.ts +++ b/src/main.ts @@ -32,6 +32,7 @@ import * as screenshotTools from './tools/screenshot.js'; import * as scriptTools from './tools/script.js'; import * as snapshotTools from './tools/snapshot.js'; import type {ToolDefinition} from './tools/ToolDefinition.js'; +import {UrlValidator} from './utils/urlValidator.js'; function readPackageJson(): {version?: string} { const currentDir = import.meta.dirname; @@ -55,6 +56,14 @@ export const args = parseArguments(version); const logFile = args.logFile ? saveLogsToFile(args.logFile) : undefined; logger(`Starting Chrome DevTools MCP Server v${version}`); + +const allowedOrigins = UrlValidator.parseOrigins(args.allowedOrigins); +const blockedOrigins = UrlValidator.parseOrigins(args.blockedOrigins); +const urlValidator = + allowedOrigins.length > 0 || blockedOrigins.length > 0 + ? new UrlValidator({allowedOrigins, blockedOrigins}, logger) + : undefined; + const server = new McpServer( { name: 'chrome_devtools', @@ -79,7 +88,7 @@ async function getContext(): Promise { logFile, }); if (context?.browser !== browser) { - context = await McpContext.from(browser, logger); + context = await McpContext.from(browser, logger, urlValidator); } return context; } diff --git a/src/utils/urlValidator.ts b/src/utils/urlValidator.ts new file mode 100644 index 00000000..bc1df168 --- /dev/null +++ b/src/utils/urlValidator.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type {Debugger} from 'debug'; + +export class UrlValidator { + #allowedOrigins: string[]; + #blockedOrigins: string[]; + #logger: Debugger; + + constructor( + options: { + allowedOrigins?: string[]; + blockedOrigins?: string[]; + }, + logger: Debugger, + ) { + this.#allowedOrigins = options.allowedOrigins ?? []; + this.#blockedOrigins = options.blockedOrigins ?? []; + this.#logger = logger; + + if (this.#allowedOrigins.length > 0) { + this.#logger( + `URL validation enabled. Allowed origins: ${this.#allowedOrigins.join(', ')}`, + ); + } + if (this.#blockedOrigins.length > 0) { + this.#logger( + `URL validation enabled. Blocked origins: ${this.#blockedOrigins.join(', ')}`, + ); + } + } + + static parseOrigins(originsString?: string): string[] { + if (!originsString) { + return []; + } + return originsString + .split(';') + .map(o => o.trim()) + .filter(o => o.length > 0); + } + + isAllowed(url: string): boolean { + if (this.#isSpecialUrl(url)) { + return true; + } + + try { + const origin = new URL(url).origin; + + if (this.#matchesAnyOrigin(origin, this.#blockedOrigins)) { + this.#logger(`Blocked request to ${url} (origin: ${origin})`); + return false; + } + + if (this.#allowedOrigins.length === 0) { + return true; + } + + const allowed = this.#matchesAnyOrigin(origin, this.#allowedOrigins); + if (!allowed) { + this.#logger( + `Blocked request to ${url} (origin: ${origin} not in allowlist)`, + ); + } + return allowed; + } catch { + return true; + } + } + + #isSpecialUrl(url: string): boolean { + const lowerUrl = url.toLowerCase(); + return ( + lowerUrl.startsWith('about:') || + lowerUrl.startsWith('data:') || + lowerUrl.startsWith('blob:') || + lowerUrl.startsWith('file:') + ); + } + + #matchesAnyOrigin(origin: string, patterns: string[]): boolean { + for (const pattern of patterns) { + if (this.#matchesOriginPattern(origin, pattern)) { + return true; + } + } + return false; + } + + #matchesOriginPattern(origin: string, pattern: string): boolean { + if (origin === pattern) { + return true; + } + + if (pattern.includes('*')) { + const regex = this.#patternToRegex(pattern); + return regex.test(origin); + } + + return false; + } + + #patternToRegex(pattern: string): RegExp { + const escaped = pattern.replace(/[.+?^${}()|[\]\\]/g, '\\$&'); + const regexPattern = escaped.replace(/\*/g, '[^\\/]+'); + return new RegExp(`^${regexPattern}$`); + } + + hasRestrictions(): boolean { + return this.#allowedOrigins.length > 0 || this.#blockedOrigins.length > 0; + } +} diff --git a/tests/utils.ts b/tests/utils.ts index 82b4da4e..e31a0bc3 100644 --- a/tests/utils.ts +++ b/tests/utils.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import logger from 'debug'; +import type {Debugger} from 'debug'; import type {Browser} from 'puppeteer'; import puppeteer from 'puppeteer'; import type {HTTPRequest, HTTPResponse} from 'puppeteer-core'; @@ -11,6 +12,10 @@ import type {HTTPRequest, HTTPResponse} from 'puppeteer-core'; import {McpContext} from '../src/McpContext.js'; import {McpResponse} from '../src/McpResponse.js'; +export function createLogger(namespace = 'test'): Debugger { + return logger(namespace); +} + let browser: Browser | undefined; export async function withBrowser( diff --git a/tests/utils/urlValidator.test.ts b/tests/utils/urlValidator.test.ts new file mode 100644 index 00000000..c3292999 --- /dev/null +++ b/tests/utils/urlValidator.test.ts @@ -0,0 +1,144 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import assert from 'node:assert'; +import {describe, it} from 'node:test'; + +import {UrlValidator} from '../../src/utils/urlValidator.js'; +import {createLogger} from '../utils.js'; + +describe('UrlValidator', () => { + describe('parseOrigins', () => { + it('should parse semicolon-separated origins', () => { + const result = UrlValidator.parseOrigins( + 'https://example.com;https://api.example.com', + ); + assert.deepStrictEqual(result, [ + 'https://example.com', + 'https://api.example.com', + ]); + }); + + it('should trim whitespace', () => { + const result = UrlValidator.parseOrigins( + ' https://example.com ; https://api.example.com ', + ); + assert.deepStrictEqual(result, [ + 'https://example.com', + 'https://api.example.com', + ]); + }); + + it('should filter empty strings', () => { + const result = UrlValidator.parseOrigins('https://example.com;;'); + assert.deepStrictEqual(result, ['https://example.com']); + }); + + it('should return empty array for undefined', () => { + const result = UrlValidator.parseOrigins(undefined); + assert.deepStrictEqual(result, []); + }); + }); + + describe('isAllowed', () => { + it('should allow all URLs when no restrictions', () => { + const validator = new UrlValidator({}, createLogger()); + assert.strictEqual(validator.isAllowed('https://example.com'), true); + assert.strictEqual(validator.isAllowed('https://blocked.com'), true); + }); + + it('should allow special URLs', () => { + const validator = new UrlValidator( + {allowedOrigins: ['https://example.com']}, + createLogger(), + ); + assert.strictEqual(validator.isAllowed('about:blank'), true); + assert.strictEqual(validator.isAllowed('data:text/html,test'), true); + assert.strictEqual(validator.isAllowed('blob:test'), true); + assert.strictEqual(validator.isAllowed('file:///test'), true); + }); + + it('should block URLs not in allowlist', () => { + const validator = new UrlValidator( + {allowedOrigins: ['https://example.com']}, + createLogger(), + ); + assert.strictEqual(validator.isAllowed('https://example.com'), true); + assert.strictEqual(validator.isAllowed('https://example.com/path'), true); + assert.strictEqual(validator.isAllowed('https://other.com'), false); + }); + + it('should block URLs in blocklist', () => { + const validator = new UrlValidator( + {blockedOrigins: ['https://blocked.com']}, + createLogger(), + ); + assert.strictEqual(validator.isAllowed('https://example.com'), true); + assert.strictEqual(validator.isAllowed('https://blocked.com'), false); + }); + + it('should prioritize blocklist over allowlist', () => { + const validator = new UrlValidator( + { + allowedOrigins: ['https://example.com'], + blockedOrigins: ['https://example.com'], + }, + createLogger(), + ); + assert.strictEqual(validator.isAllowed('https://example.com'), false); + }); + + it('should support wildcard patterns', () => { + const validator = new UrlValidator( + {allowedOrigins: ['https://*.example.com']}, + createLogger(), + ); + assert.strictEqual(validator.isAllowed('https://api.example.com'), true); + assert.strictEqual(validator.isAllowed('https://cdn.example.com'), true); + assert.strictEqual(validator.isAllowed('https://example.com'), false); + assert.strictEqual(validator.isAllowed('https://other.com'), false); + }); + + it('should support wildcard in blocklist', () => { + const validator = new UrlValidator( + {blockedOrigins: ['https://*.ads.example.com']}, + createLogger(), + ); + assert.strictEqual( + validator.isAllowed('https://tracker.ads.example.com'), + false, + ); + assert.strictEqual( + validator.isAllowed('https://stats.ads.example.com'), + false, + ); + assert.strictEqual(validator.isAllowed('https://example.com'), true); + }); + }); + + describe('hasRestrictions', () => { + it('should return false when no restrictions', () => { + const validator = new UrlValidator({}, createLogger()); + assert.strictEqual(validator.hasRestrictions(), false); + }); + + it('should return true when allowlist is set', () => { + const validator = new UrlValidator( + {allowedOrigins: ['https://example.com']}, + createLogger(), + ); + assert.strictEqual(validator.hasRestrictions(), true); + }); + + it('should return true when blocklist is set', () => { + const validator = new UrlValidator( + {blockedOrigins: ['https://blocked.com']}, + createLogger(), + ); + assert.strictEqual(validator.hasRestrictions(), true); + }); + }); +});