diff --git a/bun.lockb b/bun.lockb index 78894db9b..cbaee6d06 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/deno_dist/helper/ssg/ssg.ts b/deno_dist/helper/ssg/ssg.ts index 8016bf48d..8188403f0 100644 --- a/deno_dist/helper/ssg/ssg.ts +++ b/deno_dist/helper/ssg/ssg.ts @@ -30,8 +30,13 @@ export interface ToSSGResult { error?: Error } -const generateFilePath = (routePath: string, outDir: string, mimeType: string) => { - const extension = determineExtension(mimeType) +const generateFilePath = ( + routePath: string, + outDir: string, + mimeType: string, + extensionMap?: Record +) => { + const extension = determineExtension(mimeType, extensionMap) if (routePath.endsWith(`.${extension}`)) { return joinPaths(outDir, routePath) @@ -62,17 +67,22 @@ const parseResponseContent = async (response: Response): Promise { - switch (mimeType) { - case 'text/html': - return 'html' - case 'text/xml': - case 'application/xml': - return 'xml' - default: { - return getExtension(mimeType) || 'html' - } +export const defaultExtensionMap: Record = { + 'text/html': 'html', + 'text/xml': 'xml', + 'application/xml': 'xml', + 'application/yaml': 'yaml', +} + +const determineExtension = ( + mimeType: string, + userExtensionMap?: Record +): string => { + const extensionMap = userExtensionMap || defaultExtensionMap + if (mimeType in extensionMap) { + return extensionMap[mimeType] } + return getExtension(mimeType) || 'html' } export type BeforeRequestHook = (req: Request) => Request | false | Promise @@ -85,6 +95,7 @@ export interface ToSSGOptions { afterResponseHook?: AfterResponseHook afterGenerateHook?: AfterGenerateHook concurrency?: number + extensionMap?: Record } /** @@ -204,14 +215,15 @@ const createdDirs: Set = new Set() export const saveContentToFile = async ( data: Promise<{ routePath: string; content: string | ArrayBuffer; mimeType: string } | undefined>, fsModule: FileSystemModule, - outDir: string + outDir: string, + extensionMap?: Record ): Promise => { const awaitedData = await data if (!awaitedData) { return } const { routePath, content, mimeType } = awaitedData - const filePath = generateFilePath(routePath, outDir, mimeType) + const filePath = generateFilePath(routePath, outDir, mimeType, extensionMap) const dirPath = dirname(filePath) if (!createdDirs.has(dirPath)) { diff --git a/deno_dist/jsx/dom/index.ts b/deno_dist/jsx/dom/index.ts index 37570eab3..9bde6b33d 100644 --- a/deno_dist/jsx/dom/index.ts +++ b/deno_dist/jsx/dom/index.ts @@ -15,6 +15,7 @@ import { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, } from '../hooks/index.ts' import { Suspense, ErrorBoundary } from './components.ts' @@ -69,6 +70,7 @@ export { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, Suspense, ErrorBoundary, @@ -94,6 +96,7 @@ export default { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, Suspense, ErrorBoundary, diff --git a/deno_dist/jsx/hooks/index.ts b/deno_dist/jsx/hooks/index.ts index 7bf9c968e..56a96bcd4 100644 --- a/deno_dist/jsx/hooks/index.ts +++ b/deno_dist/jsx/hooks/index.ts @@ -349,6 +349,9 @@ export const useMemo = (factory: () => T, deps: readonly unknown[]): T => { return memoArray[hookIndex][0] as T } +let idCounter = 0 +export const useId = (): string => useMemo(() => `:r${(idCounter++).toString(32)}:`, []) + // Define to avoid errors. This hook currently does nothing. // eslint-disable-next-line @typescript-eslint/no-unused-vars export const useDebugValue = (_value: unknown, _formatter?: (value: unknown) => string): void => {} diff --git a/deno_dist/jsx/index.ts b/deno_dist/jsx/index.ts index 446bad18d..9cb9e0017 100644 --- a/deno_dist/jsx/index.ts +++ b/deno_dist/jsx/index.ts @@ -15,6 +15,7 @@ import { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, } from './hooks/index.ts' import { Suspense } from './streaming.ts' @@ -34,6 +35,7 @@ export { useRef, useCallback, useReducer, + useId, useDebugValue, use, startTransition, @@ -60,6 +62,7 @@ export default { useRef, useCallback, useReducer, + useId, useDebugValue, use, startTransition, diff --git a/deno_dist/middleware.ts b/deno_dist/middleware.ts index 290c43c26..5b421668b 100644 --- a/deno_dist/middleware.ts +++ b/deno_dist/middleware.ts @@ -11,8 +11,10 @@ export * from './jsx/index.ts' export * from './middleware/jsx-renderer/index.ts' export { jwt } from './middleware/jwt/index.ts' export * from './middleware/logger/index.ts' +export * from './middleware/method-override/index.ts' export * from './middleware/powered-by/index.ts' export * from './middleware/timing/index.ts' export * from './middleware/pretty-json/index.ts' export * from './middleware/secure-headers/index.ts' +export * from './middleware/trailing-slash/index.ts' export * from './adapter/deno/serve-static.ts' diff --git a/deno_dist/middleware/basic-auth/index.ts b/deno_dist/middleware/basic-auth/index.ts index 1e0388071..cd05874d2 100644 --- a/deno_dist/middleware/basic-auth/index.ts +++ b/deno_dist/middleware/basic-auth/index.ts @@ -1,3 +1,4 @@ +import type { Context } from '../../context.ts' import { HTTPException } from '../../http-exception.ts' import type { HonoRequest } from '../../request.ts' import type { MiddlewareHandler } from '../../types.ts' @@ -26,31 +27,59 @@ const auth = (req: HonoRequest) => { return { username: userPass[1], password: userPass[2] } } +type BasicAuthOptions = + | { + username: string + password: string + realm?: string + hashFunction?: Function + } + | { + verifyUser: (username: string, password: string, c: Context) => boolean | Promise + realm?: string + hashFunction?: Function + } + export const basicAuth = ( - options: { username: string; password: string; realm?: string; hashFunction?: Function }, + options: BasicAuthOptions, ...users: { username: string; password: string }[] ): MiddlewareHandler => { - if (!options) { - throw new Error('basic auth middleware requires options for "username and password"') + const usernamePasswordInOptions = 'username' in options && 'password' in options + const verifyUserInOptions = 'verifyUser' in options + + if (!(usernamePasswordInOptions || verifyUserInOptions)) { + throw new Error( + 'basic auth middleware requires options for "username and password" or "verifyUser"' + ) } if (!options.realm) { options.realm = 'Secure Area' } - users.unshift({ username: options.username, password: options.password }) + + if (usernamePasswordInOptions) { + users.unshift({ username: options.username, password: options.password }) + } return async function basicAuth(ctx, next) { const requestUser = auth(ctx.req) if (requestUser) { - for (const user of users) { - const [usernameEqual, passwordEqual] = await Promise.all([ - timingSafeEqual(user.username, requestUser.username, options.hashFunction), - timingSafeEqual(user.password, requestUser.password, options.hashFunction), - ]) - if (usernameEqual && passwordEqual) { + if (verifyUserInOptions) { + if (await options.verifyUser(requestUser.username, requestUser.password, ctx)) { await next() return } + } else { + for (const user of users) { + const [usernameEqual, passwordEqual] = await Promise.all([ + timingSafeEqual(user.username, requestUser.username, options.hashFunction), + timingSafeEqual(user.password, requestUser.password, options.hashFunction), + ]) + if (usernameEqual && passwordEqual) { + await next() + return + } + } } } const res = new Response('Unauthorized', { diff --git a/deno_dist/middleware/bearer-auth/index.ts b/deno_dist/middleware/bearer-auth/index.ts index 2c6aed92e..5b65b8f98 100644 --- a/deno_dist/middleware/bearer-auth/index.ts +++ b/deno_dist/middleware/bearer-auth/index.ts @@ -1,3 +1,4 @@ +import type { Context } from '../../context.ts' import { HTTPException } from '../../http-exception.ts' import type { MiddlewareHandler } from '../../types.ts' import { timingSafeEqual } from '../../utils/buffer.ts' @@ -5,13 +6,22 @@ import { timingSafeEqual } from '../../utils/buffer.ts' const TOKEN_STRINGS = '[A-Za-z0-9._~+/-]+=*' const PREFIX = 'Bearer' -export const bearerAuth = (options: { - token: string | string[] - realm?: string - prefix?: string - hashFunction?: Function -}): MiddlewareHandler => { - if (!options.token) { +type BearerAuthOptions = + | { + token: string | string[] + realm?: string + prefix?: string + hashFunction?: Function + } + | { + realm?: string + prefix?: string + verifyToken: (token: string, c: Context) => boolean | Promise + hashFunction?: Function + } + +export const bearerAuth = (options: BearerAuthOptions): MiddlewareHandler => { + if (!('token' in options || 'verifyToken' in options)) { throw new Error('bearer auth middleware requires options for "token"') } if (!options.realm) { @@ -49,7 +59,9 @@ export const bearerAuth = (options: { throw new HTTPException(400, { res }) } else { let equal = false - if (typeof options.token === 'string') { + if ('verifyToken' in options) { + equal = await options.verifyToken(match[1], c) + } else if (typeof options.token === 'string') { equal = await timingSafeEqual(options.token, match[1], options.hashFunction) } else if (Array.isArray(options.token) && options.token.length > 0) { for (const token of options.token) { diff --git a/deno_dist/middleware/cache/index.ts b/deno_dist/middleware/cache/index.ts index 74035657d..5abd6f5d4 100644 --- a/deno_dist/middleware/cache/index.ts +++ b/deno_dist/middleware/cache/index.ts @@ -5,6 +5,7 @@ export const cache = (options: { cacheName: string wait?: boolean cacheControl?: string + vary?: string | string[] }): MiddlewareHandler => { if (!globalThis.caches) { console.log('Cache Middleware is not enabled because caches is not defined.') @@ -15,16 +16,28 @@ export const cache = (options: { options.wait = false } - const directives = options.cacheControl?.split(',').map((directive) => directive.toLowerCase()) + const cacheControlDirectives = options.cacheControl + ?.split(',') + .map((directive) => directive.toLowerCase()) + const varyDirectives = Array.isArray(options.vary) + ? options.vary + : options.vary?.split(',').map((directive) => directive.trim()) + // RFC 7231 Section 7.1.4 specifies that "*" is not allowed in Vary header. + // See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.1.4 + if (options.vary?.includes('*')) { + throw new Error( + 'Middleware vary configuration cannot include "*", as it disallows effective caching.' + ) + } const addHeader = (c: Context) => { - if (directives) { + if (cacheControlDirectives) { const existingDirectives = c.res.headers .get('Cache-Control') ?.split(',') .map((d) => d.trim().split('=', 1)[0]) ?? [] - for (const directive of directives) { + for (const directive of cacheControlDirectives) { let [name, value] = directive.trim().split('=', 2) name = name.toLowerCase() if (!existingDirectives.includes(name)) { @@ -32,6 +45,26 @@ export const cache = (options: { } } } + + if (varyDirectives) { + const existingDirectives = + c.res.headers + .get('Vary') + ?.split(',') + .map((d) => d.trim()) ?? [] + + const vary = Array.from( + new Set( + [...existingDirectives, ...varyDirectives].map((directive) => directive.toLowerCase()) + ) + ).sort() + + if (vary.includes('*')) { + c.header('Vary', '*') + } else { + c.header('Vary', vary.join(', ')) + } + } } return async function cache(c, next) { diff --git a/deno_dist/middleware/cors/index.ts b/deno_dist/middleware/cors/index.ts index 652078cbb..fe3b530e2 100644 --- a/deno_dist/middleware/cors/index.ts +++ b/deno_dist/middleware/cors/index.ts @@ -1,7 +1,8 @@ +import type { Context } from '../../context.ts' import type { MiddlewareHandler } from '../../types.ts' type CORSOptions = { - origin: string | string[] | ((origin: string) => string | undefined | null) + origin: string | string[] | ((origin: string, c: Context) => string | undefined | null) allowMethods?: string[] allowHeaders?: string[] maxAge?: number @@ -36,7 +37,7 @@ export const cors = (options?: CORSOptions): MiddlewareHandler => { c.res.headers.set(key, value) } - const allowOrigin = findAllowOrigin(c.req.header('origin') || '') + const allowOrigin = findAllowOrigin(c.req.header('origin') || '', c) if (allowOrigin) { set('Access-Control-Allow-Origin', allowOrigin) } diff --git a/deno_dist/middleware/jwt/index.ts b/deno_dist/middleware/jwt/index.ts index bc1472d14..490177fd6 100644 --- a/deno_dist/middleware/jwt/index.ts +++ b/deno_dist/middleware/jwt/index.ts @@ -3,8 +3,8 @@ import { getCookie } from '../../helper/cookie/index.ts' import { HTTPException } from '../../http-exception.ts' import type { MiddlewareHandler } from '../../types.ts' import { Jwt } from '../../utils/jwt/index.ts' -import type { AlgorithmTypes } from '../../utils/jwt/types.ts' import '../../context.ts' +import type { SignatureAlgorithm } from '../../utils/jwt/jwa.ts' declare module '../../context.ts' { interface ContextVariableMap { @@ -16,7 +16,7 @@ declare module '../../context.ts' { export const jwt = (options: { secret: string cookie?: string - alg?: string + alg?: SignatureAlgorithm }): MiddlewareHandler => { if (!options) { throw new Error('JWT auth middleware requires options for "secret') @@ -32,11 +32,13 @@ export const jwt = (options: { if (credentials) { const parts = credentials.split(/\s+/) if (parts.length !== 2) { + const errDescription = 'invalid credentials structure' throw new HTTPException(401, { + message: errDescription, res: unauthorizedResponse({ ctx, error: 'invalid_request', - errDescription: 'invalid credentials structure', + errDescription, }), }) } else { @@ -47,30 +49,34 @@ export const jwt = (options: { } if (!token) { + const errDescription = 'no authorization included in request' throw new HTTPException(401, { + message: errDescription, res: unauthorizedResponse({ ctx, error: 'invalid_request', - errDescription: 'no authorization included in request', + errDescription, }), }) } let payload - let msg = '' + let cause try { - payload = await Jwt.verify(token, options.secret, options.alg as AlgorithmTypes) + payload = await Jwt.verify(token, options.secret, options.alg) } catch (e) { - msg = `${e}` + cause = e } if (!payload) { throw new HTTPException(401, { + message: 'Unauthorized', res: unauthorizedResponse({ ctx, error: 'invalid_token', - statusText: msg, + statusText: 'Unauthorized', errDescription: 'token verification failure', }), + cause, }) } diff --git a/deno_dist/middleware/method-override/index.ts b/deno_dist/middleware/method-override/index.ts new file mode 100644 index 000000000..60e399ac9 --- /dev/null +++ b/deno_dist/middleware/method-override/index.ts @@ -0,0 +1,134 @@ +import { URLSearchParams } from 'node:url' +import type { Context } from '../../context.ts' +import type { Hono } from '../../hono.ts' +import type { MiddlewareHandler } from '../../types.ts' +import { parseBody } from '../../utils/body.ts' + +type MethodOverrideOptions = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + app: Hono +} & ( + | { + // Default is 'form' and the value is `_method` + form?: string + header?: never + query?: never + } + | { + form?: never + header: string + query?: never + } + | { + form?: never + header?: never + query: string + } +) + +const DEFAULT_METHOD_FORM_NAME = '_method' + +/** + * Method Override Middleware + * + * @example + * // with form input method + * const app = new Hono() + * app.use('/books/*', methodOverride({ app })) // the default `form` value is `_method` + * app.use('/authors/*', methodOverride({ app, form: 'method' })) + * + * @example + * // with custom header + * app.use('/books/*', methodOverride({ app, header: 'X-HTTP-METHOD-OVERRIDE' })) + * + * @example + * // with query parameter + * app.use('/books/*', methodOverride({ app, query: '_method' })) + */ +export const methodOverride = (options: MethodOverrideOptions): MiddlewareHandler => + async function methodOverride(c, next) { + if (c.req.method === 'GET') { + return await next() + } + + const app = options.app + // Method override by form + if (!(options.header || options.query)) { + const contentType = c.req.header('content-type') + const methodFormName = options.form || DEFAULT_METHOD_FORM_NAME + const clonedRequest = c.req.raw.clone() + const newRequest = clonedRequest.clone() + // Content-Type is `multipart/form-data` + if (contentType?.startsWith('multipart/form-data')) { + const form = await clonedRequest.formData() + const method = form.get(methodFormName) + if (method) { + const newForm = await newRequest.formData() + newForm.delete(methodFormName) + const newHeaders = new Headers(clonedRequest.headers) + newHeaders.delete('content-type') + newHeaders.delete('content-length') + const request = new Request(c.req.url, { + body: newForm, + headers: newHeaders, + method: method as string, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + // Content-Type is `application/x-www-form-urlencoded` + if (contentType === 'application/x-www-form-urlencoded') { + const params = await parseBody>(clonedRequest) + const method = params[methodFormName] + if (method) { + delete params[methodFormName] + const newParams = new URLSearchParams(params) + const request = new Request(newRequest, { + body: newParams, + method: method as string, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + } + // Method override by header + else if (options.header) { + const headerName = options.header + const method = c.req.header(headerName) + if (method) { + const newHeaders = new Headers(c.req.raw.headers) + newHeaders.delete(headerName) + const request = new Request(c.req.raw, { + headers: newHeaders, + method, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + // Method override by query + else if (options.query) { + const queryName = options.query + const method = c.req.query(queryName) + if (method) { + const url = new URL(c.req.url) + url.searchParams.delete(queryName) + const request = new Request(url.toString(), { + body: c.req.raw.body, + headers: c.req.raw.headers, + method, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + await next() + } + +const getExecutionCtx = (c: Context) => { + let executionCtx: ExecutionContext | undefined + try { + executionCtx = c.executionCtx + } catch { + // Do nothing + } + return executionCtx +} diff --git a/deno_dist/middleware/trailing-slash/index.ts b/deno_dist/middleware/trailing-slash/index.ts new file mode 100644 index 000000000..f09492eb0 --- /dev/null +++ b/deno_dist/middleware/trailing-slash/index.ts @@ -0,0 +1,46 @@ +import type { MiddlewareHandler } from '../../types.ts' + +/** + * Trim the trailing slash from the URL if it does have one. For example, `/path/to/page/` will be redirected to `/path/to/page`. + * @access public + * @example app.use(trimTrailingSlash()) + */ +export const trimTrailingSlash = (): MiddlewareHandler => { + return async function trimTrailingSlash(c, next) { + await next() + + if ( + c.res.status === 404 && + c.req.method === 'GET' && + c.req.path !== '/' && + c.req.path[c.req.path.length - 1] === '/' + ) { + const url = new URL(c.req.url) + url.pathname = url.pathname.substring(0, url.pathname.length - 1) + + c.res = c.redirect(url.toString(), 301) + } + } +} + +/** + * Append a trailing slash to the URL if it doesn't have one. For example, `/path/to/page` will be redirected to `/path/to/page/`. + * @access public + * @example app.use(appendTrailingSlash()) + */ +export const appendTrailingSlash = (): MiddlewareHandler => { + return async function appendTrailingSlash(c, next) { + await next() + + if ( + c.res.status === 404 && + c.req.method === 'GET' && + c.req.path[c.req.path.length - 1] !== '/' + ) { + const url = new URL(c.req.url) + url.pathname += '/' + + c.res = c.redirect(url.toString(), 301) + } + } +} diff --git a/deno_dist/request.ts b/deno_dist/request.ts index d68206c59..3e76cb95d 100644 --- a/deno_dist/request.ts +++ b/deno_dist/request.ts @@ -194,18 +194,27 @@ export class HonoRequest

{ private cachedBody = (key: keyof Body) => { const { bodyCache, raw } = this const cachedBody = bodyCache[key] + if (cachedBody) { return cachedBody } - /** - * If an arrayBuffer cache is exist, - * use it for creating a text, json, and others. - */ - if (bodyCache.arrayBuffer) { - return (async () => { - return await new Response(bodyCache.arrayBuffer)[key]() - })() + + if (!bodyCache[key]) { + for (const keyOfBodyCache of Object.keys(bodyCache)) { + if (keyOfBodyCache === 'parsedBody') { + continue + } + return (async () => { + // @ts-expect-error bodyCache[keyOfBodyCache] can be passed as a body + let body = await bodyCache[keyOfBodyCache] + if (keyOfBodyCache === 'json') { + body = JSON.stringify(body) + } + return await new Response(body)[key]() + })() + } } + return (bodyCache[key] = raw[key]()) } diff --git a/deno_dist/utils/jwt/index.ts b/deno_dist/utils/jwt/index.ts index ef804944d..53d2eda74 100644 --- a/deno_dist/utils/jwt/index.ts +++ b/deno_dist/utils/jwt/index.ts @@ -1 +1,2 @@ -export * as Jwt from './jwt.ts' +import { sign, verify, decode } from './jwt.ts' +export const Jwt = { sign, verify, decode } diff --git a/deno_dist/utils/jwt/jwa.ts b/deno_dist/utils/jwt/jwa.ts new file mode 100644 index 000000000..8512b811c --- /dev/null +++ b/deno_dist/utils/jwt/jwa.ts @@ -0,0 +1,20 @@ +// JSON Web Algorithms (JWA) +// https://datatracker.ietf.org/doc/html/rfc7518 + +export enum AlgorithmTypes { + HS256 = 'HS256', + HS384 = 'HS384', + HS512 = 'HS512', + RS256 = 'RS256', + RS384 = 'RS384', + RS512 = 'RS512', + PS256 = 'PS256', + PS384 = 'PS384', + PS512 = 'PS512', + ES256 = 'ES256', + ES384 = 'ES384', + ES512 = 'ES512', + EdDSA = 'EdDSA', +} + +export type SignatureAlgorithm = keyof typeof AlgorithmTypes diff --git a/deno_dist/utils/jwt/jws.ts b/deno_dist/utils/jwt/jws.ts new file mode 100644 index 000000000..c49cae394 --- /dev/null +++ b/deno_dist/utils/jwt/jws.ts @@ -0,0 +1,224 @@ +import { getRuntimeKey } from '../../helper.ts' +import { decodeBase64 } from '../encode.ts' +import type { SignatureAlgorithm } from './jwa.ts' +import { JwtAlgorithmNotImplemented } from './types.ts' +import { CryptoKeyUsage } from './types.ts' +import { utf8Encoder } from './utf8.ts' + +// JSON Web Signature (JWS) +// https://datatracker.ietf.org/doc/html/rfc7515 + +type KeyImporterAlgorithm = Parameters[2] +type KeyAlgorithm = + | AlgorithmIdentifier + | RsaHashedImportParams + | (RsaPssParams & RsaHashedImportParams) + | (EcdsaParams & EcKeyImportParams) + | HmacImportParams + +export type SignatureKey = string | JsonWebKey | CryptoKey + +export async function signing( + privateKey: SignatureKey, + alg: SignatureAlgorithm, + data: BufferSource +): Promise { + const algorithm = getKeyAlgorithm(alg) + const cryptoKey = await importPrivateKey(privateKey, algorithm) + return await crypto.subtle.sign(algorithm, cryptoKey, data) +} + +export async function verifying( + publicKey: SignatureKey, + alg: SignatureAlgorithm, + signature: BufferSource, + data: BufferSource +): Promise { + const algorithm = getKeyAlgorithm(alg) + const cryptoKey = await importPublicKey(publicKey, algorithm) + return await crypto.subtle.verify(algorithm, cryptoKey, signature, data) +} + +function pemToBinary(pem: string): Uint8Array { + return decodeBase64(pem.replace(/-+(BEGIN|END).*/g, '').replace(/\s/g, '')) +} + +async function importPrivateKey(key: SignatureKey, alg: KeyImporterAlgorithm): Promise { + if (!crypto.subtle || !crypto.subtle.importKey) { + throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') + } + if (isCryptoKey(key)) { + if (key.type !== 'private') { + throw new Error(`unexpected non private key: CryptoKey.type is ${key.type}`) + } + return key + } + const usages = [CryptoKeyUsage.Sign] + if (typeof key === 'object') { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#json_web_key_import + return await crypto.subtle.importKey('jwk', key, alg, false, usages) + } + if (key.includes('PRIVATE')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#pkcs_8_import + return await crypto.subtle.importKey('pkcs8', pemToBinary(key), alg, false, usages) + } + return await crypto.subtle.importKey('raw', utf8Encoder.encode(key), alg, false, usages) +} + +async function importPublicKey(key: SignatureKey, alg: KeyImporterAlgorithm): Promise { + if (!crypto.subtle || !crypto.subtle.importKey) { + throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') + } + if (isCryptoKey(key)) { + if (key.type === 'public' || key.type === 'secret') { + return key + } + key = await exportPublicJwkFrom(key) + } + if (typeof key === 'string' && key.includes('PRIVATE')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#pkcs_8_import + const privateKey = await crypto.subtle.importKey('pkcs8', pemToBinary(key), alg, true, [ + CryptoKeyUsage.Sign, + ]) + key = await exportPublicJwkFrom(privateKey) + } + const usages = [CryptoKeyUsage.Verify] + if (typeof key === 'object') { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#json_web_key_import + return await crypto.subtle.importKey('jwk', key, alg, false, usages) + } + if (key.includes('PUBLIC')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#subjectpublickeyinfo_import + return await crypto.subtle.importKey('spki', pemToBinary(key), alg, false, usages) + } + return await crypto.subtle.importKey('raw', utf8Encoder.encode(key), alg, false, usages) +} + +// https://datatracker.ietf.org/doc/html/rfc7517 +async function exportPublicJwkFrom(privateKey: CryptoKey): Promise { + if (privateKey.type !== 'private') { + throw new Error(`unexpected key type: ${privateKey.type}`) + } + if (!privateKey.extractable) { + throw new Error('unexpected private key is unextractable') + } + const jwk = await crypto.subtle.exportKey('jwk', privateKey) + const { kty } = jwk // common + const { alg, e, n } = jwk // rsa + const { crv, x, y } = jwk // elliptic-curve + return { kty, alg, e, n, crv, x, y, key_ops: [CryptoKeyUsage.Verify] } +} + +function getKeyAlgorithm(name: SignatureAlgorithm): KeyAlgorithm { + switch (name) { + case 'HS256': + return { + name: 'HMAC', + hash: { + name: 'SHA-256', + }, + } satisfies HmacImportParams + case 'HS384': + return { + name: 'HMAC', + hash: { + name: 'SHA-384', + }, + } satisfies HmacImportParams + case 'HS512': + return { + name: 'HMAC', + hash: { + name: 'SHA-512', + }, + } satisfies HmacImportParams + case 'RS256': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-256', + }, + } satisfies RsaHashedImportParams + case 'RS384': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-384', + }, + } satisfies RsaHashedImportParams + case 'RS512': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-512', + }, + } satisfies RsaHashedImportParams + case 'PS256': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-256', + }, + saltLength: 32, // 256 >> 3 + } satisfies RsaPssParams & RsaHashedImportParams + case 'PS384': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-384', + }, + saltLength: 48, // 384 >> 3 + } satisfies RsaPssParams & RsaHashedImportParams + case 'PS512': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-512', + }, + saltLength: 64, // 512 >> 3, + } satisfies RsaPssParams & RsaHashedImportParams + case 'ES256': + return { + name: 'ECDSA', + hash: { + name: 'SHA-256', + }, + namedCurve: 'P-256', + } satisfies EcdsaParams & EcKeyImportParams + case 'ES384': + return { + name: 'ECDSA', + hash: { + name: 'SHA-384', + }, + namedCurve: 'P-384', + } satisfies EcdsaParams & EcKeyImportParams + case 'ES512': + return { + name: 'ECDSA', + hash: { + name: 'SHA-512', + }, + namedCurve: 'P-521', + } satisfies EcdsaParams & EcKeyImportParams + case 'EdDSA': + // Currently, supported only Safari and Deno, Node.js. + // See: https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/verify + return { + name: 'Ed25519', + namedCurve: 'Ed25519', + } + default: + throw new JwtAlgorithmNotImplemented(name) + } +} + +function isCryptoKey(key: SignatureKey): key is CryptoKey { + const runtime = getRuntimeKey() + // @ts-expect-error CryptoKey hasn't exported to global in node v18 + if (runtime === 'node' && !!crypto.webcrypto) { + // @ts-expect-error CryptoKey hasn't exported to global in node v18 + return key instanceof crypto.webcrypto.CryptoKey + } + return key instanceof CryptoKey +} diff --git a/deno_dist/utils/jwt/jwt.ts b/deno_dist/utils/jwt/jwt.ts index 73ceedc0f..d15693463 100644 --- a/deno_dist/utils/jwt/jwt.ts +++ b/deno_dist/utils/jwt/jwt.ts @@ -1,44 +1,17 @@ import { encodeBase64Url, decodeBase64Url } from '../../utils/encode.ts' -import type { AlgorithmTypes } from './types.ts' -import { JwtTokenIssuedAt } from './types.ts' +import type { SignatureAlgorithm } from './jwa.ts' +import { AlgorithmTypes } from './jwa.ts' +import type { SignatureKey } from './jws.ts' +import { signing, verifying } from './jws.ts' +import { JwtHeaderInvalid, type JWTPayload } from './types.ts' import { JwtTokenInvalid, JwtTokenNotBefore, JwtTokenExpired, JwtTokenSignatureMismatched, - JwtAlgorithmNotImplemented, + JwtTokenIssuedAt, } from './types.ts' - -interface AlgorithmParams { - name: string - namedCurve?: string - hash?: { - name: string - } -} - -enum CryptoKeyFormat { - RAW = 'raw', - PKCS8 = 'pkcs8', - SPKI = 'spki', - JWK = 'jwk', -} - -enum CryptoKeyUsage { - Ecrypt = 'encrypt', - Decrypt = 'decrypt', - Sign = 'sign', - Verify = 'verify', - Deriverkey = 'deriveKey', - DeriveBits = 'deriveBits', - WrapKey = 'wrapKey', - UnwrapKey = 'unwrapKey', -} - -type AlgorithmTypeName = keyof typeof AlgorithmTypes - -const utf8Encoder = new TextEncoder() -const utf8Decoder = new TextDecoder() +import { utf8Decoder, utf8Encoder } from './utf8.ts' const encodeJwtPart = (part: unknown): string => encodeBase64Url(utf8Encoder.encode(JSON.stringify(part))).replace(/=/g, '') @@ -47,65 +20,34 @@ const encodeSignaturePart = (buf: ArrayBufferLike): string => encodeBase64Url(bu const decodeJwtPart = (part: string): unknown => JSON.parse(utf8Decoder.decode(decodeBase64Url(part))) -const param = (name: AlgorithmTypeName): AlgorithmParams => { - switch (name.toUpperCase()) { - case 'HS256': - return { - name: 'HMAC', - hash: { - name: 'SHA-256', - }, - } - case 'HS384': - return { - name: 'HMAC', - hash: { - name: 'SHA-384', - }, - } - case 'HS512': - return { - name: 'HMAC', - hash: { - name: 'SHA-512', - }, - } - default: - throw new JwtAlgorithmNotImplemented(name) - } +export interface TokenHeader { + alg: SignatureAlgorithm + typ: 'JWT' } -const signing = async ( - data: string, - secret: string, - alg: AlgorithmTypeName = 'HS256' -): Promise => { - if (!crypto.subtle || !crypto.subtle.importKey) { - throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') - } - - const utf8Encoder = new TextEncoder() - const cryptoKey = await crypto.subtle.importKey( - CryptoKeyFormat.RAW, - utf8Encoder.encode(secret), - param(alg), - false, - [CryptoKeyUsage.Sign] +// eslint-disable-next-line +export function isTokenHeader(obj: any): obj is TokenHeader { + return ( + typeof obj === 'object' && + obj !== null && + 'alg' in obj && + Object.values(AlgorithmTypes).includes(obj.alg) && + 'typ' in obj && + obj.typ === 'JWT' ) - return await crypto.subtle.sign(param(alg), cryptoKey, utf8Encoder.encode(data)) } export const sign = async ( - payload: unknown, - secret: string, - alg: AlgorithmTypeName = 'HS256' + payload: JWTPayload, + privateKey: SignatureKey, + alg: SignatureAlgorithm = 'HS256' ): Promise => { const encodedPayload = encodeJwtPart(payload) - const encodedHeader = encodeJwtPart({ alg, typ: 'JWT' }) + const encodedHeader = encodeJwtPart({ alg, typ: 'JWT' } satisfies TokenHeader) const partialToken = `${encodedHeader}.${encodedPayload}` - const signaturePart = await signing(partialToken, secret, alg) + const signaturePart = await signing(privateKey, alg, utf8Encoder.encode(partialToken)) const signature = encodeSignaturePart(signaturePart) return `${partialToken}.${signature}` @@ -113,8 +55,8 @@ export const sign = async ( export const verify = async ( token: string, - secret: string, - alg: AlgorithmTypeName = 'HS256' + publicKey: SignatureKey, + alg: SignatureAlgorithm = 'HS256' // eslint-disable-next-line @typescript-eslint/no-explicit-any ): Promise => { const tokenParts = token.split('.') @@ -122,7 +64,10 @@ export const verify = async ( throw new JwtTokenInvalid(token) } - const { payload } = decode(token) + const { header, payload } = decode(token) + if (!isTokenHeader(header)) { + throw new JwtHeaderInvalid(header) + } const now = Math.floor(Date.now() / 1000) if (payload.nbf && payload.nbf > now) { throw new JwtTokenNotBefore(token) @@ -134,10 +79,14 @@ export const verify = async ( throw new JwtTokenIssuedAt(now, payload.iat) } - const signaturePart = tokenParts.slice(0, 2).join('.') - const signature = await signing(signaturePart, secret, alg) - const encodedSignature = encodeSignaturePart(signature) - if (encodedSignature !== tokenParts[2]) { + const headerPayload = token.substring(0, token.lastIndexOf('.')) + const verified = await verifying( + publicKey, + alg, + decodeBase64Url(tokenParts[2]), + utf8Encoder.encode(headerPayload) + ) + if (!verified) { throw new JwtTokenSignatureMismatched(token) } diff --git a/deno_dist/utils/jwt/types.ts b/deno_dist/utils/jwt/types.ts index 2f052e9c1..738afd572 100644 --- a/deno_dist/utils/jwt/types.ts +++ b/deno_dist/utils/jwt/types.ts @@ -33,6 +33,13 @@ export class JwtTokenIssuedAt extends Error { } } +export class JwtHeaderInvalid extends Error { + constructor(header: object) { + super(`jwt header is invalid: ${JSON.stringify(header)}`) + this.name = 'JwtHeaderInvalid' + } +} + export class JwtTokenSignatureMismatched extends Error { constructor(token: string) { super(`token(${token}) signature mismatched`) @@ -40,8 +47,34 @@ export class JwtTokenSignatureMismatched extends Error { } } -export enum AlgorithmTypes { - HS256 = 'HS256', - HS384 = 'HS384', - HS512 = 'HS512', +export enum CryptoKeyUsage { + Encrypt = 'encrypt', + Decrypt = 'decrypt', + Sign = 'sign', + Verify = 'verify', + DeriveKey = 'deriveKey', + DeriveBits = 'deriveBits', + WrapKey = 'wrapKey', + UnwrapKey = 'unwrapKey', } + +/** + * JWT Payload + */ +export type JWTPayload = + | (unknown & {}) + | { + [key: string]: unknown + /** + * The token is checked to ensure it has not expired. + */ + exp?: number + /** + * The token is checked to ensure it is not being used before a specified time. + */ + nbf?: number + /** + * The token is checked to ensure it is not issued in the future. + */ + iat?: number + } diff --git a/deno_dist/utils/jwt/utf8.ts b/deno_dist/utils/jwt/utf8.ts new file mode 100644 index 000000000..107407aee --- /dev/null +++ b/deno_dist/utils/jwt/utf8.ts @@ -0,0 +1,2 @@ +export const utf8Encoder = new TextEncoder() +export const utf8Decoder = new TextDecoder() diff --git a/deno_dist/validator/validator.ts b/deno_dist/validator/validator.ts index bb3b01308..e4d8103b4 100644 --- a/deno_dist/validator/validator.ts +++ b/deno_dist/validator/validator.ts @@ -60,7 +60,6 @@ export const validator = < return async (c, next) => { let value = {} const contentType = c.req.header('Content-Type') - const bodyTypes = ['text', 'arrayBuffer', 'blob'] switch (target) { case 'json': @@ -68,26 +67,8 @@ export const validator = < const message = `Invalid HTTP header: Content-Type=${contentType}` throw new HTTPException(400, { message }) } - - if (c.req.bodyCache.json) { - value = await c.req.bodyCache.json - break - } - try { - let arrayBuffer: ArrayBuffer | undefined = undefined - for (const type of bodyTypes) { - // @ts-expect-error bodyCache[type] is not typed - const body = c.req.bodyCache[type] - if (body) { - arrayBuffer = await new Response(await body).arrayBuffer() - break - } - } - arrayBuffer ??= await c.req.raw.arrayBuffer() - value = await new Response(arrayBuffer).json() - c.req.bodyCache.json = value - c.req.bodyCache.arrayBuffer = arrayBuffer + value = await c.req.json() } catch { const message = 'Malformed JSON in request body' throw new HTTPException(400, { message }) @@ -104,16 +85,7 @@ export const validator = < } try { - let arrayBuffer: ArrayBuffer | undefined = undefined - for (const type of bodyTypes) { - // @ts-expect-error bodyCache[type] is not typed - const body = c.req.bodyCache[type] - if (body) { - arrayBuffer = await new Response(await body).arrayBuffer() - break - } - } - arrayBuffer ??= await c.req.arrayBuffer() + const arrayBuffer = await c.req.arrayBuffer() const formData = await bufferToFormData(arrayBuffer, contentType) const form: BodyData = {} formData.forEach((value, key) => { @@ -121,7 +93,6 @@ export const validator = < }) value = form c.req.bodyCache.formData = formData - c.req.bodyCache.arrayBuffer = arrayBuffer } catch (e) { let message = 'Malformed FormData request.' message += e instanceof Error ? ` ${e.message}` : ` ${String(e)}` diff --git a/package.json b/package.json index 4223b74b5..1d09d4aac 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "hono", - "version": "4.1.7", + "version": "4.2.0-rc.1", "description": "Ultrafast web framework for the Edges", "main": "dist/cjs/index.js", "type": "module", @@ -114,6 +114,11 @@ "import": "./dist/middleware/etag/index.js", "require": "./dist/cjs/middleware/etag/index.js" }, + "./trailing-slash": { + "types": "./dist/types/middleware/trailing-slash/index.d.ts", + "import": "./dist/middleware/trailing-slash/index.js", + "require": "./dist/cjs/middleware/trailing-slash/index.js" + }, "./html": { "types": "./dist/types/helper/html/index.d.ts", "import": "./dist/helper/html/index.js", @@ -184,6 +189,11 @@ "import": "./dist/middleware/logger/index.js", "require": "./dist/cjs/middleware/logger/index.js" }, + "./method-override": { + "types": "./dist/types/middleware/method-override/index.d.ts", + "import": "./dist/middleware/method-override/index.js", + "require": "./dist/cjs/middleware/method-override/index.js" + }, "./powered-by": { "types": "./dist/types/middleware/powered-by/index.d.ts", "import": "./dist/middleware/powered-by/index.js", @@ -408,6 +418,9 @@ "logger": [ "./dist/types/middleware/logger" ], + "method-override": [ + "./dist/types/middleware/method-override" + ], "powered-by": [ "./dist/types/middleware/powered-by" ], diff --git a/src/helper/ssg/ssg.test.tsx b/src/helper/ssg/ssg.test.tsx index 654b4356a..e55cbd4a4 100644 --- a/src/helper/ssg/ssg.test.tsx +++ b/src/helper/ssg/ssg.test.tsx @@ -4,7 +4,7 @@ import { Hono } from '../../hono' import { jsx } from '../../jsx' import { poweredBy } from '../../middleware/powered-by' import { SSG_DISABLED_RESPONSE, ssgParams, isSSGContext, disableSSG, onlySSG } from './middleware' -import { fetchRoutesContent, saveContentToFile, toSSG } from './ssg' +import { fetchRoutesContent, saveContentToFile, toSSG, defaultExtensionMap } from './ssg' import type { BeforeRequestHook, AfterResponseHook, @@ -291,6 +291,9 @@ describe('fetchRoutesContent function', () => { beforeEach(() => { app = new Hono() app.get('/text', (c) => c.text('Text Response')) + app.get('/text-utf8', (c) => { + return c.text('Text Response', 200, { 'Content-Type': 'text/plain;charset=UTF-8' }) + }) app.get('/html', (c) => c.html('

HTML Response

')) app.get('/json', (c) => c.json({ message: 'JSON Response' })) app.use('*', poweredBy()) @@ -303,6 +306,10 @@ describe('fetchRoutesContent function', () => { content: 'Text Response', mimeType: 'text/plain', }) + expect(htmlMap.get('/text-utf8')).toEqual({ + content: 'Text Response', + mimeType: 'text/plain', + }) expect(htmlMap.get('/html')).toEqual({ content: '

HTML Response

', mimeType: 'text/html', @@ -458,6 +465,78 @@ describe('saveContentToFile function', () => { } expect(fsMock.mkdir).toHaveBeenCalledWith('static-check-extensions', { recursive: true }) }) + + it('should correctly create .yaml files for YAML content', async () => { + const yamlContent = 'title: YAML Example\nvalue: This is a YAML file.' + const mimeType = 'application/yaml' + const routePath = '/example' + + const yamlData = { + routePath: routePath, + content: yamlContent, + mimeType: mimeType, + } + + const fsMock: FileSystemModule = { + writeFile: vi.fn(() => Promise.resolve()), + mkdir: vi.fn(() => Promise.resolve()), + } + + await saveContentToFile(Promise.resolve(yamlData), fsMock, './static') + + expect(fsMock.writeFile).toHaveBeenCalledWith('static/example.yaml', yamlContent) + }) + + it('should correctly create .yml files for YAML content', async () => { + const yamlContent = 'title: YAML Example\nvalue: This is a YAML file.' + const yamlMimeType = 'application/yaml' + const yamlRoutePath = '/yaml' + + const yamlData = { + routePath: yamlRoutePath, + content: yamlContent, + mimeType: yamlMimeType, + } + + const yamlMimeType2 = 'x-yaml' + const yamlRoutePath2 = '/yaml2' + const yamlData2 = { + routePath: yamlRoutePath2, + content: yamlContent, + mimeType: yamlMimeType2, + } + + const htmlMimeType = 'text/html' + const htmlRoutePath = '/html' + + const htmlData = { + routePath: htmlRoutePath, + content: yamlContent, + mimeType: htmlMimeType, + } + + const fsMock: FileSystemModule = { + writeFile: vi.fn(() => Promise.resolve()), + mkdir: vi.fn(() => Promise.resolve()), + } + + const extensionMap = { + 'application/yaml': 'yml', + 'x-yaml': 'xyml', + } + await saveContentToFile(Promise.resolve(yamlData), fsMock, './static', extensionMap) + await saveContentToFile(Promise.resolve(yamlData2), fsMock, './static', extensionMap) + await saveContentToFile(Promise.resolve(htmlData), fsMock, './static', extensionMap) + await saveContentToFile(Promise.resolve(htmlData), fsMock, './static', { + ...defaultExtensionMap, + ...extensionMap, + }) + + expect(fsMock.writeFile).toHaveBeenCalledWith('static/yaml.yml', yamlContent) + expect(fsMock.writeFile).toHaveBeenCalledWith('static/yaml2.xyml', yamlContent) + expect(fsMock.writeFile).toHaveBeenCalledWith('static/html.htm', yamlContent) // extensionMap + expect(fsMock.writeFile).toHaveBeenCalledWith('static/html.html', yamlContent) // default + extensionMap + }) }) describe('Dynamic route handling', () => { diff --git a/src/helper/ssg/ssg.ts b/src/helper/ssg/ssg.ts index 17040e5f9..583a0ec6b 100644 --- a/src/helper/ssg/ssg.ts +++ b/src/helper/ssg/ssg.ts @@ -30,8 +30,13 @@ export interface ToSSGResult { error?: Error } -const generateFilePath = (routePath: string, outDir: string, mimeType: string) => { - const extension = determineExtension(mimeType) +const generateFilePath = ( + routePath: string, + outDir: string, + mimeType: string, + extensionMap?: Record +) => { + const extension = determineExtension(mimeType, extensionMap) if (routePath.endsWith(`.${extension}`)) { return joinPaths(outDir, routePath) @@ -62,17 +67,22 @@ const parseResponseContent = async (response: Response): Promise { - switch (mimeType) { - case 'text/html': - return 'html' - case 'text/xml': - case 'application/xml': - return 'xml' - default: { - return getExtension(mimeType) || 'html' - } +export const defaultExtensionMap: Record = { + 'text/html': 'html', + 'text/xml': 'xml', + 'application/xml': 'xml', + 'application/yaml': 'yaml', +} + +const determineExtension = ( + mimeType: string, + userExtensionMap?: Record +): string => { + const extensionMap = userExtensionMap || defaultExtensionMap + if (mimeType in extensionMap) { + return extensionMap[mimeType] } + return getExtension(mimeType) || 'html' } export type BeforeRequestHook = (req: Request) => Request | false | Promise @@ -85,6 +95,7 @@ export interface ToSSGOptions { afterResponseHook?: AfterResponseHook afterGenerateHook?: AfterGenerateHook concurrency?: number + extensionMap?: Record } /** @@ -204,14 +215,15 @@ const createdDirs: Set = new Set() export const saveContentToFile = async ( data: Promise<{ routePath: string; content: string | ArrayBuffer; mimeType: string } | undefined>, fsModule: FileSystemModule, - outDir: string + outDir: string, + extensionMap?: Record ): Promise => { const awaitedData = await data if (!awaitedData) { return } const { routePath, content, mimeType } = awaitedData - const filePath = generateFilePath(routePath, outDir, mimeType) + const filePath = generateFilePath(routePath, outDir, mimeType, extensionMap) const dirPath = dirname(filePath) if (!createdDirs.has(dirPath)) { diff --git a/src/jsx/dom/index.ts b/src/jsx/dom/index.ts index e81bf6efd..8b202a437 100644 --- a/src/jsx/dom/index.ts +++ b/src/jsx/dom/index.ts @@ -15,6 +15,7 @@ import { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, } from '../hooks' import { Suspense, ErrorBoundary } from './components' @@ -69,6 +70,7 @@ export { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, Suspense, ErrorBoundary, @@ -94,6 +96,7 @@ export default { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, Suspense, ErrorBoundary, diff --git a/src/jsx/hooks/dom.test.tsx b/src/jsx/hooks/dom.test.tsx index 86a5e4ed9..491d7514b 100644 --- a/src/jsx/hooks/dom.test.tsx +++ b/src/jsx/hooks/dom.test.tsx @@ -14,10 +14,11 @@ import { useDeferredValue, startViewTransition, useViewTransition, + useId, useDebugValue, } from '.' -describe('useReducer()', () => { +describe('Hooks', () => { beforeAll(() => { global.requestAnimationFrame = (cb) => setTimeout(cb) }) @@ -34,45 +35,97 @@ describe('useReducer()', () => { root = document.getElementById('root') as HTMLElement }) - it('simple', async () => { - const App = () => { - const [state, dispatch] = useReducer((state: number, action: number) => state + action, 0) - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') + describe('useReducer()', () => { + it('simple', async () => { + const App = () => { + const [state, dispatch] = useReducer((state: number, action: number) => state + action, 0) + return ( +
+ +
+ ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + }) }) -}) -describe('startTransition()', () => { - beforeAll(() => { - global.requestAnimationFrame = (cb) => setTimeout(cb) - }) + describe('startTransition()', () => { + it('no error', async () => { + const App = () => { + const [count, setCount] = useState(0) + return ( + Loading...}> +
+ +
+
+ ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + }) - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + it('got an error', async () => { + let resolve: () => void + const promise = new Promise((r) => (resolve = r)) + + const Counter = ({ count }: { count: number }) => { + use(promise) + return
{count}
+ } + + const App = () => { + const [count, setCount] = useState(0) + return ( + Loading...}> +
+ +
+
+ ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + expect(root.innerHTML).toBe('
') + resolve!() + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement }) - it('no error', async () => { - const App = () => { - const [count, setCount] = useState(0) - return ( - Loading...}> + describe('useTransition()', () => { + it('pending', async () => { + let called = 0 + const App = () => { + const [count, setCount] = useState(0) + const [isPending, startTransition] = useTransition() + called++ + + return (
-
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - }) - - it('got an error', async () => { - let resolve: () => void - const promise = new Promise((r) => (resolve = r)) + ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + expect(called).toBe(2) + await new Promise((r) => setTimeout(r)) + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(3) + }) - const Counter = ({ count }: { count: number }) => { - use(promise) - return
{count}
- } + it('multiple setState at once', async () => { + let called = 0 + const App = () => { + const [count1, setCount1] = useState(0) + const [count2, setCount2] = useState(0) + const [isPending, startTransition] = useTransition() + called++ - const App = () => { - const [count, setCount] = useState(0) - return ( - Loading...}> + return (
-
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - expect(root.innerHTML).toBe('
') - resolve!() - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - }) -}) - -describe('useTransition()', () => { - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + expect(called).toBe(2) + await new Promise((r) => setTimeout(r)) + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(3) }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.requestAnimationFrame = (cb: Function) => setTimeout(cb) - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement - }) - - it('pending', async () => { - let called = 0 - const App = () => { - const [count, setCount] = useState(0) - const [isPending, startTransition] = useTransition() - called++ - - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - expect(called).toBe(2) - await new Promise((r) => setTimeout(r)) - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(3) - }) - it('multiple setState at once', async () => { - let called = 0 - const App = () => { - const [count1, setCount1] = useState(0) - const [count2, setCount2] = useState(0) - const [isPending, startTransition] = useTransition() - called++ + it('multiple startTransaction at once', async () => { + let called = 0 + const App = () => { + const [count1, setCount1] = useState(0) + const [count2, setCount2] = useState(0) + const [isPending, startTransition] = useTransition() + called++ - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - expect(called).toBe(2) - await new Promise((r) => setTimeout(r)) - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(3) + return ( +
+ +
+ ) + } + render(, root) + expect(root.innerHTML).toBe('
') + expect(called).toBe(1) + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + expect(called).toBe(2) + await new Promise((r) => setTimeout(r)) + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(3) // + isPending=true + isPending=false + }) }) - it('multiple startTransaction at once', async () => { - let called = 0 - const App = () => { - const [count1, setCount1] = useState(0) - const [count2, setCount2] = useState(0) - const [isPending, startTransition] = useTransition() - called++ + describe('useDeferredValue()', () => { + it('deferred', async () => { + const promiseMap = {} as Record> + const getPromise = (count: number) => { + return (promiseMap[count] ||= new Promise((r) => setTimeout(() => r(count + 1)))) + } + const ShowCount = ({ count }: { count: number }) => { + if (count === 0) { + return
0
+ } - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - expect(called).toBe(1) - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - expect(called).toBe(2) - await new Promise((r) => setTimeout(r)) - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(3) // + isPending=true + isPending=false - }) -}) + const c = use(getPromise(count)) + return
{c}
+ } -describe('useDeferredValue()', () => { - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + const App = () => { + const [count, setCount] = useState(0) + const c = useDeferredValue(count) + return ( + <> +
+ +
+ Loading...}> + + + + ) + } + render(, root) + expect(root.innerHTML).toBe('
0
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
0
') + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
2
') }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement }) - it('deferred', async () => { - const promiseMap = {} as Record> - const getPromise = (count: number) => { - return (promiseMap[count] ||= new Promise((r) => setTimeout(() => r(count + 1)))) - } - const ShowCount = ({ count }: { count: number }) => { - if (count === 0) { - return
0
- } + describe('startViewTransition()', () => { + afterEach(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delete (dom.window.document as any).startViewTransition + }) - const c = use(getPromise(count)) - return
{c}
- } + it('supported browser', async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ;(dom.window.document as any).startViewTransition = vi.fn((cb: Function) => { + Promise.resolve().then(() => cb()) + return { finished: Promise.resolve() } + }) - const App = () => { - const [count, setCount] = useState(0) - const c = useDeferredValue(count) - return ( - <> -
- -
+ const App = () => { + const [count, setCount] = useState(0) + return ( Loading...}> - +
+ +
- - ) - } - render(, root) - expect(root.innerHTML).toBe('
0
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
0
') - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
2
') - }) -}) - -describe('startViewTransition()', () => { - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') + await Promise.resolve() // updated in microtask + expect(root.innerHTML).toBe('
') }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement - }) - afterEach(() => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - delete (dom.window.document as any).startViewTransition - }) - it('supported browser', async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ;(dom.window.document as any).startViewTransition = vi.fn((cb: Function) => { - Promise.resolve().then(() => cb()) - return { finished: Promise.resolve() } + it('unsupported browser', async () => { + const App = () => { + const [count, setCount] = useState(0) + return ( + Loading...}> +
+ +
+
+ ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('
') }) - const App = () => { - const [count, setCount] = useState(0) - return ( - Loading...}> + it('with useTransition()', async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ;(dom.window.document as any).startViewTransition = vi.fn((cb: Function) => { + Promise.resolve().then(() => cb()) + return { finished: Promise.resolve() } + }) + + let called = 0 + const App = () => { + const [count, setCount] = useState(0) + const [isPending, startTransition] = useTransition() + called++ + + return (
-
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - await Promise.resolve() // updated in microtask - expect(root.innerHTML).toBe('
') + ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(2) + await new Promise((r) => setTimeout(r)) + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(3) + }) }) - it('unsupported browser', async () => { - const App = () => { - const [count, setCount] = useState(0) - return ( - Loading...}> + describe('useViewTransition()', () => { + afterEach(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delete (dom.window.document as any).startViewTransition + }) + + it('supported browser', async () => { + let resolved: (() => void) | undefined + const promise = new Promise((r) => (resolved = r)) + let called = 0 + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ;(global.document as any).startViewTransition = vi.fn((cb: Function) => { + Promise.resolve().then(() => cb()) + return { finished: promise } + }) + + const App = () => { + const [count, setCount] = useState(0) + const [isUpdating, startViewTransition] = useViewTransition() + called++ + + return (
-
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('
') - }) - - it('with useTransition()', async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ;(dom.window.document as any).startViewTransition = vi.fn((cb: Function) => { - Promise.resolve().then(() => cb()) - return { finished: Promise.resolve() } + ) + } + render(, root) + expect(root.innerHTML).toBe('
') + root.querySelector('button')?.click() + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(2) + resolved?.() + await new Promise((r) => setTimeout(r)) + expect(root.innerHTML).toBe('
') + expect(called).toBe(3) }) - - let called = 0 - const App = () => { - const [count, setCount] = useState(0) - const [isPending, startTransition] = useTransition() - called++ - - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(2) - await new Promise((r) => setTimeout(r)) - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(3) }) -}) -describe('useViewTransition()', () => { - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + describe('useId()', () => { + let dom: JSDOM + let root: HTMLElement + beforeEach(() => { + dom = new JSDOM('
', { + runScripts: 'dangerously', + }) + global.document = dom.window.document + global.HTMLElement = dom.window.HTMLElement + global.Text = dom.window.Text + root = document.getElementById('root') as HTMLElement }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement - }) - afterEach(() => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - delete (dom.window.document as any).startViewTransition - }) - it('supported browser', async () => { - let resolved: (() => void) | undefined - const promise = new Promise((r) => (resolved = r)) - let called = 0 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ;(global.document as any).startViewTransition = vi.fn((cb: Function) => { - Promise.resolve().then(() => cb()) - return { finished: promise } + it('simple', () => { + const App = () => { + const id = useId() + return
+ } + render(, root) + expect(root.innerHTML).toBe('
') }) - const App = () => { - const [count, setCount] = useState(0) - const [isUpdating, startViewTransition] = useViewTransition() - called++ - - return ( -
- -
- ) - } - render(, root) - expect(root.innerHTML).toBe('
') - root.querySelector('button')?.click() - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(2) - resolved?.() - await new Promise((r) => setTimeout(r)) - expect(root.innerHTML).toBe('
') - expect(called).toBe(3) - }) -}) - -describe('useDebugValue()', () => { - let dom: JSDOM - let root: HTMLElement - beforeEach(() => { - dom = new JSDOM('
', { - runScripts: 'dangerously', + it('memoized', async () => { + let setCount: (c: number) => void = () => {} + const App = () => { + const id = useId() + const [count, _setCount] = useState(0) + setCount = _setCount + return
{count}
+ } + render(, root) + expect(root.innerHTML).toBe('
0
') + setCount(1) + await Promise.resolve() + expect(root.innerHTML).toBe('
1
') }) - global.document = dom.window.document - global.HTMLElement = dom.window.HTMLElement - global.Text = dom.window.Text - root = document.getElementById('root') as HTMLElement }) - it('simple', () => { - const spy = vi.fn() - const App = () => { - useDebugValue('hello', spy) - return
- } - render(, root) - expect(root.innerHTML).toBe('
') - expect(spy).not.toBeCalled() + describe('useDebugValue()', () => { + it('simple', () => { + const spy = vi.fn() + const App = () => { + useDebugValue('hello', spy) + return
+ } + render(, root) + expect(root.innerHTML).toBe('
') + expect(spy).not.toBeCalled() + }) }) }) diff --git a/src/jsx/hooks/index.ts b/src/jsx/hooks/index.ts index 6254bc7a0..dfd069e02 100644 --- a/src/jsx/hooks/index.ts +++ b/src/jsx/hooks/index.ts @@ -349,6 +349,9 @@ export const useMemo = (factory: () => T, deps: readonly unknown[]): T => { return memoArray[hookIndex][0] as T } +let idCounter = 0 +export const useId = (): string => useMemo(() => `:r${(idCounter++).toString(32)}:`, []) + // Define to avoid errors. This hook currently does nothing. // eslint-disable-next-line @typescript-eslint/no-unused-vars export const useDebugValue = (_value: unknown, _formatter?: (value: unknown) => string): void => {} diff --git a/src/jsx/index.ts b/src/jsx/index.ts index 56d827ff0..5d286b7e8 100644 --- a/src/jsx/index.ts +++ b/src/jsx/index.ts @@ -15,6 +15,7 @@ import { useMemo, useLayoutEffect, useReducer, + useId, useDebugValue, } from './hooks' import { Suspense } from './streaming' @@ -34,6 +35,7 @@ export { useRef, useCallback, useReducer, + useId, useDebugValue, use, startTransition, @@ -60,6 +62,7 @@ export default { useRef, useCallback, useReducer, + useId, useDebugValue, use, startTransition, diff --git a/src/middleware.ts b/src/middleware.ts index 57ba64783..65954a4ed 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -11,8 +11,10 @@ export * from './jsx' export * from './middleware/jsx-renderer' export { jwt } from './middleware/jwt' export * from './middleware/logger' +export * from './middleware/method-override' export * from './middleware/powered-by' export * from './middleware/timing' export * from './middleware/pretty-json' export * from './middleware/secure-headers' +export * from './middleware/trailing-slash' export * from './adapter/deno/serve-static' diff --git a/src/middleware/basic-auth/index.test.ts b/src/middleware/basic-auth/index.test.ts index 354861445..147b3c95a 100644 --- a/src/middleware/basic-auth/index.test.ts +++ b/src/middleware/basic-auth/index.test.ts @@ -70,6 +70,19 @@ describe('Basic Auth by Middleware', () => { return auth(c, next) }) + app.use('/verify-user/*', async (c, next) => { + const auth = basicAuth({ + verifyUser: (username, password, c) => { + return ( + c.req.path === '/verify-user' && + username === 'dynamic-user' && + password === 'hono-password' + ) + }, + }) + return auth(c, next) + }) + app.get('/auth/*', (c) => { handlerExecuted = true return c.text('auth') @@ -92,6 +105,11 @@ describe('Basic Auth by Middleware', () => { return c.text('nested') }) + app.get('/verify-user', (c) => { + handlerExecuted = true + return c.text('verify-user') + }) + it('Should not authorize', async () => { const req = new Request('http://localhost/auth/a') const res = await app.request(req) @@ -184,4 +202,28 @@ describe('Basic Auth by Middleware', () => { expect(res.status).toBe(401) expect(await res.text()).toBe('Unauthorized') }) + + it('Should authorize - verifyUser', async () => { + const credential = Buffer.from('dynamic-user' + ':' + 'hono-password').toString('base64') + + const req = new Request('http://localhost/verify-user') + req.headers.set('Authorization', `Basic ${credential}`) + const res = await app.request(req) + expect(handlerExecuted).toBeTruthy() + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(await res.text()).toBe('verify-user') + }) + + it('Should not authorize - verifyUser', async () => { + const credential = Buffer.from('foo' + ':' + 'bar').toString('base64') + + const req = new Request('http://localhost/verify-user') + req.headers.set('Authorization', `Basic ${credential}`) + const res = await app.request(req) + expect(handlerExecuted).toBeFalsy() + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Unauthorized') + }) }) diff --git a/src/middleware/basic-auth/index.ts b/src/middleware/basic-auth/index.ts index 5548dad7c..a830219c1 100644 --- a/src/middleware/basic-auth/index.ts +++ b/src/middleware/basic-auth/index.ts @@ -1,3 +1,4 @@ +import type { Context } from '../../context' import { HTTPException } from '../../http-exception' import type { HonoRequest } from '../../request' import type { MiddlewareHandler } from '../../types' @@ -26,31 +27,59 @@ const auth = (req: HonoRequest) => { return { username: userPass[1], password: userPass[2] } } +type BasicAuthOptions = + | { + username: string + password: string + realm?: string + hashFunction?: Function + } + | { + verifyUser: (username: string, password: string, c: Context) => boolean | Promise + realm?: string + hashFunction?: Function + } + export const basicAuth = ( - options: { username: string; password: string; realm?: string; hashFunction?: Function }, + options: BasicAuthOptions, ...users: { username: string; password: string }[] ): MiddlewareHandler => { - if (!options) { - throw new Error('basic auth middleware requires options for "username and password"') + const usernamePasswordInOptions = 'username' in options && 'password' in options + const verifyUserInOptions = 'verifyUser' in options + + if (!(usernamePasswordInOptions || verifyUserInOptions)) { + throw new Error( + 'basic auth middleware requires options for "username and password" or "verifyUser"' + ) } if (!options.realm) { options.realm = 'Secure Area' } - users.unshift({ username: options.username, password: options.password }) + + if (usernamePasswordInOptions) { + users.unshift({ username: options.username, password: options.password }) + } return async function basicAuth(ctx, next) { const requestUser = auth(ctx.req) if (requestUser) { - for (const user of users) { - const [usernameEqual, passwordEqual] = await Promise.all([ - timingSafeEqual(user.username, requestUser.username, options.hashFunction), - timingSafeEqual(user.password, requestUser.password, options.hashFunction), - ]) - if (usernameEqual && passwordEqual) { + if (verifyUserInOptions) { + if (await options.verifyUser(requestUser.username, requestUser.password, ctx)) { await next() return } + } else { + for (const user of users) { + const [usernameEqual, passwordEqual] = await Promise.all([ + timingSafeEqual(user.username, requestUser.username, options.hashFunction), + timingSafeEqual(user.password, requestUser.password, options.hashFunction), + ]) + if (usernameEqual && passwordEqual) { + await next() + return + } + } } } const res = new Response('Unauthorized', { diff --git a/src/middleware/bearer-auth/index.test.ts b/src/middleware/bearer-auth/index.test.ts index 08f1b3504..81dea1aa0 100644 --- a/src/middleware/bearer-auth/index.test.ts +++ b/src/middleware/bearer-auth/index.test.ts @@ -43,6 +43,19 @@ describe('Bearer Auth by Middleware', () => { handlerExecuted = true return c.text('auths') }) + + app.use( + '/auth-verify-token/*', + bearerAuth({ + verifyToken: async (token, c) => { + return c.req.path === '/auth-verify-token' && token === 'dynamic-token' + }, + }) + ) + app.get('/auth-verify-token/*', (c) => { + handlerExecuted = true + return c.text('auth-verify-token') + }) }) it('Should authorize', async () => { @@ -144,4 +157,23 @@ describe('Bearer Auth by Middleware', () => { expect(handlerExecuted).toBeTruthy() expect(await res2.text()).toBe('auths') }) + + it('Should authorize - verifyToken option', async () => { + const res = await app.request('/auth-verify-token', { + headers: { Authorization: 'Bearer dynamic-token' }, + }) + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(handlerExecuted).toBeTruthy() + expect(await res.text()).toBe('auth-verify-token') + }) + + it('Should not authorize - verifyToken option', async () => { + const res = await app.request('/auth-verify-token', { + headers: { Authorization: 'Bearer invalid-token' }, + }) + expect(res).not.toBeNull() + expect(handlerExecuted).toBeFalsy() + expect(res.status).toBe(401) + }) }) diff --git a/src/middleware/bearer-auth/index.ts b/src/middleware/bearer-auth/index.ts index 6037ea41d..97bfecaf0 100644 --- a/src/middleware/bearer-auth/index.ts +++ b/src/middleware/bearer-auth/index.ts @@ -1,3 +1,4 @@ +import type { Context } from '../../context' import { HTTPException } from '../../http-exception' import type { MiddlewareHandler } from '../../types' import { timingSafeEqual } from '../../utils/buffer' @@ -5,13 +6,22 @@ import { timingSafeEqual } from '../../utils/buffer' const TOKEN_STRINGS = '[A-Za-z0-9._~+/-]+=*' const PREFIX = 'Bearer' -export const bearerAuth = (options: { - token: string | string[] - realm?: string - prefix?: string - hashFunction?: Function -}): MiddlewareHandler => { - if (!options.token) { +type BearerAuthOptions = + | { + token: string | string[] + realm?: string + prefix?: string + hashFunction?: Function + } + | { + realm?: string + prefix?: string + verifyToken: (token: string, c: Context) => boolean | Promise + hashFunction?: Function + } + +export const bearerAuth = (options: BearerAuthOptions): MiddlewareHandler => { + if (!('token' in options || 'verifyToken' in options)) { throw new Error('bearer auth middleware requires options for "token"') } if (!options.realm) { @@ -49,7 +59,9 @@ export const bearerAuth = (options: { throw new HTTPException(400, { res }) } else { let equal = false - if (typeof options.token === 'string') { + if ('verifyToken' in options) { + equal = await options.verifyToken(match[1], c) + } else if (typeof options.token === 'string') { equal = await timingSafeEqual(options.token, match[1], options.hashFunction) } else if (Array.isArray(options.token) && options.token.length > 0) { for (const token of options.token) { diff --git a/src/middleware/cache/index.test.ts b/src/middleware/cache/index.test.ts index d9c4ce174..1e0750fcc 100644 --- a/src/middleware/cache/index.test.ts +++ b/src/middleware/cache/index.test.ts @@ -51,7 +51,59 @@ describe('Cache Middleware', () => { return c.text('cached') }) - app.use('/not-found/*', cache({ cacheName: 'my-app-v1', wait: true, cacheControl: 'max-age=10' })) + app.use('/vary1/*', cache({ cacheName: 'my-app-v1', wait: true, vary: ['Accept'] })) + app.get('/vary1/', (c) => { + return c.text('cached') + }) + + app.use('/vary2/*', cache({ cacheName: 'my-app-v1', wait: true, vary: ['Accept'] })) + app.get('/vary2/', (c) => { + c.header('Vary', 'Accept-Encoding') + return c.text('cached') + }) + + app.use( + '/vary3/*', + cache({ cacheName: 'my-app-v1', wait: true, vary: ['Accept', 'Accept-Encoding'] }) + ) + app.get('/vary3/', (c) => { + c.header('Vary', 'Accept-Language') + return c.text('cached') + }) + + app.use( + '/vary4/*', + cache({ cacheName: 'my-app-v1', wait: true, vary: ['Accept', 'Accept-Encoding'] }) + ) + app.get('/vary4/', (c) => { + c.header('Vary', 'Accept, Accept-Language') + return c.text('cached') + }) + + app.use('/vary5/*', cache({ cacheName: 'my-app-v1', wait: true, vary: 'Accept' })) + app.get('/vary5/', (c) => { + return c.text('cached with Accept and Accept-Encoding headers') + }) + + app.use( + '/vary6/*', + cache({ cacheName: 'my-app-v1', wait: true, vary: 'Accept, Accept-Encoding' }) + ) + app.get('/vary6/', (c) => { + c.header('Vary', 'Accept, Accept-Language') + return c.text('cached with Accept and Accept-Encoding headers as array') + }) + + app.use('/vary7/*', cache({ cacheName: 'my-app-v1', wait: true, vary: ['Accept'] })) + app.get('/vary7/', (c) => { + c.header('Vary', '*') + return c.text('cached') + }) + + app.use( + '/not-found/*', + cache({ cacheName: 'my-app-v1', wait: true, cacheControl: 'max-age=10', vary: ['Accept'] }) + ) const ctx = new Context() @@ -93,11 +145,62 @@ describe('Cache Middleware', () => { expect(res.headers.get('cache-control')).toBe('private, max-age=10') }) + it('Should correctly apply a single Vary header from middleware', async () => { + const res = await app.request('http://localhost/vary1/') + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(res.headers.get('vary')).toBe('accept') + }) + + it('Should merge Vary headers from middleware and handler without duplicating', async () => { + const res = await app.request('http://localhost/vary2/') + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(res.headers.get('vary')).toBe('accept, accept-encoding') + }) + + it('Should deduplicate while merging multiple Vary headers from middleware and handler', async () => { + const res = await app.request('http://localhost/vary3/') + expect(res.headers.get('vary')).toBe('accept, accept-encoding, accept-language') + }) + + it('Should prevent duplication of Vary headers when identical ones are set by both middleware and handler', async () => { + const res = await app.request('http://localhost/vary4/') + expect(res.headers.get('vary')).toBe('accept, accept-encoding, accept-language') + }) + + it('Should correctly apply and return a single Vary header with Accept specified by middleware', async () => { + const res = await app.request('http://localhost/vary5/') + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(res.headers.get('vary')).toBe('accept') + }) + + it('Should merge Vary headers specified by middleware as a string with additional headers added by handler', async () => { + const res = await app.request('http://localhost/vary6/') + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(res.headers.get('vary')).toBe('accept, accept-encoding, accept-language') + }) + + it('Should prioritize the "*" Vary header from handler over any set by middleware', async () => { + const res = await app.request('http://localhost/vary7/') + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(res.headers.get('vary')).toBe('*') + }) + + it('Should not allow "*" as a Vary header in middleware configuration due to its impact on caching effectiveness', async () => { + expect(() => cache({ cacheName: 'my-app-v1', wait: true, vary: ['*'] })).toThrow() + expect(() => cache({ cacheName: 'my-app-v1', wait: true, vary: '*' })).toThrow() + }) + it('Should not cache if it is not found', async () => { const res = await app.request('/not-found/') expect(res).not.toBeNull() expect(res.status).toBe(404) expect(res.headers.get('cache-control')).toBeFalsy() + expect(res.headers.get('vary')).toBeFalsy() }) it('Should not be enabled if caches is not defined', async () => { diff --git a/src/middleware/cache/index.ts b/src/middleware/cache/index.ts index accff4a60..2652e3cce 100644 --- a/src/middleware/cache/index.ts +++ b/src/middleware/cache/index.ts @@ -5,6 +5,7 @@ export const cache = (options: { cacheName: string wait?: boolean cacheControl?: string + vary?: string | string[] }): MiddlewareHandler => { if (!globalThis.caches) { console.log('Cache Middleware is not enabled because caches is not defined.') @@ -15,16 +16,28 @@ export const cache = (options: { options.wait = false } - const directives = options.cacheControl?.split(',').map((directive) => directive.toLowerCase()) + const cacheControlDirectives = options.cacheControl + ?.split(',') + .map((directive) => directive.toLowerCase()) + const varyDirectives = Array.isArray(options.vary) + ? options.vary + : options.vary?.split(',').map((directive) => directive.trim()) + // RFC 7231 Section 7.1.4 specifies that "*" is not allowed in Vary header. + // See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.1.4 + if (options.vary?.includes('*')) { + throw new Error( + 'Middleware vary configuration cannot include "*", as it disallows effective caching.' + ) + } const addHeader = (c: Context) => { - if (directives) { + if (cacheControlDirectives) { const existingDirectives = c.res.headers .get('Cache-Control') ?.split(',') .map((d) => d.trim().split('=', 1)[0]) ?? [] - for (const directive of directives) { + for (const directive of cacheControlDirectives) { let [name, value] = directive.trim().split('=', 2) name = name.toLowerCase() if (!existingDirectives.includes(name)) { @@ -32,6 +45,26 @@ export const cache = (options: { } } } + + if (varyDirectives) { + const existingDirectives = + c.res.headers + .get('Vary') + ?.split(',') + .map((d) => d.trim()) ?? [] + + const vary = Array.from( + new Set( + [...existingDirectives, ...varyDirectives].map((directive) => directive.toLowerCase()) + ) + ).sort() + + if (vary.includes('*')) { + c.header('Vary', '*') + } else { + c.header('Vary', vary.join(', ')) + } + } } return async function cache(c, next) { diff --git a/src/middleware/cors/index.ts b/src/middleware/cors/index.ts index 716058f8c..415f69d03 100644 --- a/src/middleware/cors/index.ts +++ b/src/middleware/cors/index.ts @@ -1,7 +1,8 @@ +import type { Context } from '../../context' import type { MiddlewareHandler } from '../../types' type CORSOptions = { - origin: string | string[] | ((origin: string) => string | undefined | null) + origin: string | string[] | ((origin: string, c: Context) => string | undefined | null) allowMethods?: string[] allowHeaders?: string[] maxAge?: number @@ -36,7 +37,7 @@ export const cors = (options?: CORSOptions): MiddlewareHandler => { c.res.headers.set(key, value) } - const allowOrigin = findAllowOrigin(c.req.header('origin') || '') + const allowOrigin = findAllowOrigin(c.req.header('origin') || '', c) if (allowOrigin) { set('Access-Control-Allow-Origin', allowOrigin) } diff --git a/src/middleware/jwt/index.test.ts b/src/middleware/jwt/index.test.ts index 32e9f10ef..9f2130c5d 100644 --- a/src/middleware/jwt/index.test.ts +++ b/src/middleware/jwt/index.test.ts @@ -1,4 +1,5 @@ import { Hono } from '../../hono' +import { HTTPException } from '../../http-exception' import { jwt } from '.' describe('JWT', () => { @@ -214,4 +215,30 @@ describe('JWT', () => { expect(handlerExecuted).toBeFalsy() }) }) + + describe('Error handling with `cause`', () => { + const app = new Hono() + + app.use('/auth/*', jwt({ secret: 'a-secret' })) + app.get('/auth/*', (c) => c.text('Authorized')) + + app.onError((e, c) => { + if (e instanceof HTTPException && e.cause instanceof Error) { + return c.json({ name: e.cause.name, message: e.cause.message }, 401) + } + return c.text(e.message, 401) + }) + + it('Should not authorize', async () => { + const credential = 'abc.def.ghi' + const req = new Request('http://localhost/auth') + req.headers.set('Authorization', `Bearer ${credential}`) + const res = await app.request(req) + expect(res.status).toBe(401) + expect(await res.json()).toEqual({ + name: 'JwtTokenInvalid', + message: `invalid JWT token: ${credential}`, + }) + }) + }) }) diff --git a/src/middleware/jwt/index.ts b/src/middleware/jwt/index.ts index 439405add..28a082a30 100644 --- a/src/middleware/jwt/index.ts +++ b/src/middleware/jwt/index.ts @@ -3,8 +3,8 @@ import { getCookie } from '../../helper/cookie' import { HTTPException } from '../../http-exception' import type { MiddlewareHandler } from '../../types' import { Jwt } from '../../utils/jwt' -import type { AlgorithmTypes } from '../../utils/jwt/types' import '../../context' +import type { SignatureAlgorithm } from '../../utils/jwt/jwa' declare module '../../context' { interface ContextVariableMap { @@ -16,7 +16,7 @@ declare module '../../context' { export const jwt = (options: { secret: string cookie?: string - alg?: string + alg?: SignatureAlgorithm }): MiddlewareHandler => { if (!options) { throw new Error('JWT auth middleware requires options for "secret') @@ -32,11 +32,13 @@ export const jwt = (options: { if (credentials) { const parts = credentials.split(/\s+/) if (parts.length !== 2) { + const errDescription = 'invalid credentials structure' throw new HTTPException(401, { + message: errDescription, res: unauthorizedResponse({ ctx, error: 'invalid_request', - errDescription: 'invalid credentials structure', + errDescription, }), }) } else { @@ -47,30 +49,34 @@ export const jwt = (options: { } if (!token) { + const errDescription = 'no authorization included in request' throw new HTTPException(401, { + message: errDescription, res: unauthorizedResponse({ ctx, error: 'invalid_request', - errDescription: 'no authorization included in request', + errDescription, }), }) } let payload - let msg = '' + let cause try { - payload = await Jwt.verify(token, options.secret, options.alg as AlgorithmTypes) + payload = await Jwt.verify(token, options.secret, options.alg) } catch (e) { - msg = `${e}` + cause = e } if (!payload) { throw new HTTPException(401, { + message: 'Unauthorized', res: unauthorizedResponse({ ctx, error: 'invalid_token', - statusText: msg, + statusText: 'Unauthorized', errDescription: 'token verification failure', }), + cause, }) } diff --git a/src/middleware/method-override/index.test.ts b/src/middleware/method-override/index.test.ts new file mode 100644 index 000000000..519aca0f4 --- /dev/null +++ b/src/middleware/method-override/index.test.ts @@ -0,0 +1,182 @@ +import { Hono } from '../../hono' +import { methodOverride } from './index' + +describe('Method Override Middleware', () => { + describe('Form', () => { + const app = new Hono() + app.use('/posts/*', methodOverride({ app })) + app.use('/posts-custom/*', methodOverride({ app, form: 'custom-input-name' })) + app.on(['post', 'delete'], ['/posts', '/posts-custom'], async (c) => { + const form = await c.req.formData() + return c.json({ + method: c.req.method, + message: form.get('message'), + contentType: c.req.header('content-type') ?? '', + }) + }) + + describe('multipart/form-data', () => { + it('Should override POST to DELETE', async () => { + const form = new FormData() + form.append('message', 'Hello') + form.append('_method', 'DELETE') + const res = await app.request('/posts', { + body: form, + method: 'POST', + }) + expect(res.status).toBe(200) + const data = await res.json() + expect(data.method).toBe('DELETE') + expect(data.message).toBe('Hello') + expect(data.contentType).toMatch(/^multipart\/form-data;/) + }) + + it('Should override POST to DELETE - with a custom form input name', async () => { + const form = new FormData() + form.append('message', 'Hello') + form.append('custom-input-name', 'DELETE') + const res = await app.request('/posts-custom', { + body: form, + method: 'POST', + }) + expect(res.status).toBe(200) + const data = await res.json() + expect(data.method).toBe('DELETE') + expect(data.message).toBe('Hello') + expect(data.contentType).toMatch(/^multipart\/form-data;/) + }) + + it('Should override POST to PATCH - not found', async () => { + const form = new FormData() + form.append('message', 'Hello') + form.append('_method', 'PATCH') + const res = await app.request('/posts', { + body: form, + method: 'POST', + }) + expect(res.status).toBe(404) + }) + }) + + describe('application/x-www-form-urlencoded', () => { + it('Should override POST to DELETE', async () => { + const params = new URLSearchParams() + params.append('message', 'Hello') + params.append('_method', 'DELETE') + const res = await app.request('/posts', { + body: params, + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + method: 'POST', + }) + expect(res.status).toBe(200) + const data = await res.json() + expect(data.method).toBe('DELETE') + expect(data.message).toBe('Hello') + expect(data.contentType).toBe('application/x-www-form-urlencoded') + }) + + it('Should override POST to DELETE - with a custom form input name', async () => { + const params = new URLSearchParams() + params.append('message', 'Hello') + params.append('custom-input-name', 'DELETE') + const res = await app.request('/posts-custom', { + body: params, + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + method: 'POST', + }) + expect(res.status).toBe(200) + const data = await res.json() + expect(data.method).toBe('DELETE') + expect(data.message).toBe('Hello') + expect(data.contentType).toBe('application/x-www-form-urlencoded') + }) + + it('Should override POST to PATCH - not found', async () => { + const form = new FormData() + form.append('message', 'Hello') + form.append('_method', 'PATCH') + const res = await app.request('/posts', { + body: form, + method: 'POST', + }) + expect(res.status).toBe(404) + }) + }) + }) + + describe('Header', () => { + const app = new Hono() + app.use('/posts/*', methodOverride({ app, header: 'X-METHOD-OVERRIDE' })) + app.on(['get', 'post', 'delete'], '/posts', async (c) => { + return c.json({ + method: c.req.method, + headerValue: c.req.header('X-METHOD-OVERRIDE') ?? null, + }) + }) + + it('Should override POST to DELETE', async () => { + const res = await app.request('/posts', { + method: 'POST', + headers: { + 'X-METHOD-OVERRIDE': 'DELETE', + }, + }) + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + method: 'DELETE', + headerValue: null, + }) + }) + + it('Should not override GET request', async () => { + const res = await app.request('/posts', { + method: 'GET', + headers: { + 'X-METHOD-OVERRIDE': 'DELETE', + }, + }) + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + method: 'GET', + headerValue: 'DELETE', // It does not modify the headers. + }) + }) + }) + + describe('Query', () => { + const app = new Hono() + app.use('/posts/*', methodOverride({ app, query: '_method' })) + app.on(['get', 'post', 'delete'], '/posts', async (c) => { + return c.json({ + method: c.req.method, + queryValue: c.req.query('_method') ?? null, + }) + }) + + it('Should override POST to DELETE', async () => { + const res = await app.request('/posts?_method=delete', { + method: 'POST', + }) + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + method: 'DELETE', + queryValue: null, + }) + }) + + it('Should not override GET request', async () => { + const res = await app.request('/posts?_method=delete', { + method: 'GET', + }) + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + method: 'GET', + queryValue: 'delete', // It does not modify the queries. + }) + }) + }) +}) diff --git a/src/middleware/method-override/index.ts b/src/middleware/method-override/index.ts new file mode 100644 index 000000000..0e942ec29 --- /dev/null +++ b/src/middleware/method-override/index.ts @@ -0,0 +1,134 @@ +import { URLSearchParams } from 'url' +import type { Context } from '../../context' +import type { Hono } from '../../hono' +import type { MiddlewareHandler } from '../../types' +import { parseBody } from '../../utils/body' + +type MethodOverrideOptions = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + app: Hono +} & ( + | { + // Default is 'form' and the value is `_method` + form?: string + header?: never + query?: never + } + | { + form?: never + header: string + query?: never + } + | { + form?: never + header?: never + query: string + } +) + +const DEFAULT_METHOD_FORM_NAME = '_method' + +/** + * Method Override Middleware + * + * @example + * // with form input method + * const app = new Hono() + * app.use('/books/*', methodOverride({ app })) // the default `form` value is `_method` + * app.use('/authors/*', methodOverride({ app, form: 'method' })) + * + * @example + * // with custom header + * app.use('/books/*', methodOverride({ app, header: 'X-HTTP-METHOD-OVERRIDE' })) + * + * @example + * // with query parameter + * app.use('/books/*', methodOverride({ app, query: '_method' })) + */ +export const methodOverride = (options: MethodOverrideOptions): MiddlewareHandler => + async function methodOverride(c, next) { + if (c.req.method === 'GET') { + return await next() + } + + const app = options.app + // Method override by form + if (!(options.header || options.query)) { + const contentType = c.req.header('content-type') + const methodFormName = options.form || DEFAULT_METHOD_FORM_NAME + const clonedRequest = c.req.raw.clone() + const newRequest = clonedRequest.clone() + // Content-Type is `multipart/form-data` + if (contentType?.startsWith('multipart/form-data')) { + const form = await clonedRequest.formData() + const method = form.get(methodFormName) + if (method) { + const newForm = await newRequest.formData() + newForm.delete(methodFormName) + const newHeaders = new Headers(clonedRequest.headers) + newHeaders.delete('content-type') + newHeaders.delete('content-length') + const request = new Request(c.req.url, { + body: newForm, + headers: newHeaders, + method: method as string, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + // Content-Type is `application/x-www-form-urlencoded` + if (contentType === 'application/x-www-form-urlencoded') { + const params = await parseBody>(clonedRequest) + const method = params[methodFormName] + if (method) { + delete params[methodFormName] + const newParams = new URLSearchParams(params) + const request = new Request(newRequest, { + body: newParams, + method: method as string, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + } + // Method override by header + else if (options.header) { + const headerName = options.header + const method = c.req.header(headerName) + if (method) { + const newHeaders = new Headers(c.req.raw.headers) + newHeaders.delete(headerName) + const request = new Request(c.req.raw, { + headers: newHeaders, + method, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + // Method override by query + else if (options.query) { + const queryName = options.query + const method = c.req.query(queryName) + if (method) { + const url = new URL(c.req.url) + url.searchParams.delete(queryName) + const request = new Request(url.toString(), { + body: c.req.raw.body, + headers: c.req.raw.headers, + method, + }) + return app.fetch(request, c.env, getExecutionCtx(c)) + } + } + await next() + } + +const getExecutionCtx = (c: Context) => { + let executionCtx: ExecutionContext | undefined + try { + executionCtx = c.executionCtx + } catch { + // Do nothing + } + return executionCtx +} diff --git a/src/middleware/trailing-slash/index.test.ts b/src/middleware/trailing-slash/index.test.ts new file mode 100644 index 000000000..8b74c3add --- /dev/null +++ b/src/middleware/trailing-slash/index.test.ts @@ -0,0 +1,85 @@ +import { Hono } from '../../hono' +import { trimTrailingSlash, appendTrailingSlash } from '.' + +describe('Resolve trailing slash', () => { + let app: Hono + + it('Trim', async () => { + app = new Hono({ strict: true }) + + app.use('*', trimTrailingSlash()) + + app.get('/', async (c) => { + return c.text('ok') + }) + app.get('/the/example/endpoint/without/trailing/slash', async (c) => { + return c.text('ok') + }) + + let resp: Response, loc: URL + + resp = await app.request('/') + expect(resp).not.toBeNull() + expect(resp.status).toBe(200) + + resp = await app.request('/the/example/endpoint/without/trailing/slash') + expect(resp).not.toBeNull() + expect(resp.status).toBe(200) + + resp = await app.request('/the/example/endpoint/without/trailing/slash/') + loc = new URL(resp.headers.get('location')!) + expect(resp).not.toBeNull() + expect(resp.status).toBe(301) + expect(loc.pathname).toBe('/the/example/endpoint/without/trailing/slash') + + resp = await app.request('/the/example/endpoint/without/trailing/slash/?exampleParam=1') + loc = new URL(resp.headers.get('location')!) + expect(resp).not.toBeNull() + expect(resp.status).toBe(301) + expect(loc.pathname).toBe('/the/example/endpoint/without/trailing/slash') + expect(loc.searchParams.get('exampleParam')).toBe('1') + }) + + it('Append', async () => { + app = new Hono({ strict: true }) + + app.use('*', appendTrailingSlash()) + + app.get('/', async (c) => { + return c.text('ok') + }) + app.get('/the/example/endpoint/with/trailing/slash/', async (c) => { + return c.text('ok') + }) + app.get('/the/example/simulate/a.file', async (c) => { + return c.text('ok') + }) + + let resp: Response, loc: URL + + resp = await app.request('/') + expect(resp).not.toBeNull() + expect(resp.status).toBe(200) + + resp = await app.request('/the/example/simulate/a.file') + expect(resp).not.toBeNull() + expect(resp.status).toBe(200) + + resp = await app.request('/the/example/endpoint/with/trailing/slash/') + expect(resp).not.toBeNull() + expect(resp.status).toBe(200) + + resp = await app.request('/the/example/endpoint/with/trailing/slash') + loc = new URL(resp.headers.get('location')!) + expect(resp).not.toBeNull() + expect(resp.status).toBe(301) + expect(loc.pathname).toBe('/the/example/endpoint/with/trailing/slash/') + + resp = await app.request('/the/example/endpoint/with/trailing/slash?exampleParam=1') + loc = new URL(resp.headers.get('location')!) + expect(resp).not.toBeNull() + expect(resp.status).toBe(301) + expect(loc.pathname).toBe('/the/example/endpoint/with/trailing/slash/') + expect(loc.searchParams.get('exampleParam')).toBe('1') + }) +}) diff --git a/src/middleware/trailing-slash/index.ts b/src/middleware/trailing-slash/index.ts new file mode 100644 index 000000000..ead3975b3 --- /dev/null +++ b/src/middleware/trailing-slash/index.ts @@ -0,0 +1,46 @@ +import type { MiddlewareHandler } from '../../types' + +/** + * Trim the trailing slash from the URL if it does have one. For example, `/path/to/page/` will be redirected to `/path/to/page`. + * @access public + * @example app.use(trimTrailingSlash()) + */ +export const trimTrailingSlash = (): MiddlewareHandler => { + return async function trimTrailingSlash(c, next) { + await next() + + if ( + c.res.status === 404 && + c.req.method === 'GET' && + c.req.path !== '/' && + c.req.path[c.req.path.length - 1] === '/' + ) { + const url = new URL(c.req.url) + url.pathname = url.pathname.substring(0, url.pathname.length - 1) + + c.res = c.redirect(url.toString(), 301) + } + } +} + +/** + * Append a trailing slash to the URL if it doesn't have one. For example, `/path/to/page` will be redirected to `/path/to/page/`. + * @access public + * @example app.use(appendTrailingSlash()) + */ +export const appendTrailingSlash = (): MiddlewareHandler => { + return async function appendTrailingSlash(c, next) { + await next() + + if ( + c.res.status === 404 && + c.req.method === 'GET' && + c.req.path[c.req.path.length - 1] !== '/' + ) { + const url = new URL(c.req.url) + url.pathname += '/' + + c.res = c.redirect(url.toString(), 301) + } + } +} diff --git a/src/request.test.ts b/src/request.test.ts index fb25a11e2..9163f2a38 100644 --- a/src/request.test.ts +++ b/src/request.test.ts @@ -159,16 +159,27 @@ describe('headers', () => { }) }) -describe('Body methods', () => { +const text = '{"foo":"bar"}' +const json = { foo: 'bar' } +const buffer = new TextEncoder().encode('{"foo":"bar"}').buffer + +describe('Body methods with caching', () => { test('req.text()', async () => { const req = new HonoRequest( new Request('http://localhost', { method: 'POST', - body: 'foo', + body: text, + }) + ) + expect(await req.text()).toEqual(text) + expect(await req.text()).toEqual(text) + expect(await req.json()).toEqual(json) + expect(await req.arrayBuffer()).toEqual(buffer) + expect(await req.blob()).toEqual( + new Blob([text], { + type: 'text/plain;charset=utf-8', }) ) - expect(await req.text()).toBe('foo') - expect(await req.text()).toBe('foo') // Should be cached }) test('req.json()', async () => { @@ -178,12 +189,19 @@ describe('Body methods', () => { body: '{"foo":"bar"}', }) ) - expect(await req.json()).toEqual({ foo: 'bar' }) - expect(await req.json()).toEqual({ foo: 'bar' }) // Should be cached + expect(await req.json()).toEqual(json) + expect(await req.json()).toEqual(json) + expect(await req.text()).toEqual(text) + expect(await req.arrayBuffer()).toEqual(buffer) + expect(await req.blob()).toEqual( + new Blob([text], { + type: 'text/plain;charset=utf-8', + }) + ) }) test('req.arrayBuffer()', async () => { - const buffer = new ArrayBuffer(8) + const buffer = new TextEncoder().encode('{"foo":"bar"}').buffer const req = new HonoRequest( new Request('http://localhost', { method: 'POST', @@ -191,12 +209,19 @@ describe('Body methods', () => { }) ) expect(await req.arrayBuffer()).toEqual(buffer) - expect(await req.arrayBuffer()).toEqual(buffer) // Should be cached + expect(await req.arrayBuffer()).toEqual(buffer) + expect(await req.text()).toEqual(text) + expect(await req.json()).toEqual(json) + expect(await req.blob()).toEqual( + new Blob([text], { + type: '', + }) + ) }) test('req.blob()', async () => { - const blob = new Blob(['foo'], { - type: 'text/plain', + const blob = new Blob(['{"foo":"bar"}'], { + type: 'application/json', }) const req = new HonoRequest( new Request('http://localhost', { @@ -205,7 +230,10 @@ describe('Body methods', () => { }) ) expect(await req.blob()).toEqual(blob) - expect(await req.blob()).toEqual(blob) // Should be cached + expect(await req.blob()).toEqual(blob) + expect(await req.text()).toEqual(text) + expect(await req.json()).toEqual(json) + expect(await req.arrayBuffer()).toEqual(buffer) }) test('req.formData()', async () => { @@ -218,6 +246,9 @@ describe('Body methods', () => { }) ) expect((await req.formData()).get('foo')).toBe('bar') - expect((await req.formData()).get('foo')).toBe('bar') // Should be cached + expect((await req.formData()).get('foo')).toBe('bar') + expect(async () => await req.text()).not.toThrow() + expect(async () => await req.arrayBuffer()).not.toThrow() + expect(async () => await req.blob()).not.toThrow() }) }) diff --git a/src/request.ts b/src/request.ts index c82584e2e..5953a04ff 100644 --- a/src/request.ts +++ b/src/request.ts @@ -194,18 +194,27 @@ export class HonoRequest

{ private cachedBody = (key: keyof Body) => { const { bodyCache, raw } = this const cachedBody = bodyCache[key] + if (cachedBody) { return cachedBody } - /** - * If an arrayBuffer cache is exist, - * use it for creating a text, json, and others. - */ - if (bodyCache.arrayBuffer) { - return (async () => { - return await new Response(bodyCache.arrayBuffer)[key]() - })() + + if (!bodyCache[key]) { + for (const keyOfBodyCache of Object.keys(bodyCache)) { + if (keyOfBodyCache === 'parsedBody') { + continue + } + return (async () => { + // @ts-expect-error bodyCache[keyOfBodyCache] can be passed as a body + let body = await bodyCache[keyOfBodyCache] + if (keyOfBodyCache === 'json') { + body = JSON.stringify(body) + } + return await new Response(body)[key]() + })() + } } + return (bodyCache[key] = raw[key]()) } diff --git a/src/test-utils/setup-vitest.ts b/src/test-utils/setup-vitest.ts index c14b2fc07..eef20bb04 100644 --- a/src/test-utils/setup-vitest.ts +++ b/src/test-utils/setup-vitest.ts @@ -1,11 +1,14 @@ // @denoify-ignore -import crypto from 'node:crypto' +import * as nodeCrypto from 'node:crypto' import { vi } from 'vitest' /** * crypto */ -vi.stubGlobal('crypto', crypto) +if (!globalThis.crypto) { + vi.stubGlobal('crypto', nodeCrypto) + vi.stubGlobal('CryptoKey', nodeCrypto.webcrypto.CryptoKey) +} /** * Cache API diff --git a/src/utils/jwt/index.ts b/src/utils/jwt/index.ts index 1c887fe82..5f76e3e2e 100644 --- a/src/utils/jwt/index.ts +++ b/src/utils/jwt/index.ts @@ -1 +1,2 @@ -export * as Jwt from './jwt' +import { sign, verify, decode } from './jwt' +export const Jwt = { sign, verify, decode } diff --git a/src/utils/jwt/types.test.ts b/src/utils/jwt/jwa.test.ts similarity index 65% rename from src/utils/jwt/types.test.ts rename to src/utils/jwt/jwa.test.ts index 9f0d7c3b0..bb3802f26 100644 --- a/src/utils/jwt/types.test.ts +++ b/src/utils/jwt/jwa.test.ts @@ -1,10 +1,14 @@ -import { AlgorithmTypes } from './types' +import { AlgorithmTypes } from './jwa' describe('Types', () => { it('AlgorithmTypes', () => { expect('HS256' as AlgorithmTypes).toBe(AlgorithmTypes.HS256) expect('HS384' as AlgorithmTypes).toBe(AlgorithmTypes.HS384) expect('HS512' as AlgorithmTypes).toBe(AlgorithmTypes.HS512) + expect('RS256' as AlgorithmTypes).toBe(AlgorithmTypes.RS256) + expect('RS384' as AlgorithmTypes).toBe(AlgorithmTypes.RS384) + expect('RS512' as AlgorithmTypes).toBe(AlgorithmTypes.RS512) + // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore expect(undefined as AlgorithmTypes).toBe(undefined) diff --git a/src/utils/jwt/jwa.ts b/src/utils/jwt/jwa.ts new file mode 100644 index 000000000..8512b811c --- /dev/null +++ b/src/utils/jwt/jwa.ts @@ -0,0 +1,20 @@ +// JSON Web Algorithms (JWA) +// https://datatracker.ietf.org/doc/html/rfc7518 + +export enum AlgorithmTypes { + HS256 = 'HS256', + HS384 = 'HS384', + HS512 = 'HS512', + RS256 = 'RS256', + RS384 = 'RS384', + RS512 = 'RS512', + PS256 = 'PS256', + PS384 = 'PS384', + PS512 = 'PS512', + ES256 = 'ES256', + ES384 = 'ES384', + ES512 = 'ES512', + EdDSA = 'EdDSA', +} + +export type SignatureAlgorithm = keyof typeof AlgorithmTypes diff --git a/src/utils/jwt/jws.ts b/src/utils/jwt/jws.ts new file mode 100644 index 000000000..05f84ab81 --- /dev/null +++ b/src/utils/jwt/jws.ts @@ -0,0 +1,224 @@ +import { getRuntimeKey } from '../../helper' +import { decodeBase64 } from '../encode' +import type { SignatureAlgorithm } from './jwa' +import { JwtAlgorithmNotImplemented } from './types' +import { CryptoKeyUsage } from './types' +import { utf8Encoder } from './utf8' + +// JSON Web Signature (JWS) +// https://datatracker.ietf.org/doc/html/rfc7515 + +type KeyImporterAlgorithm = Parameters[2] +type KeyAlgorithm = + | AlgorithmIdentifier + | RsaHashedImportParams + | (RsaPssParams & RsaHashedImportParams) + | (EcdsaParams & EcKeyImportParams) + | HmacImportParams + +export type SignatureKey = string | JsonWebKey | CryptoKey + +export async function signing( + privateKey: SignatureKey, + alg: SignatureAlgorithm, + data: BufferSource +): Promise { + const algorithm = getKeyAlgorithm(alg) + const cryptoKey = await importPrivateKey(privateKey, algorithm) + return await crypto.subtle.sign(algorithm, cryptoKey, data) +} + +export async function verifying( + publicKey: SignatureKey, + alg: SignatureAlgorithm, + signature: BufferSource, + data: BufferSource +): Promise { + const algorithm = getKeyAlgorithm(alg) + const cryptoKey = await importPublicKey(publicKey, algorithm) + return await crypto.subtle.verify(algorithm, cryptoKey, signature, data) +} + +function pemToBinary(pem: string): Uint8Array { + return decodeBase64(pem.replace(/-+(BEGIN|END).*/g, '').replace(/\s/g, '')) +} + +async function importPrivateKey(key: SignatureKey, alg: KeyImporterAlgorithm): Promise { + if (!crypto.subtle || !crypto.subtle.importKey) { + throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') + } + if (isCryptoKey(key)) { + if (key.type !== 'private') { + throw new Error(`unexpected non private key: CryptoKey.type is ${key.type}`) + } + return key + } + const usages = [CryptoKeyUsage.Sign] + if (typeof key === 'object') { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#json_web_key_import + return await crypto.subtle.importKey('jwk', key, alg, false, usages) + } + if (key.includes('PRIVATE')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#pkcs_8_import + return await crypto.subtle.importKey('pkcs8', pemToBinary(key), alg, false, usages) + } + return await crypto.subtle.importKey('raw', utf8Encoder.encode(key), alg, false, usages) +} + +async function importPublicKey(key: SignatureKey, alg: KeyImporterAlgorithm): Promise { + if (!crypto.subtle || !crypto.subtle.importKey) { + throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') + } + if (isCryptoKey(key)) { + if (key.type === 'public' || key.type === 'secret') { + return key + } + key = await exportPublicJwkFrom(key) + } + if (typeof key === 'string' && key.includes('PRIVATE')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#pkcs_8_import + const privateKey = await crypto.subtle.importKey('pkcs8', pemToBinary(key), alg, true, [ + CryptoKeyUsage.Sign, + ]) + key = await exportPublicJwkFrom(privateKey) + } + const usages = [CryptoKeyUsage.Verify] + if (typeof key === 'object') { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#json_web_key_import + return await crypto.subtle.importKey('jwk', key, alg, false, usages) + } + if (key.includes('PUBLIC')) { + // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/importKey#subjectpublickeyinfo_import + return await crypto.subtle.importKey('spki', pemToBinary(key), alg, false, usages) + } + return await crypto.subtle.importKey('raw', utf8Encoder.encode(key), alg, false, usages) +} + +// https://datatracker.ietf.org/doc/html/rfc7517 +async function exportPublicJwkFrom(privateKey: CryptoKey): Promise { + if (privateKey.type !== 'private') { + throw new Error(`unexpected key type: ${privateKey.type}`) + } + if (!privateKey.extractable) { + throw new Error('unexpected private key is unextractable') + } + const jwk = await crypto.subtle.exportKey('jwk', privateKey) + const { kty } = jwk // common + const { alg, e, n } = jwk // rsa + const { crv, x, y } = jwk // elliptic-curve + return { kty, alg, e, n, crv, x, y, key_ops: [CryptoKeyUsage.Verify] } +} + +function getKeyAlgorithm(name: SignatureAlgorithm): KeyAlgorithm { + switch (name) { + case 'HS256': + return { + name: 'HMAC', + hash: { + name: 'SHA-256', + }, + } satisfies HmacImportParams + case 'HS384': + return { + name: 'HMAC', + hash: { + name: 'SHA-384', + }, + } satisfies HmacImportParams + case 'HS512': + return { + name: 'HMAC', + hash: { + name: 'SHA-512', + }, + } satisfies HmacImportParams + case 'RS256': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-256', + }, + } satisfies RsaHashedImportParams + case 'RS384': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-384', + }, + } satisfies RsaHashedImportParams + case 'RS512': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: { + name: 'SHA-512', + }, + } satisfies RsaHashedImportParams + case 'PS256': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-256', + }, + saltLength: 32, // 256 >> 3 + } satisfies RsaPssParams & RsaHashedImportParams + case 'PS384': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-384', + }, + saltLength: 48, // 384 >> 3 + } satisfies RsaPssParams & RsaHashedImportParams + case 'PS512': + return { + name: 'RSA-PSS', + hash: { + name: 'SHA-512', + }, + saltLength: 64, // 512 >> 3, + } satisfies RsaPssParams & RsaHashedImportParams + case 'ES256': + return { + name: 'ECDSA', + hash: { + name: 'SHA-256', + }, + namedCurve: 'P-256', + } satisfies EcdsaParams & EcKeyImportParams + case 'ES384': + return { + name: 'ECDSA', + hash: { + name: 'SHA-384', + }, + namedCurve: 'P-384', + } satisfies EcdsaParams & EcKeyImportParams + case 'ES512': + return { + name: 'ECDSA', + hash: { + name: 'SHA-512', + }, + namedCurve: 'P-521', + } satisfies EcdsaParams & EcKeyImportParams + case 'EdDSA': + // Currently, supported only Safari and Deno, Node.js. + // See: https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/verify + return { + name: 'Ed25519', + namedCurve: 'Ed25519', + } + default: + throw new JwtAlgorithmNotImplemented(name) + } +} + +function isCryptoKey(key: SignatureKey): key is CryptoKey { + const runtime = getRuntimeKey() + // @ts-expect-error CryptoKey hasn't exported to global in node v18 + if (runtime === 'node' && !!crypto.webcrypto) { + // @ts-expect-error CryptoKey hasn't exported to global in node v18 + return key instanceof crypto.webcrypto.CryptoKey + } + return key instanceof CryptoKey +} diff --git a/src/utils/jwt/jwt.test.ts b/src/utils/jwt/jwt.test.ts index 26b0b2837..ac1447c32 100644 --- a/src/utils/jwt/jwt.test.ts +++ b/src/utils/jwt/jwt.test.ts @@ -1,8 +1,9 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import { vi } from 'vitest' +import { encodeBase64 } from '../encode' +import { AlgorithmTypes } from './jwa' import * as JWT from './jwt' import { - AlgorithmTypes, JwtAlgorithmNotImplemented, JwtTokenExpired, JwtTokenInvalid, @@ -11,6 +12,26 @@ import { JwtTokenSignatureMismatched, } from './types' +describe('isTokenHeader', () => { + it('should return true for valid TokenHeader', () => { + const validTokenHeader: JWT.TokenHeader = { + alg: AlgorithmTypes.HS256, + typ: 'JWT', + } + + expect(JWT.isTokenHeader(validTokenHeader)).toBe(true) + }) + + it('should return false for invalid TokenHeader', () => { + const invalidTokenHeader = { + alg: 'invalid', + typ: 'JWT', + } + + expect(JWT.isTokenHeader(invalidTokenHeader)).toBe(false) + }) +}) + describe('JWT', () => { it('JwtAlgorithmNotImplemented', async () => { const payload = { message: 'hello world' } @@ -61,7 +82,7 @@ describe('JWT', () => { it('JwtTokenExpired', async () => { const tok = - 'eyJraWQiOiJFemF6bVZWbnd0TUpUNEFveFVtT0dILWJ0Y2VUVFM3djBYcEJuMm5ZZ2VjIiwiYWxnIjoiSFMyNTYifQ.eyJyb2xlIjoiYXBpX3JvbGUiLCJleHAiOjE2MzMwNDY0MDB9.Gmq_dozOnwzqkMUMEm7uny7cMZuF1d0QkCnmRXAbTEk' + 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE2MzMwNDYxMDAsImV4cCI6MTYzMzA0NjQwMH0.H-OI1TWAbmK8RonvcpPaQcNvOKS9sxinEOsgKwjoiVo' const secret = 'a-secret' let err let authorized @@ -196,4 +217,264 @@ describe('JWT', () => { expect(authorized).toBeUndefined() expect(err instanceof JwtTokenSignatureMismatched).toBe(true) }) + + const rsTestCases = [ + { + alg: AlgorithmTypes.RS256, + hash: 'SHA-256', + }, + { + alg: AlgorithmTypes.RS384, + hash: 'SHA-384', + }, + { + alg: AlgorithmTypes.RS512, + hash: 'SHA-512', + }, + ] + for (const tc of rsTestCases) { + it(`${tc.alg} sign & verify`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateRSAKey(tc.hash) + const pemPrivateKey = await exportPEMPrivateKey(keyPair.privateKey) + const pemPublicKey = await exportPEMPublicKey(keyPair.publicKey) + const jwkPublicKey = await exportJWK(keyPair.publicKey) + + const tok = await JWT.sign(payload, pemPrivateKey, alg) + expect(await JWT.verify(tok, pemPublicKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, pemPrivateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, jwkPublicKey, alg)).toEqual(payload) + + const keyPair2 = await generateRSAKey(tc.hash) + const unexpectedPemPublicKey = await exportPEMPublicKey(keyPair2.publicKey) + + let err = null + let authorized + try { + authorized = await JWT.verify(tok, unexpectedPemPublicKey, alg) + } catch (e) { + err = e + } + expect(authorized).toBeUndefined() + expect(err instanceof JwtTokenSignatureMismatched).toBe(true) + }) + + it(`${tc.alg} sign & verify w/ CryptoKey`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateRSAKey(tc.hash) + + const tok = await JWT.sign(payload, keyPair.privateKey, alg) + expect(await JWT.verify(tok, keyPair.privateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, keyPair.publicKey, alg)).toEqual(payload) + }) + } + + const psTestCases = [ + { + alg: AlgorithmTypes.PS256, + hash: 'SHA-256', + }, + { + alg: AlgorithmTypes.PS384, + hash: 'SHA-384', + }, + { + alg: AlgorithmTypes.PS512, + hash: 'SHA-512', + }, + ] + for (const tc of psTestCases) { + it(`${tc.alg} sign & verify`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateRSAPSSKey(tc.hash) + const pemPrivateKey = await exportPEMPrivateKey(keyPair.privateKey) + const pemPublicKey = await exportPEMPublicKey(keyPair.publicKey) + const jwkPublicKey = await exportJWK(keyPair.publicKey) + + const tok = await JWT.sign(payload, pemPrivateKey, alg) + expect(await JWT.verify(tok, pemPublicKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, pemPrivateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, jwkPublicKey, alg)).toEqual(payload) + + const keyPair2 = await generateRSAPSSKey(tc.hash) + const unexpectedPemPublicKey = await exportPEMPublicKey(keyPair2.publicKey) + + let err = null + let authorized + try { + authorized = await JWT.verify(tok, unexpectedPemPublicKey, alg) + } catch (e) { + err = e + } + expect(authorized).toBeUndefined() + expect(err instanceof JwtTokenSignatureMismatched).toBe(true) + }) + + it(`${tc.alg} sign & verify w/ CryptoKey`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateRSAPSSKey(tc.hash) + + const tok = await JWT.sign(payload, keyPair.privateKey, alg) + expect(await JWT.verify(tok, keyPair.privateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, keyPair.publicKey, alg)).toEqual(payload) + }) + } + + const esTestCases = [ + { + alg: AlgorithmTypes.ES256, + namedCurve: 'P-256', + }, + { + alg: AlgorithmTypes.ES384, + namedCurve: 'P-384', + }, + { + alg: AlgorithmTypes.ES512, + namedCurve: 'P-521', + }, + ] + for (const tc of esTestCases) { + it(`${tc.alg} sign & verify`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateECDSAKey(tc.namedCurve) + const pemPrivateKey = await exportPEMPrivateKey(keyPair.privateKey) + const pemPublicKey = await exportPEMPublicKey(keyPair.publicKey) + const jwkPublicKey = await exportJWK(keyPair.publicKey) + + const tok = await JWT.sign(payload, pemPrivateKey, alg) + expect(await JWT.verify(tok, pemPublicKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, pemPrivateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, jwkPublicKey, alg)).toEqual(payload) + + const keyPair2 = await generateECDSAKey(tc.namedCurve) + const unexpectedPemPublicKey = await exportPEMPublicKey(keyPair2.publicKey) + + let err = null + let authorized + try { + authorized = await JWT.verify(tok, unexpectedPemPublicKey, alg) + } catch (e) { + err = e + } + expect(authorized).toBeUndefined() + expect(err instanceof JwtTokenSignatureMismatched).toBe(true) + }) + + it(`${tc.alg} sign & verify w/ CryptoKey`, async () => { + const alg = tc.alg + const payload = { message: 'hello world' } + const keyPair = await generateECDSAKey(tc.namedCurve) + + const tok = await JWT.sign(payload, keyPair.privateKey, alg) + expect(await JWT.verify(tok, keyPair.privateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, keyPair.publicKey, alg)).toEqual(payload) + }) + } + + it('EdDSA sign & verify', async () => { + const alg = 'EdDSA' + const payload = { message: 'hello world' } + const keyPair = await generateEd25519Key() + const pemPrivateKey = await exportPEMPrivateKey(keyPair.privateKey) + const pemPublicKey = await exportPEMPublicKey(keyPair.publicKey) + const jwkPublicKey = await exportJWK(keyPair.publicKey) + + const tok = await JWT.sign(payload, pemPrivateKey, alg) + expect(await JWT.verify(tok, pemPublicKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, pemPrivateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, jwkPublicKey, alg)).toEqual(payload) + + const keyPair2 = await generateEd25519Key() + const unexpectedPemPublicKey = await exportPEMPublicKey(keyPair2.publicKey) + + let err = null + let authorized + try { + authorized = await JWT.verify(tok, unexpectedPemPublicKey, alg) + } catch (e) { + err = e + } + expect(authorized).toBeUndefined() + expect(err instanceof JwtTokenSignatureMismatched).toBe(true) + }) + + it('EdDSA sign & verify w/ CryptoKey', async () => { + const alg = 'EdDSA' + const payload = { message: 'hello world' } + const keyPair = await generateEd25519Key() + + const tok = await JWT.sign(payload, keyPair.privateKey, alg) + expect(await JWT.verify(tok, keyPair.privateKey, alg)).toEqual(payload) + expect(await JWT.verify(tok, keyPair.publicKey, alg)).toEqual(payload) + }) }) + +async function exportPEMPrivateKey(key: CryptoKey): Promise { + const exported = await crypto.subtle.exportKey('pkcs8', key) + const pem = `-----BEGIN PRIVATE KEY-----\n${encodeBase64(exported)}\n-----END PRIVATE KEY-----` + return pem +} + +async function exportPEMPublicKey(key: CryptoKey): Promise { + const exported = await crypto.subtle.exportKey('spki', key) + const pem = `-----BEGIN PUBLIC KEY-----\n${encodeBase64(exported)}\n-----END PUBLIC KEY-----` + return pem +} + +async function exportJWK(key: CryptoKey): Promise { + return await crypto.subtle.exportKey('jwk', key) +} + +async function generateRSAKey(hash: string): Promise { + return await crypto.subtle.generateKey( + { + hash, + modulusLength: 2048, + publicExponent: new Uint8Array([1, 0, 1]), + name: 'RSASSA-PKCS1-v1_5', + }, + true, + ['sign', 'verify'] + ) +} + +async function generateRSAPSSKey(hash: string): Promise { + return await crypto.subtle.generateKey( + { + hash, + modulusLength: 2048, + publicExponent: new Uint8Array([1, 0, 1]), + name: 'RSA-PSS', + }, + true, + ['sign', 'verify'] + ) +} + +async function generateECDSAKey(namedCurve: string): Promise { + return await crypto.subtle.generateKey( + { + name: 'ECDSA', + namedCurve, + }, + true, + ['sign', 'verify'] + ) +} + +async function generateEd25519Key(): Promise { + return await crypto.subtle.generateKey( + { + name: 'Ed25519', + namedCurve: 'Ed25519', + }, + true, + ['sign', 'verify'] + ) +} diff --git a/src/utils/jwt/jwt.ts b/src/utils/jwt/jwt.ts index d7eba1787..c25d95b78 100644 --- a/src/utils/jwt/jwt.ts +++ b/src/utils/jwt/jwt.ts @@ -1,44 +1,17 @@ import { encodeBase64Url, decodeBase64Url } from '../../utils/encode' -import type { AlgorithmTypes } from './types' -import { JwtTokenIssuedAt } from './types' +import type { SignatureAlgorithm } from './jwa' +import { AlgorithmTypes } from './jwa' +import type { SignatureKey } from './jws' +import { signing, verifying } from './jws' +import { JwtHeaderInvalid, type JWTPayload } from './types' import { JwtTokenInvalid, JwtTokenNotBefore, JwtTokenExpired, JwtTokenSignatureMismatched, - JwtAlgorithmNotImplemented, + JwtTokenIssuedAt, } from './types' - -interface AlgorithmParams { - name: string - namedCurve?: string - hash?: { - name: string - } -} - -enum CryptoKeyFormat { - RAW = 'raw', - PKCS8 = 'pkcs8', - SPKI = 'spki', - JWK = 'jwk', -} - -enum CryptoKeyUsage { - Ecrypt = 'encrypt', - Decrypt = 'decrypt', - Sign = 'sign', - Verify = 'verify', - Deriverkey = 'deriveKey', - DeriveBits = 'deriveBits', - WrapKey = 'wrapKey', - UnwrapKey = 'unwrapKey', -} - -type AlgorithmTypeName = keyof typeof AlgorithmTypes - -const utf8Encoder = new TextEncoder() -const utf8Decoder = new TextDecoder() +import { utf8Decoder, utf8Encoder } from './utf8' const encodeJwtPart = (part: unknown): string => encodeBase64Url(utf8Encoder.encode(JSON.stringify(part))).replace(/=/g, '') @@ -47,65 +20,34 @@ const encodeSignaturePart = (buf: ArrayBufferLike): string => encodeBase64Url(bu const decodeJwtPart = (part: string): unknown => JSON.parse(utf8Decoder.decode(decodeBase64Url(part))) -const param = (name: AlgorithmTypeName): AlgorithmParams => { - switch (name.toUpperCase()) { - case 'HS256': - return { - name: 'HMAC', - hash: { - name: 'SHA-256', - }, - } - case 'HS384': - return { - name: 'HMAC', - hash: { - name: 'SHA-384', - }, - } - case 'HS512': - return { - name: 'HMAC', - hash: { - name: 'SHA-512', - }, - } - default: - throw new JwtAlgorithmNotImplemented(name) - } +export interface TokenHeader { + alg: SignatureAlgorithm + typ: 'JWT' } -const signing = async ( - data: string, - secret: string, - alg: AlgorithmTypeName = 'HS256' -): Promise => { - if (!crypto.subtle || !crypto.subtle.importKey) { - throw new Error('`crypto.subtle.importKey` is undefined. JWT auth middleware requires it.') - } - - const utf8Encoder = new TextEncoder() - const cryptoKey = await crypto.subtle.importKey( - CryptoKeyFormat.RAW, - utf8Encoder.encode(secret), - param(alg), - false, - [CryptoKeyUsage.Sign] +// eslint-disable-next-line +export function isTokenHeader(obj: any): obj is TokenHeader { + return ( + typeof obj === 'object' && + obj !== null && + 'alg' in obj && + Object.values(AlgorithmTypes).includes(obj.alg) && + 'typ' in obj && + obj.typ === 'JWT' ) - return await crypto.subtle.sign(param(alg), cryptoKey, utf8Encoder.encode(data)) } export const sign = async ( - payload: unknown, - secret: string, - alg: AlgorithmTypeName = 'HS256' + payload: JWTPayload, + privateKey: SignatureKey, + alg: SignatureAlgorithm = 'HS256' ): Promise => { const encodedPayload = encodeJwtPart(payload) - const encodedHeader = encodeJwtPart({ alg, typ: 'JWT' }) + const encodedHeader = encodeJwtPart({ alg, typ: 'JWT' } satisfies TokenHeader) const partialToken = `${encodedHeader}.${encodedPayload}` - const signaturePart = await signing(partialToken, secret, alg) + const signaturePart = await signing(privateKey, alg, utf8Encoder.encode(partialToken)) const signature = encodeSignaturePart(signaturePart) return `${partialToken}.${signature}` @@ -113,8 +55,8 @@ export const sign = async ( export const verify = async ( token: string, - secret: string, - alg: AlgorithmTypeName = 'HS256' + publicKey: SignatureKey, + alg: SignatureAlgorithm = 'HS256' // eslint-disable-next-line @typescript-eslint/no-explicit-any ): Promise => { const tokenParts = token.split('.') @@ -122,7 +64,10 @@ export const verify = async ( throw new JwtTokenInvalid(token) } - const { payload } = decode(token) + const { header, payload } = decode(token) + if (!isTokenHeader(header)) { + throw new JwtHeaderInvalid(header) + } const now = Math.floor(Date.now() / 1000) if (payload.nbf && payload.nbf > now) { throw new JwtTokenNotBefore(token) @@ -134,10 +79,14 @@ export const verify = async ( throw new JwtTokenIssuedAt(now, payload.iat) } - const signaturePart = tokenParts.slice(0, 2).join('.') - const signature = await signing(signaturePart, secret, alg) - const encodedSignature = encodeSignaturePart(signature) - if (encodedSignature !== tokenParts[2]) { + const headerPayload = token.substring(0, token.lastIndexOf('.')) + const verified = await verifying( + publicKey, + alg, + decodeBase64Url(tokenParts[2]), + utf8Encoder.encode(headerPayload) + ) + if (!verified) { throw new JwtTokenSignatureMismatched(token) } diff --git a/src/utils/jwt/types.ts b/src/utils/jwt/types.ts index 2f052e9c1..738afd572 100644 --- a/src/utils/jwt/types.ts +++ b/src/utils/jwt/types.ts @@ -33,6 +33,13 @@ export class JwtTokenIssuedAt extends Error { } } +export class JwtHeaderInvalid extends Error { + constructor(header: object) { + super(`jwt header is invalid: ${JSON.stringify(header)}`) + this.name = 'JwtHeaderInvalid' + } +} + export class JwtTokenSignatureMismatched extends Error { constructor(token: string) { super(`token(${token}) signature mismatched`) @@ -40,8 +47,34 @@ export class JwtTokenSignatureMismatched extends Error { } } -export enum AlgorithmTypes { - HS256 = 'HS256', - HS384 = 'HS384', - HS512 = 'HS512', +export enum CryptoKeyUsage { + Encrypt = 'encrypt', + Decrypt = 'decrypt', + Sign = 'sign', + Verify = 'verify', + DeriveKey = 'deriveKey', + DeriveBits = 'deriveBits', + WrapKey = 'wrapKey', + UnwrapKey = 'unwrapKey', } + +/** + * JWT Payload + */ +export type JWTPayload = + | (unknown & {}) + | { + [key: string]: unknown + /** + * The token is checked to ensure it has not expired. + */ + exp?: number + /** + * The token is checked to ensure it is not being used before a specified time. + */ + nbf?: number + /** + * The token is checked to ensure it is not issued in the future. + */ + iat?: number + } diff --git a/src/utils/jwt/utf8.ts b/src/utils/jwt/utf8.ts new file mode 100644 index 000000000..107407aee --- /dev/null +++ b/src/utils/jwt/utf8.ts @@ -0,0 +1,2 @@ +export const utf8Encoder = new TextEncoder() +export const utf8Decoder = new TextDecoder() diff --git a/src/validator/validator.ts b/src/validator/validator.ts index 9f729c0db..b11745479 100644 --- a/src/validator/validator.ts +++ b/src/validator/validator.ts @@ -60,7 +60,6 @@ export const validator = < return async (c, next) => { let value = {} const contentType = c.req.header('Content-Type') - const bodyTypes = ['text', 'arrayBuffer', 'blob'] switch (target) { case 'json': @@ -68,26 +67,8 @@ export const validator = < const message = `Invalid HTTP header: Content-Type=${contentType}` throw new HTTPException(400, { message }) } - - if (c.req.bodyCache.json) { - value = await c.req.bodyCache.json - break - } - try { - let arrayBuffer: ArrayBuffer | undefined = undefined - for (const type of bodyTypes) { - // @ts-expect-error bodyCache[type] is not typed - const body = c.req.bodyCache[type] - if (body) { - arrayBuffer = await new Response(await body).arrayBuffer() - break - } - } - arrayBuffer ??= await c.req.raw.arrayBuffer() - value = await new Response(arrayBuffer).json() - c.req.bodyCache.json = value - c.req.bodyCache.arrayBuffer = arrayBuffer + value = await c.req.json() } catch { const message = 'Malformed JSON in request body' throw new HTTPException(400, { message }) @@ -104,16 +85,7 @@ export const validator = < } try { - let arrayBuffer: ArrayBuffer | undefined = undefined - for (const type of bodyTypes) { - // @ts-expect-error bodyCache[type] is not typed - const body = c.req.bodyCache[type] - if (body) { - arrayBuffer = await new Response(await body).arrayBuffer() - break - } - } - arrayBuffer ??= await c.req.arrayBuffer() + const arrayBuffer = await c.req.arrayBuffer() const formData = await bufferToFormData(arrayBuffer, contentType) const form: BodyData = {} formData.forEach((value, key) => { @@ -121,7 +93,6 @@ export const validator = < }) value = form c.req.bodyCache.formData = formData - c.req.bodyCache.arrayBuffer = arrayBuffer } catch (e) { let message = 'Malformed FormData request.' message += e instanceof Error ? ` ${e.message}` : ` ${String(e)}`