diff --git a/README.md b/README.md index 6413fa7..019d90f 100644 --- a/README.md +++ b/README.md @@ -365,6 +365,8 @@ The list of all supported types as of now is: - `ZodRecord` - `ZodUnknown` +Extending an instance of `ZodObject` is also supported and results in an OpenApi definition with `allOf` + ### Unsupported types In case you try to create an OpenAPI schema from a zod schema that is not one of the aforementioned types then you'd receive an `UnknownZodTypeError`. diff --git a/spec/lodash.spec.ts b/spec/lodash.spec.ts new file mode 100644 index 0000000..1186697 --- /dev/null +++ b/spec/lodash.spec.ts @@ -0,0 +1,38 @@ +import { objectEquals } from '../src/lib/lodash'; + +describe('Lodash', () => { + describe('objectEquals', () => { + it('can compare plain values', () => { + expect(objectEquals(3, 4)).toEqual(false); + expect(objectEquals(3, 3)).toEqual(true); + + expect(objectEquals('3', '4')).toEqual(false); + expect(objectEquals('3', '3')).toEqual(true); + }); + + it('can compare objects', () => { + expect(objectEquals({ a: 3 }, { b: 3 })).toEqual(false); + expect(objectEquals({ a: 3 }, { a: '3' })).toEqual(false); + + expect(objectEquals({ a: 3 }, { a: 3, b: false })).toEqual(false); + + expect(objectEquals({ a: 3 }, { a: 3 })).toEqual(true); + }); + + it('can compare nested objects', () => { + expect( + objectEquals( + { test: { a: ['asd', 3, true] } }, + { test: { a: ['asd', 3, true, { b: null }] } } + ) + ).toEqual(false); + + expect( + objectEquals( + { test: { a: ['asd', 3, true, { b: null }] } }, + { test: { a: ['asd', 3, true, { b: null }] } } + ) + ).toEqual(true); + }); + }); +}); diff --git a/spec/polymorphism.spec.ts b/spec/polymorphism.spec.ts new file mode 100644 index 0000000..cb99a65 --- /dev/null +++ b/spec/polymorphism.spec.ts @@ -0,0 +1,141 @@ +import * as z from 'zod'; +import { extendZodWithOpenApi } from '../src/zod-extensions'; +import { expectSchema } from './lib/helpers'; + +// TODO: setupTests.ts +extendZodWithOpenApi(z); + +describe('Polymorphism', () => { + it('can use allOf for extended schemas', () => { + const BaseSchema = z.object({ id: z.string() }).openapi({ + refId: 'Base', + }); + + const ExtendedSchema = BaseSchema.extend({ + bonus: z.number(), + }).openapi({ + refId: 'Extended', + }); + + expectSchema([BaseSchema, ExtendedSchema], { + Base: { + type: 'object', + required: ['id'], + properties: { + id: { + type: 'string', + }, + }, + }, + Extended: { + allOf: [ + { $ref: '#/components/schemas/Base' }, + { + type: 'object', + required: ['bonus'], + properties: { + bonus: { + type: 'number', + }, + }, + }, + ], + }, + }); + }); + + it('can apply nullable', () => { + const BaseSchema = z.object({ id: z.ostring() }).openapi({ + refId: 'Base', + }); + + const ExtendedSchema = BaseSchema.extend({ + bonus: z.onumber(), + }) + .nullable() + .openapi({ + refId: 'Extended', + }); + + expectSchema([BaseSchema, ExtendedSchema], { + Base: { + type: 'object', + properties: { + id: { + type: 'string', + }, + }, + }, + Extended: { + allOf: [ + { $ref: '#/components/schemas/Base' }, + { + type: 'object', + properties: { + bonus: { + type: 'number', + }, + }, + nullable: true, + }, + ], + }, + }); + }); + + it('can override properties', () => { + const AnimalSchema = z + .object({ + name: z.ostring(), + type: z.enum(['dog', 'cat']).optional(), + }) + .openapi({ + refId: 'Animal', + discriminator: { + propertyName: 'type', + }, + }); + + const DogSchema = AnimalSchema.extend({ + type: z.string().openapi({ const: 'dog' }), + }).openapi({ + refId: 'Dog', + discriminator: { + propertyName: 'type', + }, + }); + + expectSchema([AnimalSchema, DogSchema], { + Animal: { + discriminator: { + propertyName: 'type', + }, + type: 'object', + properties: { + name: { + type: 'string', + }, + type: { type: 'string', enum: ['dog', 'cat'] }, + }, + }, + Dog: { + discriminator: { + propertyName: 'type', + }, + allOf: [ + { $ref: '#/components/schemas/Animal' }, + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'dog', + }, + }, + required: ['type'], + }, + ], + }, + }); + }); +}); diff --git a/spec/simple.spec.ts b/spec/simple.spec.ts index 1191038..f172982 100644 --- a/spec/simple.spec.ts +++ b/spec/simple.spec.ts @@ -324,6 +324,60 @@ describe('Simple', () => { }); }); + it('supports nullable for registered schemas', () => { + const StringSchema = z.string().openapi({ refId: 'String' }); + + const TestSchema = z + .object({ key: StringSchema.nullable() }) + .openapi({ refId: 'Test' }); + + expectSchema([StringSchema, TestSchema], { + String: { + type: 'string', + }, + Test: { + type: 'object', + properties: { + key: { + allOf: [ + { $ref: '#/components/schemas/String' }, + { nullable: true }, + ], + }, + }, + required: ['key'], + }, + }); + }); + + it('supports .openapi for registered schemas', () => { + const StringSchema = z.string().openapi({ refId: 'String' }); + + const TestSchema = z + .object({ + key: StringSchema.openapi({ example: 'test', deprecated: true }), + }) + .openapi({ refId: 'Test' }); + + expectSchema([StringSchema, TestSchema], { + String: { + type: 'string', + }, + Test: { + type: 'object', + properties: { + key: { + allOf: [ + { $ref: '#/components/schemas/String' }, + { example: 'test', deprecated: true }, + ], + }, + }, + required: ['key'], + }, + }); + }); + describe('defaults', () => { it('supports defaults', () => { expectSchema( diff --git a/src/lib/lodash.ts b/src/lib/lodash.ts index 6a1e335..e132ff6 100644 --- a/src/lib/lodash.ts +++ b/src/lib/lodash.ts @@ -54,3 +54,35 @@ export function omitBy< export function compact(arr: (T | null | undefined)[]) { return arr.filter((elem): elem is T => !isNil(elem)); } + +export function objectEquals(x: any, y: any): boolean { + if (x === null || x === undefined || y === null || y === undefined) { + return x === y; + } + + if (x === y || x.valueOf() === y.valueOf()) { + return true; + } + + if (Array.isArray(x)) { + if (!Array.isArray(y)) { + return false; + } + + if (x.length !== y.length) { + return false; + } + } + + // if they are strictly equal, they both need to be object at least + if (!(x instanceof Object) || !(y instanceof Object)) { + return false; + } + + // recursive object equality check + const keysX = Object.keys(x); + return ( + Object.keys(y).every(keyY => keysX.indexOf(keyY) !== -1) && + keysX.every(key => objectEquals(x[key], y[key])) + ); +} diff --git a/src/openapi-generator.ts b/src/openapi-generator.ts index 9a47500..c059244 100644 --- a/src/openapi-generator.ts +++ b/src/openapi-generator.ts @@ -22,7 +22,14 @@ import type { ZodType, ZodTypeAny, } from 'zod'; -import { compact, isNil, mapValues, omit, omitBy } from './lib/lodash'; +import { + compact, + isNil, + mapValues, + objectEquals, + omit, + omitBy, +} from './lib/lodash'; import { ZodOpenAPIMetadata } from './zod-extensions'; import { OpenAPIComponentObject, @@ -335,9 +342,24 @@ export class OpenAPIGenerator { const refId = metadata?.refId; if (refId && this.schemaRefs[refId]) { - return { + const referenceObject = { $ref: `#/components/schemas/${refId}`, }; + + const nullableMetadata = zodSchema.isNullable() ? { nullable: true } : {}; + + const appliedMetadata = this.applySchemaMetadata( + nullableMetadata, + metadata + ); + + if (Object.keys(appliedMetadata).length > 0) { + return { + allOf: [referenceObject, appliedMetadata], + }; + } + + return referenceObject; } const result = metadata?.type @@ -655,6 +677,8 @@ export class OpenAPIGenerator { zodSchema: ZodObject, isNullable: boolean ): SchemaObject { + const extendedFrom = zodSchema._def.openapi?.extendedFrom; + const propTypes = zodSchema._def.shape(); const unknownKeysOption = zodSchema._unknownKeys as UnknownKeysParam; @@ -662,19 +686,63 @@ export class OpenAPIGenerator { .filter(([_key, type]) => !this.isOptionalSchema(type)) .map(([key, _type]) => key); - return { - type: 'object', + const schemaProperties = mapValues(propTypes, propSchema => + this.generateInnerSchema(propSchema) + ); + + let alreadyRegistered: string[] = []; + let alreadyRequired: string[] = []; - properties: mapValues(propTypes, propSchema => - this.generateInnerSchema(propSchema) - ), + if (extendedFrom) { + const registeredSchema = this.schemaRefs[extendedFrom]; - required: requiredProperties.length > 0 ? requiredProperties : undefined, + if (!registeredSchema) { + throw new Error( + `Attempt to extend an unregistered schema with id ${extendedFrom}.` + ); + } - additionalProperties: unknownKeysOption === 'passthrough' || undefined, + const registeredProperties = registeredSchema.properties ?? {}; - nullable: isNullable ? true : undefined, + alreadyRegistered = Object.keys(registeredProperties).filter(propKey => { + return objectEquals( + schemaProperties[propKey], + registeredProperties[propKey] + ); + }); + + alreadyRequired = registeredSchema.required ?? []; + } + + const properties = omit(schemaProperties, alreadyRegistered); + + const additionallyRequired = requiredProperties.filter( + prop => !alreadyRequired.includes(prop) + ); + + const objectData = { + type: 'object' as const, + + properties, + + ...(isNullable ? { nullable: true } : {}), + + ...(additionallyRequired.length > 0 + ? { required: additionallyRequired } + : {}), + + ...(unknownKeysOption === 'passthrough' + ? { additionalProperties: true } + : {}), }; + + if (extendedFrom) { + return { + allOf: [{ $ref: `#/components/schemas/${extendedFrom}` }, objectData], + }; + } + + return objectData; } private flattenUnionTypes(schema: ZodSchema): ZodSchema[] { @@ -719,7 +787,7 @@ export class OpenAPIGenerator { private buildSchemaMetadata(metadata: ZodOpenAPIMetadata) { // A place to omit all custom keys added to the openapi - return omitBy(omit(metadata, ['param', 'refId']), isNil); + return omitBy(omit(metadata, ['param', 'refId', 'extendedFrom']), isNil); } private buildParameterMetadata( diff --git a/src/zod-extensions.ts b/src/zod-extensions.ts index d8e0c03..f0c9a8c 100644 --- a/src/zod-extensions.ts +++ b/src/zod-extensions.ts @@ -1,8 +1,10 @@ import { ParameterObject, SchemaObject } from 'openapi3-ts'; import type { z } from 'zod'; +import { isZodType } from './lib/zod-is-type'; export interface ZodOpenAPIMetadata extends SchemaObject { refId?: string; + extendedFrom?: string; param?: Partial & { example?: T }; example?: T; } @@ -31,7 +33,7 @@ export function extendZodWithOpenApi(zod: typeof z) { zod.ZodSchema.prototype.openapi = function (openapi) { const { param, ...restOfOpenApi } = openapi ?? {}; - return new (this as any).constructor({ + const result = new (this as any).constructor({ ...this._def, openapi: { ...this._def.openapi, @@ -42,6 +44,23 @@ export function extendZodWithOpenApi(zod: typeof z) { }, }, }); + + if (isZodType(this, 'ZodObject')) { + const initialExtend = this.extend; + + // TODO: This does an overload everytime. So .extend().openapi() makes this change twice + result.extend = function (...args: any) { + const extendedResult = initialExtend.apply(result, args); + + extendedResult._def.openapi = { + extendedFrom: result._def.openapi?.refId, + }; + + return extendedResult; + }; + } + + return result; }; const zodOptional = zod.ZodSchema.prototype.optional as any;