Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(delegate): delegate model's guards are not properly including concrete models #1932

Merged
merged 5 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {
isArrayExpr,
isDataModel,
isGeneratorDecl,
isReferenceExpr,
isTypeDef,
type Model,
} from '@zenstackhq/sdk/ast';
Expand All @@ -45,6 +44,7 @@ import {
} from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
import { execPackage } from '../../../utils/exec-utils';
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
import { trackPrismaSchemaError } from '../../prisma';
Expand Down Expand Up @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
this.model.declarations
.filter((d): d is DataModel => isDelegateModel(d))
.forEach((dm) => {
const concreteModels = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
);
const concreteModels = getConcreteModels(dm);
if (concreteModels.length > 0) {
delegateInfo.push([dm, concreteModels]);
}
Expand Down Expand Up @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const typeName = typeAlias.getName();
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
if (payloadRecord) {
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
if (discriminatorDecl) {
source = `${payloadRecord[1]
.map(
Expand Down Expand Up @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
}

private getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
saveSourceFile(sf);
Expand Down
24 changes: 13 additions & 11 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,18 @@ export class ExpressionWriter {
operation = this.options.operationContext;
}

this.block(() => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
}
});
this.block(() =>
this.writeFieldCondition(fieldRef, () => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(FALSE);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${targetGuardFunc}(context, db)`);
}
})
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
hasAttribute,
hasValidationAttributes,
isAuthInvocation,
isDelegateModel,
isForeignKeyField,
saveSourceFile,
} from '@zenstackhq/sdk';
Expand Down Expand Up @@ -454,36 +455,44 @@ export class PolicyGenerator {
writer: CodeBlockWriter,
sourceFile: SourceFile
) {
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
const isDelegate = isDelegateModel(model);

if (!isDelegate) {
// handle cases where a constant function can be used
// note that this doesn't apply to delegate models because
// all concrete models inheriting it need to be considered

if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()!},`);
return;
}
writer.write(`guard: ${func.getName()!},`);
return;
}

if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}

if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
}
}

// generate a policy function that evaluates a partial prisma query
Expand Down
72 changes: 69 additions & 3 deletions packages/schema/src/plugins/enhancer/policy/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime';
import { DELEGATE_AUX_RELATION_PREFIX, type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
import {
ExpressionContext,
PluginError,
Expand All @@ -15,6 +15,7 @@ import {
getQueryGuardFunctionName,
isAuthInvocation,
isDataModelFieldReference,
isDelegateModel,
isEnumFieldReference,
isFromStdlib,
isFutureExpr,
Expand All @@ -39,9 +40,16 @@ import {
} from '@zenstackhq/sdk/ast';
import deepmerge from 'deepmerge';
import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import { SourceFile, WriterFunction } from 'ts-morph';
import { name } from '..';
import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils';
import {
getConcreteModels,
getDiscriminatorField,
isCheckInvocation,
isCollectionPredicate,
isFutureInvocation,
} from '../../../utils/ast-utils';
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';

/**
Expand Down Expand Up @@ -303,8 +311,11 @@ export function generateQueryGuardFunction(
forField?: DataModelField,
fieldOverride = false
) {
const statements: (string | WriterFunction)[] = [];
if (isDelegateModel(model) && !forField) {
return generateDelegateQueryGuardFunction(sourceFile, model, kind);
}

const statements: (string | WriterFunction)[] = [];
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));

Expand Down Expand Up @@ -438,6 +449,61 @@ export function generateQueryGuardFunction(
return func;
}

function generateDelegateQueryGuardFunction(sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind) {
const concreteModels = getConcreteModels(model);

const discriminator = getDiscriminatorField(model);
if (!discriminator) {
throw new PluginError(name, `Model '${model.name}' does not have a discriminator field`);
}

const func = sourceFile.addFunction({
name: getQueryGuardFunctionName(model, undefined, false, kind),
returnType: 'any',
parameters: [
{
name: 'context',
type: 'QueryContext',
},
{
// for generating field references used by field comparison in the same model
name: 'db',
type: 'CrudContract',
},
],
statements: (writer) => {
writer.write('return ');
if (concreteModels.length === 0) {
writer.write(TRUE);
} else {
writer.block(() => {
// union all concrete model's guards
writer.writeLine('OR: [');
concreteModels.forEach((concrete) => {
writer.block(() => {
writer.write('AND: [');
// discriminator condition
writer.write(`{ ${discriminator.name}: '${concrete.name}' },`);
// concrete model guard
writer.write(
`{ ${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(
concrete.name
)}: ${getQueryGuardFunctionName(concrete, undefined, false, kind)}(context, db) }`
);
writer.writeLine(']');
});
writer.write(',');
});
writer.writeLine(']');
});
}
writer.write(';');
},
});

return func;
}

export function generateEntityCheckerFunction(
sourceFile: SourceFile,
model: DataModel,
Expand Down
5 changes: 2 additions & 3 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import path from 'path';
import semver from 'semver';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { getConcreteModels } from '../../utils/ast-utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
Expand Down Expand Up @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
}

// collect concrete models inheriting this model
const concreteModels = decl.$container.declarations.filter(
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
);
const concreteModels = getConcreteModels(decl);

// generate an optional relation field in delegate base model to each concrete model
concreteModels.forEach((concrete) => {
Expand Down
28 changes: 27 additions & 1 deletion packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
TypeDef,
} from '@zenstackhq/language/ast';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
}
return undefined;
}

/**
* Gets all concrete models that inherit from the given delegate model
*/
export function getConcreteModels(dataModel: DataModel): DataModel[] {
if (!isDelegateModel(dataModel)) {
return [];
}
return dataModel.$container.declarations.filter(
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
);
}

/**
* Gets the discriminator field for the given delegate model
*/
export function getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}
Loading
Loading