Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend z.discriminatedUnion to support intersections, unions, and discriminated unions of objects as options #3956

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .husky/pre-commit
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/sh
. "$(dirname "$0")/_/husky.sh"

npx lint-staged
yarn build:deno
git add deno
npx lint-staged
72 changes: 72 additions & 0 deletions deno/lib/__tests__/discriminated-unions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,75 @@ test("readonly array of options", () => {
z.discriminatedUnion("type", options).parse({ type: "x", val: 1 })
).toEqual({ type: "x", val: 1 });
});

test("valid - unions and intersections of objects", () => {
const union = z.discriminatedUnion("type", [
z.object({ type: z.literal("a"), a: z.string() }),
z.object({ type: z.literal("b") }).and(z.object({ b: z.string() })),
z
.object({ type: z.literal("c"), c: z.string() })
.or(z.object({ type: z.literal("c"), c: z.number() }))
.or(z.object({ type: z.literal("d"), d: z.string() })),
z
.object({ type: z.literal("e"), e: z.string() })
.and(z.object({ foo: z.string() }).or(z.object({ bar: z.string() }))),
z
.object({ type: z.literal("f"), f: z.string() })
.or(z.object({ type: z.literal("f") }).and(z.object({ f: z.number() }))),
z.discriminatedUnion("foo", [
z.object({ type: z.literal("g"), foo: z.literal("bar") }),
z.object({ type: z.literal("h"), foo: z.literal("baz") }),
]),
z
.object({ type: z.literal("i").or(z.literal("j")) })
.and(
z.object({
type: z.literal("i").or(z.literal("j")).and(z.literal("i")),
})
)
.and(z.object({ type: z.literal("i") }))
.and(z.object({ foo: z.string() })),
]);

expect(union.parse({ type: "a", a: "123" })).toEqual({ type: "a", a: "123" });
expect(union.parse({ type: "b", b: "123" })).toEqual({ type: "b", b: "123" });
expect(union.parse({ type: "c", c: "123" })).toEqual({ type: "c", c: "123" });
expect(union.parse({ type: "c", c: 123 })).toEqual({ type: "c", c: 123 });
expect(union.parse({ type: "d", d: "123" })).toEqual({ type: "d", d: "123" });
expect(() => {
union.parse({ type: "d", c: "123" });
}).toThrow();
expect(union.parse({ type: "e", e: "123", foo: "456" })).toEqual({
type: "e",
e: "123",
foo: "456",
});
expect(union.parse({ type: "e", e: "123", bar: "456" })).toEqual({
type: "e",
e: "123",
bar: "456",
});
expect(() => {
union.parse({ type: "e", e: "123" });
}).toThrow();
expect(union.parse({ type: "f", f: "123" })).toEqual({ type: "f", f: "123" });
expect(union.parse({ type: "f", f: 123 })).toEqual({ type: "f", f: 123 });
expect(union.parse({ type: "g", foo: "bar" })).toEqual({
type: "g",
foo: "bar",
});
expect(union.parse({ type: "h", foo: "baz" })).toEqual({
type: "h",
foo: "baz",
});
expect(() => {
union.parse({ type: "h", foo: "bar" });
}).toThrow();
expect(union.parse({ type: "i", foo: "123" })).toEqual({
type: "i",
foo: "123",
});
expect(() => {
union.parse({ type: "j", foo: "123" });
}).toThrow();
});
178 changes: 151 additions & 27 deletions deno/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ import {
} from "./helpers/parseUtil.ts";
import { partialUtil } from "./helpers/partialUtil.ts";
import { Primitive } from "./helpers/typeAliases.ts";
import { getParsedType, objectUtil, util, ZodParsedType } from "./helpers/util.ts";
import {
getParsedType,
objectUtil,
util,
ZodParsedType,
} from "./helpers/util.ts";
import type { StandardSchemaV1 } from "./standard-schema.ts";
import {
IssueData,
Expand Down Expand Up @@ -3075,7 +3080,9 @@ export type AnyZodObject = ZodObject<any, any, any>;
////////// //////////
////////////////////////////////////////
////////////////////////////////////////
export type ZodUnionOptions = Readonly<[ZodTypeAny, ...ZodTypeAny[]]>;
export type ZodUnionOptions<T extends ZodTypeAny = ZodTypeAny> = Readonly<
[T, ...T[]]
>;
export interface ZodUnionDef<
T extends ZodUnionOptions = Readonly<
[ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
Expand Down Expand Up @@ -3217,46 +3224,161 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
/////////////////////////////////////////////////////
/////////////////////////////////////////////////////

const getDiscriminator = <T extends ZodTypeAny>(type: T): Primitive[] => {
type AnyZodDiscriminatedUnionOption =
| SomeZodObject
| ZodIntersection<
AnyZodDiscriminatedUnionOption,
AnyZodDiscriminatedUnionOption
>
| ZodUnion<ZodUnionOptions<AnyZodDiscriminatedUnionOption>>
| ZodDiscriminatedUnion<any, readonly AnyZodDiscriminatedUnionOption[]>;

export type ZodDiscriminatedUnionOption<Discriminator extends string> =
| ZodObject<
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
UnknownKeysParam,
ZodTypeAny
>
| ZodIntersection<
ZodDiscriminatedUnionOption<Discriminator>,
AnyZodDiscriminatedUnionOption
>
| ZodIntersection<
AnyZodDiscriminatedUnionOption,
ZodDiscriminatedUnionOption<Discriminator>
>
| ZodUnion<ZodUnionOptions<ZodDiscriminatedUnionOption<Discriminator>>>
| ZodDiscriminatedUnion<
any,
readonly ZodDiscriminatedUnionOption<Discriminator>[]
>;

const getDiscriminatorValues = (
discriminator: string,
type: AnyZodDiscriminatedUnionOption,
values: Set<Primitive>
) => {
if (type instanceof ZodObject) {
getDiscriminator(type.shape[discriminator], values);
} else if (type instanceof ZodIntersection) {
const leftHasDiscriminator = hasDiscriminator(
discriminator,
type._def.left
);
const rightHasDiscriminator = hasDiscriminator(
discriminator,
type._def.right
);

if (leftHasDiscriminator && rightHasDiscriminator) {
const leftValues = new Set<Primitive>();
const rightValues = new Set<Primitive>();

getDiscriminatorValues(discriminator, type._def.left, leftValues);
getDiscriminatorValues(discriminator, type._def.right, rightValues);

for (const value of leftValues) {
if (rightValues.has(value)) {
values.add(value);
}
}
} else if (leftHasDiscriminator) {
getDiscriminatorValues(discriminator, type._def.left, values);
} else if (rightHasDiscriminator) {
getDiscriminatorValues(discriminator, type._def.right, values);
}
} else if (
type instanceof ZodUnion ||
type instanceof ZodDiscriminatedUnion
) {
for (const optionType of type.options) {
getDiscriminatorValues(discriminator, optionType, values);
}
}
};

const hasDiscriminator = (
discriminator: string,
type: AnyZodDiscriminatedUnionOption
): boolean => {
if (type instanceof ZodObject) {
return discriminator in type.shape;
} else if (type instanceof ZodIntersection) {
return (
hasDiscriminator(discriminator, type._def.left) ||
hasDiscriminator(discriminator, type._def.right)
);
} else if (
type instanceof ZodUnion ||
type instanceof ZodDiscriminatedUnion
) {
return type.options.some((optionType) =>
hasDiscriminator(discriminator, optionType)
);
} else {
return false;
}
};

const getDiscriminator = <T extends ZodTypeAny>(
type: T,
values: Set<Primitive>
) => {
if (type instanceof ZodLazy) {
return getDiscriminator(type.schema);
getDiscriminator(type.schema, values);
} else if (type instanceof ZodEffects) {
return getDiscriminator(type.innerType());
getDiscriminator(type.innerType(), values);
} else if (type instanceof ZodLiteral) {
return [type.value];
values.add(type.value);
} else if (type instanceof ZodEnum) {
return type.options;
for (const value of type.options) {
values.add(value);
}
} else if (type instanceof ZodNativeEnum) {
// eslint-disable-next-line ban/ban
return util.objectValues(type.enum as any);
for (const value of util.objectValues(type.enum as any)) {
values.add(value);
}
} else if (type instanceof ZodDefault) {
return getDiscriminator(type._def.innerType);
getDiscriminator(type._def.innerType, values);
} else if (type instanceof ZodUndefined) {
return [undefined];
values.add(undefined);
} else if (type instanceof ZodNull) {
return [null];
values.add(null);
} else if (type instanceof ZodOptional) {
return [undefined, ...getDiscriminator(type.unwrap())];
values.add(undefined);
getDiscriminator(type.unwrap(), values);
} else if (type instanceof ZodNullable) {
return [null, ...getDiscriminator(type.unwrap())];
values.add(null);
getDiscriminator(type.unwrap(), values);
} else if (type instanceof ZodBranded) {
return getDiscriminator(type.unwrap());
getDiscriminator(type.unwrap(), values);
} else if (type instanceof ZodReadonly) {
return getDiscriminator(type.unwrap());
getDiscriminator(type.unwrap(), values);
} else if (type instanceof ZodCatch) {
return getDiscriminator(type._def.innerType);
} else {
return [];
getDiscriminator(type._def.innerType, values);
} else if (type instanceof ZodIntersection) {
const leftValues = new Set<Primitive>();
const rightValues = new Set<Primitive>();

getDiscriminator(type._def.left, leftValues);
getDiscriminator(type._def.right, rightValues);

for (const value of leftValues) {
if (rightValues.has(value)) {
values.add(value);
}
}
} else if (
type instanceof ZodUnion ||
type instanceof ZodDiscriminatedUnion
) {
for (const optionType of type.options) {
getDiscriminator(optionType, values);
}
}
};

export type ZodDiscriminatedUnionOption<Discriminator extends string> =
ZodObject<
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
UnknownKeysParam,
ZodTypeAny
>;

export interface ZodDiscriminatedUnionDef<
Discriminator extends string,
Options extends readonly ZodDiscriminatedUnionOption<string>[] = ZodDiscriminatedUnionOption<string>[]
Expand Down Expand Up @@ -3352,9 +3474,10 @@ export class ZodDiscriminatedUnion<
const optionsMap: Map<Primitive, Types[number]> = new Map();

// try {
const discriminatorValues = new Set<Primitive>();
for (const type of options) {
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
if (!discriminatorValues.length) {
getDiscriminatorValues(discriminator, type, discriminatorValues);
if (discriminatorValues.size < 1) {
throw new Error(
`A discriminator value for key \`${discriminator}\` could not be extracted from all schema options`
);
Expand All @@ -3370,6 +3493,7 @@ export class ZodDiscriminatedUnion<

optionsMap.set(value, type);
}
discriminatorValues.clear();
}

return new ZodDiscriminatedUnion<
Expand Down
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
],
"license": "MIT",
"lint-staged": {
"src/*.ts": [
"(src|deno/lib)/**/*.ts": [
"eslint --cache --fix",
"prettier --ignore-unknown --write"
],
Expand All @@ -84,8 +84,8 @@
]
},
"scripts": {
"prettier:check": "prettier --check src/**/*.ts deno/lib/**/*.ts *.md --no-error-on-unmatched-pattern",
"prettier:fix": "prettier --write src/**/*.ts deno/lib/**/*.ts *.md --ignore-unknown --no-error-on-unmatched-pattern",
"prettier:check": "prettier --check 'src/**/*.ts' 'deno/lib/**/*.ts' *.md --no-error-on-unmatched-pattern",
"prettier:fix": "prettier --write 'src/**/*.ts' 'deno/lib/**/*.ts' *.md --ignore-unknown --no-error-on-unmatched-pattern",
"lint:check": "eslint --cache --ext .ts ./src",
"lint:fix": "eslint --cache --fix --ext .ts ./src",
"check": "yarn lint:check && yarn prettier:check",
Expand Down
Loading