Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: merge from dev #1054

Merged
merged 1 commit into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion jest.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ export default {
// Automatically clear mock calls, instances, contexts and results before every test
clearMocks: true,

globalSetup: path.join(__dirname, './test-setup.ts'),
globalSetup: path.join(__dirname, './script/test-global-setup.ts'),

setupFiles: [path.join(__dirname, './script/set-test-env.ts')],

// Indicates whether the coverage information should be collected while executing the test
collectCoverage: true,
Expand Down
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"scripts": {
"build": "pnpm -r build",
"lint": "pnpm -r lint",
"test": "ZENSTACK_TEST=1 pnpm -r --parallel run test --silent --forceExit",
"test-ci": "ZENSTACK_TEST=1 pnpm -r --parallel run test --silent --forceExit",
"test": "pnpm -r --parallel run test --silent --forceExit",
"test-ci": "pnpm -r --parallel run test --silent --forceExit",
"test-scaffold": "tsx script/test-scaffold.ts",
"publish-all": "pnpm --filter \"./packages/**\" -r publish --access public",
"publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/",
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/openapi/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE dist && copyfiles -u 1 ./src/plugin.zmodel dist && pnpm pack dist --pack-destination '../../../../.build'",
"watch": "tsc --watch",
"lint": "eslint src --ext ts",
"test": "ZENSTACK_TEST=1 jest",
"test": "jest",
"prepublishOnly": "pnpm build"
},
"keywords": [
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/swr/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && tsup-node --config ./tsup.config.ts && copyfiles ./package.json ./README.md ./LICENSE dist && pnpm pack dist --pack-destination '../../../../.build'",
"watch": "concurrently \"tsc --watch\" \"tsup-node --config ./tsup.config.ts --watch\"",
"lint": "eslint src --ext ts",
"test": "ZENSTACK_TEST=1 jest",
"test": "jest",
"prepublishOnly": "pnpm build"
},
"publishConfig": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/tanstack-query/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && tsup-node --config ./tsup.config.ts && tsup-node --config ./tsup-v5.config.ts && node scripts/postbuild && copyfiles ./package.json ./README.md ./LICENSE dist && pnpm pack dist --pack-destination '../../../../.build'",
"watch": "concurrently \"tsc --watch\" \"tsup-node --config ./tsup.config.ts --watch\" \"tsup-node --config ./tsup-v5.config.ts --watch\"",
"lint": "eslint src --ext ts",
"test": "ZENSTACK_TEST=1 jest",
"test": "jest",
"prepublishOnly": "pnpm build"
},
"publishConfig": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/trpc/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE 'res/**/*' dist && pnpm pack dist --pack-destination '../../../../.build'",
"watch": "tsc --watch",
"lint": "eslint src --ext ts",
"test": "ZENSTACK_TEST=1 jest",
"test": "jest",
"prepublishOnly": "pnpm build"
},
"publishConfig": {
Expand Down
65 changes: 56 additions & 9 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// Validates the given create payload against Zod schema if any
private validateCreateInputSchema(model: string, data: any) {
const schema = this.policyUtils.getZodSchema(model, 'create');
if (schema) {
if (schema && data) {
const parseResult = schema.safeParse(data);
if (!parseResult.success) {
throw this.policyUtils.deniedByPolicy(
Expand Down Expand Up @@ -514,26 +514,29 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

args = this.policyUtils.clone(args);

// do static input validation and check if post-create checks are needed
// go through create items, statically check input to determine if post-create
// check is needed, and also validate zod schema
let needPostCreateCheck = false;
for (const item of enumerate(args.data)) {
const validationResult = this.validateCreateInputSchema(this.model, item);
if (validationResult !== item) {
this.policyUtils.replace(item, validationResult);
}

const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create');
if (inputCheck === false) {
// unconditionally deny
throw this.policyUtils.deniedByPolicy(
this.model,
'create',
undefined,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (inputCheck === true) {
const r = this.validateCreateInputSchema(this.model, item);
if (r !== item) {
this.policyUtils.replace(item, r);
}
// unconditionally allow
Comment on lines +517 to +536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While iterating over args.data for createMany operations, the validation and input check logic is sound. However, consider optimizing this loop to avoid potential performance issues with large datasets by minimizing the number of times validateCreateInputSchema and checkInputGuard are called.

- for (const item of enumerate(args.data)) {
+ for (const [index, item] of enumerate(args.data).entries()) {

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
// go through create items, statically check input to determine if post-create
// check is needed, and also validate zod schema
let needPostCreateCheck = false;
for (const item of enumerate(args.data)) {
const validationResult = this.validateCreateInputSchema(this.model, item);
if (validationResult !== item) {
this.policyUtils.replace(item, validationResult);
}
const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create');
if (inputCheck === false) {
// unconditionally deny
throw this.policyUtils.deniedByPolicy(
this.model,
'create',
undefined,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (inputCheck === true) {
const r = this.validateCreateInputSchema(this.model, item);
if (r !== item) {
this.policyUtils.replace(item, r);
}
// unconditionally allow
// go through create items, statically check input to determine if post-create
// check is needed, and also validate zod schema
let needPostCreateCheck = false;
for (const [index, item] of enumerate(args.data).entries()) {
const validationResult = this.validateCreateInputSchema(this.model, item);
if (validationResult !== item) {
this.policyUtils.replace(item, validationResult);
}
const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create');
if (inputCheck === false) {
// unconditionally deny
throw this.policyUtils.deniedByPolicy(
this.model,
'create',
undefined,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (inputCheck === true) {
// unconditionally allow

} else if (inputCheck === undefined) {
// static policy check is not possible, need to do post-create check
needPostCreateCheck = true;
break;
}
}

Expand Down Expand Up @@ -808,7 +811,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

// check if the update actually writes to this model
let thisModelUpdate = false;
const updatePayload: any = (args as any).data ?? args;
const updatePayload = (args as any).data ?? args;

const validatedPayload = this.validateUpdateInputSchema(model, updatePayload);
if (validatedPayload !== updatePayload) {
this.policyUtils.replace(updatePayload, validatedPayload);
}

if (updatePayload) {
for (const key of Object.keys(updatePayload)) {
const field = resolveField(this.modelMeta, model, key);
Expand Down Expand Up @@ -879,6 +888,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
);
}

args.data = this.validateUpdateInputSchema(model, args.data);

const updateGuard = this.policyUtils.getAuthGuard(db, model, 'update');
if (this.policyUtils.isTrue(updateGuard) || this.policyUtils.isFalse(updateGuard)) {
// injects simple auth guard into where clause
Expand Down Expand Up @@ -939,7 +950,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await _registerPostUpdateCheck(model, uniqueFilter);

// convert upsert to update
context.parent.update = { where: args.where, data: args.update };
context.parent.update = {
where: args.where,
data: this.validateUpdateInputSchema(model, args.update),
};
delete context.parent.upsert;

// continue visiting the new payload
Expand Down Expand Up @@ -1038,6 +1052,37 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
return { result, postWriteChecks };
}

// Validates the given update payload against Zod schema if any
private validateUpdateInputSchema(model: string, data: any) {
const schema = this.policyUtils.getZodSchema(model, 'update');
if (schema && data) {
// update payload can contain non-literal fields, like:
// { x: { increment: 1 } }
// we should only validate literal fields

const literalData = Object.entries(data).reduce<any>(
(acc, [k, v]) => ({ ...acc, ...(typeof v !== 'object' ? { [k]: v } : {}) }),
{}
);

const parseResult = schema.safeParse(literalData);
if (!parseResult.success) {
throw this.policyUtils.deniedByPolicy(
model,
'update',
`input failed validation: ${fromZodError(parseResult.error)}`,
CrudFailureReason.DATA_VALIDATION_VIOLATION,
parseResult.error
);
}

// schema may have transformed field values, use it to overwrite the original data
return { ...data, ...parseResult.data };
} else {
return data;
}
}

private isUnsafeMutate(model: string, args: any) {
if (!args) {
return false;
Expand Down Expand Up @@ -1072,6 +1117,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
args = this.policyUtils.clone(args);
this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update');

args.data = this.validateUpdateInputSchema(this.model, args.data);

Comment on lines +1120 to +1121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying Zod schema validation to updateMany operations is consistent with the approach for single updates. However, consider the impact on performance and ensure that the validation logic is optimized for handling bulk operations efficiently.

- args.data = this.validateUpdateInputSchema(this.model, args.data);
+ // Consider batching or optimizing validation for bulk operations

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
args.data = this.validateUpdateInputSchema(this.model, args.data);
// Consider batching or optimizing validation for bulk operations

if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) {
// use a transaction to do post-update checks
const postWriteChecks: PostWriteCheckRecord[] = [];
Expand Down
21 changes: 18 additions & 3 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ export class PolicyUtil extends QueryUtils {
/**
* Checks if the given model has a policy guard for the given operation.
*/
hasAuthGuard(model: string, operation: PolicyOperationKind): boolean {
hasAuthGuard(model: string, operation: PolicyOperationKind) {
const guard = this.policy.guard[lowerCaseFirst(model)];
if (!guard) {
return false;
Expand All @@ -326,6 +326,21 @@ export class PolicyUtil extends QueryUtils {
return typeof provider !== 'boolean' || provider !== true;
}

/**
* Checks if the given model has any field-level override policy guard for the given operation.
*/
hasOverrideAuthGuard(model: string, operation: PolicyOperationKind) {
const guard = this.requireGuard(model);
switch (operation) {
case 'read':
return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX));
case 'update':
return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX));
default:
return false;
}
}

/**
* Checks model creation policy based on static analysis to the input args.
*
Expand Down Expand Up @@ -632,7 +647,7 @@ export class PolicyUtil extends QueryUtils {
preValue?: any
) {
let guard = this.getAuthGuard(db, model, operation, preValue);
if (this.isFalse(guard)) {
if (this.isFalse(guard) && !this.hasOverrideAuthGuard(model, operation)) {
throw this.deniedByPolicy(
model,
operation,
Expand Down Expand Up @@ -805,7 +820,7 @@ export class PolicyUtil extends QueryUtils {
*/
tryReject(db: CrudContract, model: string, operation: PolicyOperationKind) {
const guard = this.getAuthGuard(db, model, operation);
if (this.isFalse(guard)) {
if (this.isFalse(guard) && !this.hasOverrideAuthGuard(model, operation)) {
throw this.deniedByPolicy(model, operation, undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION);
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"bundle": "rimraf bundle && pnpm lint --max-warnings=0 && node build/bundle.js --minify",
"watch": "tsc --watch",
"lint": "eslint src tests --ext ts",
"test": "ZENSTACK_TEST=1 jest",
"test": "jest",
"prepublishOnly": "pnpm build",
"postinstall": "node bin/post-install.js"
},
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/cli/cli-util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../language-server/cons
import { ZModelFormatter } from '../language-server/zmodel-formatter';
import { createZModelServices, ZModelServices } from '../language-server/zmodel-module';
import { mergeBaseModel, resolveImport, resolveTransitiveImports } from '../utils/ast-utils';
import { findPackageJson } from '../utils/pkg-utils';
import { findUp } from '../utils/pkg-utils';
import { getVersion } from '../utils/version-utils';
import { CliError } from './cli-error';

Expand Down Expand Up @@ -289,7 +289,7 @@ export async function formatDocument(fileName: string) {

export function getDefaultSchemaLocation() {
// handle override from package.json
const pkgJsonPath = findPackageJson();
const pkgJsonPath = findUp(['package.json']);
if (pkgJsonPath) {
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf-8'));
if (typeof pkgJson?.zenstack?.schema === 'string') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ export default class DataModelValidator implements AstValidator<DataModel> {
}

const isArray = idField.type.array;
const isScalar = SCALAR_TYPES.includes(idField.type.type as typeof SCALAR_TYPES[number])
const isValidType = isScalar || isEnum(idField.type.reference?.ref)
const isScalar = SCALAR_TYPES.includes(idField.type.type as (typeof SCALAR_TYPES)[number]);
const isValidType = isScalar || isEnum(idField.type.reference?.ref);
Comment on lines +67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for determining if a field's type is scalar or an enum is correctly implemented. However, consider adding a comment explaining why these specific checks are necessary for fields with @id attributes, as it might not be immediately clear to someone unfamiliar with the domain-specific rules.

+ // Ensure @id fields are either scalar types or enums for database compatibility
  const isScalar = SCALAR_TYPES.includes(idField.type.type as (typeof SCALAR_TYPES)[number]);
  const isValidType = isScalar || isEnum(idField.type.reference?.ref);

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
const isScalar = SCALAR_TYPES.includes(idField.type.type as (typeof SCALAR_TYPES)[number]);
const isValidType = isScalar || isEnum(idField.type.reference?.ref);
// Ensure @id fields are either scalar types or enums for database compatibility
const isScalar = SCALAR_TYPES.includes(idField.type.type as (typeof SCALAR_TYPES)[number]);
const isValidType = isScalar || isEnum(idField.type.reference?.ref);


if (isArray || !isValidType) {
accept('error', 'Field with @id attribute must be of scalar or enum type', { node: idField });
Expand Down Expand Up @@ -121,7 +121,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
fields = (arg.value as ArrayExpr).items as ReferenceExpr[];
if (fields.length === 0) {
if (accept) {
accept('error', `"fields" value cannot be emtpy`, {
accept('error', `"fields" value cannot be empty`, {
node: arg,
});
}
Expand All @@ -131,7 +131,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
references = (arg.value as ArrayExpr).items as ReferenceExpr[];
if (references.length === 0) {
if (accept) {
accept('error', `"references" value cannot be emtpy`, {
accept('error', `"references" value cannot be empty`, {
node: arg,
});
}
Expand All @@ -157,6 +157,17 @@ export default class DataModelValidator implements AstValidator<DataModel> {
}
} else {
for (let i = 0; i < fields.length; i++) {
if (!field.type.optional && fields[i].$resolvedType?.nullable) {
// if relation is not optional, then fk field must not be nullable
if (accept) {
accept(
'error',
`relation "${field.name}" is not optional, but field "${fields[i].target.$refText}" is optional`,
{ node: fields[i].target.ref! }
);
Comment on lines +160 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code introduces a validation check to ensure that if a relation is not optional, then the foreign key field must not be nullable. This is an important check for data integrity and consistency. However, consider adding a brief comment to explain the rationale behind this check for future maintainability.

+ // Validate non-optional relations to ensure foreign key fields are not nullable
  if (!field.type.optional && fields[i].$resolvedType?.nullable) {
      if (accept) {
          accept(
              'error',
              `relation "${field.name}" is not optional, but field "${fields[i].target.$refText}" is optional`,
              { node: fields[i].target.ref! }
          );
      }
  }

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if (!field.type.optional && fields[i].$resolvedType?.nullable) {
// if relation is not optional, then fk field must not be nullable
if (accept) {
accept(
'error',
`relation "${field.name}" is not optional, but field "${fields[i].target.$refText}" is optional`,
{ node: fields[i].target.ref! }
);
// Validate non-optional relations to ensure foreign key fields are not nullable
if (!field.type.optional && fields[i].$resolvedType?.nullable) {
// if relation is not optional, then fk field must not be nullable
if (accept) {
accept(
'error',
`relation "${field.name}" is not optional, but field "${fields[i].target.$refText}" is optional`,
{ node: fields[i].target.ref! }
);

}
}

if (!fields[i].$resolvedType) {
if (accept) {
accept('error', `field reference is unresolved`, { node: fields[i] });
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import telemetry from '../../telemetry';
import { execPackage } from '../../utils/exec-utils';
import { findPackageJson } from '../../utils/pkg-utils';
import { findUp } from '../../utils/pkg-utils';
import {
AttributeArgValue,
ModelFieldType,
Expand Down Expand Up @@ -666,7 +666,7 @@ function isDescendantOf(model: DataModel, superModel: DataModel): boolean {

export function getDefaultPrismaOutputFile(schemaPath: string) {
// handle override from package.json
const pkgJsonPath = findPackageJson(path.dirname(schemaPath));
const pkgJsonPath = findUp(['package.json'], path.dirname(schemaPath));
if (pkgJsonPath) {
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf-8'));
if (typeof pkgJson?.zenstack?.prisma === 'string') {
Expand Down
Loading
Loading