From 8bd694a658dae4e2fd99401d5d97460e41b96857 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 13 Apr 2025 17:13:02 -0700 Subject: [PATCH] perf: improve TS code generation performance --- .../enhancer/policy/expression-writer.ts | 4 +- .../enhancer/policy/policy-guard-generator.ts | 233 ++++++++---------- .../src/plugins/enhancer/policy/utils.ts | 193 +++++++-------- packages/sdk/src/code-gen.ts | 57 +++++ packages/sdk/src/model-meta-generator.ts | 78 +++--- 5 files changed, 300 insertions(+), 265 deletions(-) diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 0d792bdc1..2c3334fb3 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -21,6 +21,7 @@ import { } from '@zenstackhq/language/ast'; import { DELEGATE_AUX_RELATION_PREFIX, PolicyOperationKind } from '@zenstackhq/runtime'; import { + CodeWriter, ExpressionContext, getFunctionExpressionContext, getIdFields, @@ -37,7 +38,6 @@ import { } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; -import { CodeBlockWriter } from 'ts-morph'; import { name } from '..'; import { isCheckInvocation } from '../../../utils/ast-utils'; @@ -77,7 +77,7 @@ export class ExpressionWriter { /** * Constructs a new ExpressionWriter */ - constructor(private readonly writer: CodeBlockWriter, private readonly options: ExpressionWriterOptions) { + constructor(private readonly writer: CodeWriter, private readonly options: ExpressionWriterOptions) { this.plainExprBuilder = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, isPostGuard: this.options.isPostGuard, diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 9ffe41dcb..e7651754d 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -15,7 +15,9 @@ import { } from '@zenstackhq/language/ast'; import { PolicyCrudKind, type PolicyOperationKind } from '@zenstackhq/runtime'; import { + type CodeWriter, ExpressionContext, + FastWriter, PluginOptions, PolicyAnalysisResult, RUNTIME_PACKAGE, @@ -32,14 +34,7 @@ import { getPrismaClientImportSpec } from '@zenstackhq/sdk/prisma'; import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { - CodeBlockWriter, - FunctionDeclaration, - Project, - SourceFile, - VariableDeclarationKind, - WriterFunction, -} from 'ts-morph'; +import { FunctionDeclarationStructure, OptionalKind, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { isCheckInvocation } from '../../../utils/ast-utils'; import { ConstraintTransformer } from './constraint-transformer'; import { @@ -56,6 +51,8 @@ import { * Generates source file that contains Prisma query guard objects used for injecting database queries */ export class PolicyGenerator { + private extraFunctions: OptionalKind[] = []; + constructor(private options: PluginOptions) {} generate(project: Project, model: Model, output: string) { @@ -65,23 +62,28 @@ export class PolicyGenerator { const models = getDataModels(model); + const writer = new FastWriter(); + writer.block(() => { + this.writePolicy(writer, models); + this.writeValidationMeta(writer, models); + this.writeAuthSelector(models, writer); + }); + sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [ { name: 'policy', type: 'PolicyDef', - initializer: (writer) => { - writer.block(() => { - this.writePolicy(writer, models, sf); - this.writeValidationMeta(writer, models); - this.writeAuthSelector(models, writer); - }); - }, + initializer: writer.result, }, ], }); + if (this.extraFunctions.length > 0) { + sf.addFunctions(this.extraFunctions); + } + sf.addStatements('export default policy'); // save ts files if requested explicitly or the user provided @@ -121,7 +123,7 @@ export class PolicyGenerator { } } - private writePolicy(writer: CodeBlockWriter, models: DataModel[], sourceFile: SourceFile) { + private writePolicy(writer: CodeWriter, models: DataModel[]) { writer.write('policy:'); writer.inlineBlock(() => { for (const model of models) { @@ -129,10 +131,10 @@ export class PolicyGenerator { writer.block(() => { // model-level guards - this.writeModelLevelDefs(model, writer, sourceFile); + this.writeModelLevelDefs(model, writer); // field-level guards - this.writeFieldLevelDefs(model, writer, sourceFile); + this.writeFieldLevelDefs(model, writer); }); writer.writeLine(','); @@ -145,55 +147,45 @@ export class PolicyGenerator { // writes model-level policy def for each operation kind for a model // `[modelName]: { [operationKind]: [funcName] },` - private writeModelLevelDefs(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + private writeModelLevelDefs(model: DataModel, writer: CodeWriter) { const policies = analyzePolicies(model); writer.write('modelLevel:'); writer.inlineBlock(() => { - this.writeModelReadDef(model, policies, writer, sourceFile); - this.writeModelCreateDef(model, policies, writer, sourceFile); - this.writeModelUpdateDef(model, policies, writer, sourceFile); - this.writeModelPostUpdateDef(model, policies, writer, sourceFile); - this.writeModelDeleteDef(model, policies, writer, sourceFile); + this.writeModelReadDef(model, policies, writer); + this.writeModelCreateDef(model, policies, writer); + this.writeModelUpdateDef(model, policies, writer); + this.writeModelPostUpdateDef(model, policies, writer); + this.writeModelDeleteDef(model, policies, writer); }); writer.writeLine(','); } // writes `read: ...` for a given model - private writeModelReadDef( - model: DataModel, - policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile - ) { + private writeModelReadDef(model: DataModel, policies: PolicyAnalysisResult, writer: CodeWriter) { writer.write(`read:`); writer.inlineBlock(() => { - this.writeCommonModelDef(model, 'read', policies, writer, sourceFile); + this.writeCommonModelDef(model, 'read', policies, writer); }); writer.writeLine(','); } // writes `create: ...` for a given model - private writeModelCreateDef( - model: DataModel, - policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile - ) { + private writeModelCreateDef(model: DataModel, policies: PolicyAnalysisResult, writer: CodeWriter) { writer.write(`create:`); writer.inlineBlock(() => { - this.writeCommonModelDef(model, 'create', policies, writer, sourceFile); + this.writeCommonModelDef(model, 'create', policies, writer); // create policy has an additional input checker for validating the payload - this.writeCreateInputChecker(model, writer, sourceFile); + this.writeCreateInputChecker(model, writer); }); writer.writeLine(','); } // writes `inputChecker: [funcName]` for a given model - private writeCreateInputChecker(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + private writeCreateInputChecker(model: DataModel, writer: CodeWriter) { if (this.canCheckCreateBasedOnInput(model)) { - const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile); - writer.write(`inputChecker: ${inputCheckFunc.getName()!},`); + const inputCheckFuncName = this.generateCreateInputCheckerFunction(model); + writer.write(`inputChecker: ${inputCheckFuncName},`); } } @@ -237,19 +229,18 @@ export class PolicyGenerator { } // generates a function for checking "create" input - private generateCreateInputCheckerFunction(model: DataModel, sourceFile: SourceFile) { - const statements: (string | WriterFunction)[] = []; + private generateCreateInputCheckerFunction(model: DataModel) { + const statements: string[] = []; const allows = getPolicyExpressions(model, 'allow', 'create'); const denies = getPolicyExpressions(model, 'deny', 'create'); generateNormalizedAuthRef(model, allows, denies, statements); - statements.push((writer) => { - if (allows.length === 0) { - writer.write('return false;'); - return; - } - + // write allow and deny rules + const writer = new FastWriter(); + if (allows.length === 0) { + writer.write('return false;'); + } else { const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', @@ -275,10 +266,12 @@ export class PolicyGenerator { expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); - }); + } + statements.push(writer.result); - const func = sourceFile.addFunction({ - name: model.name + '_create_input', + const funcName = model.name + '_create_input'; + this.extraFunctions.push({ + name: funcName, returnType: 'boolean', parameters: [ { @@ -293,33 +286,23 @@ export class PolicyGenerator { statements, }); - return func; + return funcName; } // writes `update: ...` for a given model - private writeModelUpdateDef( - model: DataModel, - policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile - ) { + private writeModelUpdateDef(model: DataModel, policies: PolicyAnalysisResult, writer: CodeWriter) { writer.write(`update:`); writer.inlineBlock(() => { - this.writeCommonModelDef(model, 'update', policies, writer, sourceFile); + this.writeCommonModelDef(model, 'update', policies, writer); }); writer.writeLine(','); } // writes `postUpdate: ...` for a given model - private writeModelPostUpdateDef( - model: DataModel, - policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile - ) { + private writeModelPostUpdateDef(model: DataModel, policies: PolicyAnalysisResult, writer: CodeWriter) { writer.write(`postUpdate:`); writer.inlineBlock(() => { - this.writeCommonModelDef(model, 'postUpdate', policies, writer, sourceFile); + this.writeCommonModelDef(model, 'postUpdate', policies, writer); // post-update policy has an additional selector for reading the pre-update entity data this.writePostUpdatePreValueSelector(model, writer); @@ -327,7 +310,7 @@ export class PolicyGenerator { writer.writeLine(','); } - private writePostUpdatePreValueSelector(model: DataModel, writer: CodeBlockWriter) { + private writePostUpdatePreValueSelector(model: DataModel, writer: CodeWriter) { const allows = getPolicyExpressions(model, 'allow', 'postUpdate'); const denies = getPolicyExpressions(model, 'deny', 'postUpdate'); const preValueSelect = generateSelectForRules([...allows, ...denies], 'postUpdate'); @@ -337,15 +320,10 @@ export class PolicyGenerator { } // writes `delete: ...` for a given model - private writeModelDeleteDef( - model: DataModel, - policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile - ) { + private writeModelDeleteDef(model: DataModel, policies: PolicyAnalysisResult, writer: CodeWriter) { writer.write(`delete:`); writer.inlineBlock(() => { - this.writeCommonModelDef(model, 'delete', policies, writer, sourceFile); + this.writeCommonModelDef(model, 'delete', policies, writer); }); } @@ -354,23 +332,22 @@ export class PolicyGenerator { model: DataModel, kind: PolicyOperationKind, policies: PolicyAnalysisResult, - writer: CodeBlockWriter, - sourceFile: SourceFile + writer: CodeWriter ) { const allows = getPolicyExpressions(model, 'allow', kind); const denies = getPolicyExpressions(model, 'deny', kind); // policy guard - this.writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile); + this.writePolicyGuard(model, kind, policies, allows, denies, writer); // permission checker if (kind !== 'postUpdate') { - this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); + this.writePermissionChecker(model, kind, policies, allows, denies, writer); } // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma - const { functionName, selector } = this.writeEntityChecker(model, kind, sourceFile, false); + const { functionName, selector } = this.writeEntityChecker(model, kind, false); if (this.shouldUseEntityChecker(model, kind, true, false)) { writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); @@ -420,18 +397,12 @@ export class PolicyGenerator { }); } - private writeEntityChecker( - target: DataModel | DataModelField, - kind: PolicyOperationKind, - sourceFile: SourceFile, - forOverride: boolean - ) { + private writeEntityChecker(target: DataModel | DataModelField, kind: PolicyOperationKind, forOverride: boolean) { const allows = getPolicyExpressions(target, 'allow', kind, forOverride, 'all'); const denies = getPolicyExpressions(target, 'deny', kind, forOverride, 'all'); const model = isDataModel(target) ? target : (target.$container as DataModel); const func = generateEntityCheckerFunction( - sourceFile, model, kind, allows, @@ -439,9 +410,10 @@ export class PolicyGenerator { isDataModelField(target) ? target : undefined, forOverride ); + this.extraFunctions.push(func); const selector = generateSelectForRules([...allows, ...denies], kind, false, kind !== 'postUpdate') ?? {}; - return { functionName: func.getName()!, selector }; + return { functionName: func.name, selector }; } // writes `guard: ...` for a given policy operation kind @@ -451,46 +423,48 @@ export class PolicyGenerator { policies: ReturnType, allows: Expression[], denies: Expression[], - writer: CodeBlockWriter, - sourceFile: SourceFile + writer: CodeWriter ) { // first handle several cases where a constant function can be used if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart - let func: FunctionDeclaration; + let func: OptionalKind; if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - func = generateConstantQueryGuardFunction(sourceFile, model, kind, false); + func = generateConstantQueryGuardFunction(model, kind, false); } else { - func = generateConstantQueryGuardFunction(sourceFile, model, kind, true); + func = generateConstantQueryGuardFunction(model, kind, true); } - writer.write(`guard: ${func.getName()!},`); + this.extraFunctions.push(func); + writer.write(`guard: ${func.name},`); return; } if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { // no 'postUpdate' rule, always allow - const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true); - writer.write(`guard: ${func.getName()},`); + const func = generateConstantQueryGuardFunction(model, kind, true); + this.extraFunctions.push(func); + writer.write(`guard: ${func.name},`); return; } if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') { // constant policy const func = generateConstantQueryGuardFunction( - sourceFile, model, kind, policies[kind as keyof typeof policies] as boolean ); - writer.write(`guard: ${func.getName()!},`); + this.extraFunctions.push(func); + writer.write(`guard: ${func.name},`); return; } // generate a policy function that evaluates a partial prisma query - const guardFunc = generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - writer.write(`guard: ${guardFunc.getName()!},`); + const guardFunc = generateQueryGuardFunction(model, kind, allows, denies); + this.extraFunctions.push(guardFunc); + writer.write(`guard: ${guardFunc.name},`); } // writes `permissionChecker: ...` for a given policy operation kind @@ -500,8 +474,7 @@ export class PolicyGenerator { policies: PolicyAnalysisResult, allows: Expression[], denies: Expression[], - writer: CodeBlockWriter, - sourceFile: SourceFile + writer: CodeWriter ) { if (this.options.generatePermissionChecker !== true) { return; @@ -524,16 +497,15 @@ export class PolicyGenerator { return; } - const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); - writer.write(`permissionChecker: ${guardFunc.getName()!},`); + const guardFuncName = this.generatePermissionCheckerFunction(model, kind, allows, denies); + writer.write(`permissionChecker: ${guardFuncName},`); } private generatePermissionCheckerFunction( model: DataModel, kind: string, allows: Expression[], - denies: Expression[], - sourceFile: SourceFile + denies: Expression[] ) { const statements: string[] = []; @@ -545,8 +517,9 @@ export class PolicyGenerator { statements.push(`return ${transformed};`); - const func = sourceFile.addFunction({ - name: `${model.name}$checker$${kind}`, + const funcName = `${model.name}$checker$${kind}`; + this.extraFunctions.push({ + name: funcName, returnType: 'PermissionCheckerConstraint', parameters: [ { @@ -557,23 +530,23 @@ export class PolicyGenerator { statements, }); - return func; + return funcName; } // #endregion // #region Field-level definitions - private writeFieldLevelDefs(model: DataModel, writer: CodeBlockWriter, sf: SourceFile) { + private writeFieldLevelDefs(model: DataModel, writer: CodeWriter) { writer.write('fieldLevel:'); writer.inlineBlock(() => { - this.writeFieldReadDef(model, writer, sf); - this.writeFieldUpdateDef(model, writer, sf); + this.writeFieldReadDef(model, writer); + this.writeFieldUpdateDef(model, writer); }); writer.writeLine(','); } - private writeFieldReadDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + private writeFieldReadDef(model: DataModel, writer: CodeWriter) { writer.writeLine('read:'); writer.block(() => { for (const field of model.fields) { @@ -589,12 +562,13 @@ export class PolicyGenerator { writer.block(() => { // guard - const guardFunc = generateQueryGuardFunction(sourceFile, model, 'read', allows, denies, field); - writer.write(`guard: ${guardFunc.getName()},`); + const guardFunc = generateQueryGuardFunction(model, 'read', allows, denies, field); + this.extraFunctions.push(guardFunc); + writer.write(`guard: ${guardFunc.name},`); // checker function // write all field-level rules as entity checker function - const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, false); + const { functionName, selector } = this.writeEntityChecker(field, 'read', false); if (this.shouldUseEntityChecker(field, 'read', false, false)) { writer.write( @@ -606,7 +580,6 @@ export class PolicyGenerator { // override guard function const denies = getPolicyExpressions(field, 'deny', 'read'); const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, model, 'read', overrideAllows, @@ -614,10 +587,11 @@ export class PolicyGenerator { field, true ); - writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + this.extraFunctions.push(overrideGuardFunc); + writer.write(`overrideGuard: ${overrideGuardFunc.name},`); // additional entity checker for override - const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, true); + const { functionName, selector } = this.writeEntityChecker(field, 'read', true); if (this.shouldUseEntityChecker(field, 'read', false, true)) { writer.write( `overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify( @@ -633,7 +607,7 @@ export class PolicyGenerator { writer.writeLine(','); } - private writeFieldUpdateDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + private writeFieldUpdateDef(model: DataModel, writer: CodeWriter) { writer.writeLine('update:'); writer.block(() => { for (const field of model.fields) { @@ -649,12 +623,13 @@ export class PolicyGenerator { writer.block(() => { // guard - const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); - writer.write(`guard: ${guardFunc.getName()},`); + const guardFunc = generateQueryGuardFunction(model, 'update', allows, denies, field); + this.extraFunctions.push(guardFunc); + writer.write(`guard: ${guardFunc.name},`); // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma - const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, false); + const { functionName, selector } = this.writeEntityChecker(field, 'update', false); if (this.shouldUseEntityChecker(field, 'update', true, false)) { writer.write( `entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },` @@ -664,7 +639,6 @@ export class PolicyGenerator { if (overrideAllows.length > 0) { // override guard const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, model, 'update', overrideAllows, @@ -672,11 +646,12 @@ export class PolicyGenerator { field, true ); - writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + this.extraFunctions.push(overrideGuardFunc); + writer.write(`overrideGuard: ${overrideGuardFunc.name},`); // write cross-model comparison override rules as entity checker functions // because they cannot be checked inside Prisma - const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, true); + const { functionName, selector } = this.writeEntityChecker(field, 'update', true); if (this.shouldUseEntityChecker(field, 'update', true, true)) { writer.write( `overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify( @@ -696,7 +671,7 @@ export class PolicyGenerator { //#region Auth selector - private writeAuthSelector(models: DataModel[], writer: CodeBlockWriter) { + private writeAuthSelector(models: DataModel[], writer: CodeWriter) { const authSelector = this.generateAuthSelector(models); if (authSelector) { writer.write(`authSelector: ${JSON.stringify(authSelector)},`); @@ -744,7 +719,7 @@ export class PolicyGenerator { // #region Validation meta - private writeValidationMeta(writer: CodeBlockWriter, models: DataModel[]) { + private writeValidationMeta(writer: CodeWriter, models: DataModel[]) { writer.write('validation:'); writer.inlineBlock(() => { for (const model of models) { diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index fee0cc15a..ae9a7846f 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -2,6 +2,7 @@ import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; import { ExpressionContext, + FastWriter, PluginError, TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, @@ -39,7 +40,7 @@ import { } from '@zenstackhq/sdk/ast'; import deepmerge from 'deepmerge'; import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium'; -import { SourceFile, WriterFunction } from 'ts-morph'; +import { FunctionDeclarationStructure, OptionalKind } from 'ts-morph'; import { name } from '..'; import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; @@ -265,13 +266,8 @@ export function generateSelectForRules( /** * Generates a constant query guard function */ -export function generateConstantQueryGuardFunction( - sourceFile: SourceFile, - model: DataModel, - kind: PolicyOperationKind, - value: boolean -) { - const func = sourceFile.addFunction({ +export function generateConstantQueryGuardFunction(model: DataModel, kind: PolicyOperationKind, value: boolean) { + return { name: getQueryGuardFunctionName(model, undefined, false, kind), returnType: 'any', parameters: [ @@ -286,16 +282,13 @@ export function generateConstantQueryGuardFunction( }, ], statements: [`return ${value ? TRUE : FALSE};`], - }); - - return func; + } as OptionalKind; } /** * Generates a query guard function that returns a partial Prisma query for the given model or field */ export function generateQueryGuardFunction( - sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind, allows: Expression[], @@ -303,7 +296,7 @@ export function generateQueryGuardFunction( forField?: DataModelField, fieldOverride = false ) { - const statements: (string | WriterFunction)[] = []; + const statements: string[] = []; const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule)); const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule)); @@ -325,100 +318,101 @@ export function generateQueryGuardFunction( if (!hasFieldAccess) { // none of the rules reference model fields, we can compile down to a plain boolean // function in this case (so we can skip doing SQL queries when validating) - statements.push((writer) => { - const transformer = new TypeScriptExpressionTransformer({ - context: ExpressionContext.AccessPolicy, - isPostGuard: kind === 'postUpdate', - operationContext: kind, + const writer = new FastWriter(); + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + isPostGuard: kind === 'postUpdate', + operationContext: kind, + }); + try { + denyRules.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); }); - try { - denyRules.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); - }); - allowRules.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); - }); - } catch (err) { - if (err instanceof TypeScriptExpressionTransformerError) { - throw new PluginError(name, err.message); - } else { - throw err; - } + allowRules.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); + }); + } catch (err) { + if (err instanceof TypeScriptExpressionTransformerError) { + throw new PluginError(name, err.message); + } else { + throw err; } + } - if (forField) { - if (allows.length === 0) { - // if there's no allow rule, for field-level rules, by default we allow - writer.write(`return ${TRUE};`); - } else { - if (allowRules.length < allows.length) { - writer.write(`return ${TRUE};`); - } else { - // if there's any allow rule, we deny unless any allow rule evaluates to true - writer.write(`return ${FALSE};`); - } - } + if (forField) { + if (allows.length === 0) { + // if there's no allow rule, for field-level rules, by default we allow + writer.write(`return ${TRUE};`); } else { if (allowRules.length < allows.length) { - // some rules are filtered out here and will be generated as additional - // checker functions, so we allow here to avoid a premature denial writer.write(`return ${TRUE};`); } else { - // for model-level rules, the default is always deny unless for 'postUpdate' - writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + // if there's any allow rule, we deny unless any allow rule evaluates to true + writer.write(`return ${FALSE};`); } } - }); - } else { - statements.push((writer) => { - writer.write('return '); - const exprWriter = new ExpressionWriter(writer, { - isPostGuard: kind === 'postUpdate', - operationContext: kind, - }); - const writeDenies = () => { - writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); - denyRules.forEach((expr, i) => { - writer.inlineBlock(() => { - writer.write('NOT: '); - exprWriter.write(expr); - }); - writer.conditionalWrite(i !== denyRules.length - 1, ','); - }); - writer.conditionalWrite(denyRules.length > 1, ']}'); - }; - - const writeAllows = () => { - writer.conditionalWrite(allowRules.length > 1, '{ OR: ['); - allowRules.forEach((expr, i) => { - exprWriter.write(expr); - writer.conditionalWrite(i !== allowRules.length - 1, ','); - }); - writer.conditionalWrite(allowRules.length > 1, ']}'); - }; - - if (allowRules.length > 0 && denyRules.length > 0) { - // include both allow and deny rules - writer.write('{ AND: ['); - writeDenies(); - writer.write(','); - writeAllows(); - writer.write(']}'); - } else if (denyRules.length > 0) { - // only deny rules - writeDenies(); - } else if (allowRules.length > 0) { - // only allow rules - writeAllows(); + } else { + if (allowRules.length < allows.length) { + // some rules are filtered out here and will be generated as additional + // checker functions, so we allow here to avoid a premature denial + writer.write(`return ${TRUE};`); } else { - // disallow any operation unless for 'postUpdate' + // for model-level rules, the default is always deny unless for 'postUpdate' writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); } - writer.write(';'); + } + + statements.push(writer.result); + } else { + const writer = new FastWriter(); + writer.write('return '); + const exprWriter = new ExpressionWriter(writer, { + isPostGuard: kind === 'postUpdate', + operationContext: kind, }); + const writeDenies = () => { + writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); + denyRules.forEach((expr, i) => { + writer.inlineBlock(() => { + writer.write('NOT: '); + exprWriter.write(expr); + }); + writer.conditionalWrite(i !== denyRules.length - 1, ','); + }); + writer.conditionalWrite(denyRules.length > 1, ']}'); + }; + + const writeAllows = () => { + writer.conditionalWrite(allowRules.length > 1, '{ OR: ['); + allowRules.forEach((expr, i) => { + exprWriter.write(expr); + writer.conditionalWrite(i !== allowRules.length - 1, ','); + }); + writer.conditionalWrite(allowRules.length > 1, ']}'); + }; + + if (allowRules.length > 0 && denyRules.length > 0) { + // include both allow and deny rules + writer.write('{ AND: ['); + writeDenies(); + writer.write(','); + writeAllows(); + writer.write(']}'); + } else if (denyRules.length > 0) { + // only deny rules + writeDenies(); + } else if (allowRules.length > 0) { + // only allow rules + writeAllows(); + } else { + // disallow any operation unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + } + writer.write(';'); + statements.push(writer.result); } - const func = sourceFile.addFunction({ + return { name: getQueryGuardFunctionName(model, forField, fieldOverride, kind), returnType: 'any', parameters: [ @@ -433,13 +427,10 @@ export function generateQueryGuardFunction( }, ], statements, - }); - - return func; + } as OptionalKind; } export function generateEntityCheckerFunction( - sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind, allows: Expression[], @@ -447,7 +438,7 @@ export function generateEntityCheckerFunction( forField?: DataModelField, fieldOverride = false ) { - const statements: (string | WriterFunction)[] = []; + const statements: string[] = []; generateNormalizedAuthRef(model, allows, denies, statements); @@ -488,7 +479,7 @@ export function generateEntityCheckerFunction( } } - const func = sourceFile.addFunction({ + return { name: getEntityCheckerFunctionName(model, forField, fieldOverride, kind), returnType: 'any', parameters: [ @@ -502,9 +493,7 @@ export function generateEntityCheckerFunction( }, ], statements, - }); - - return func; + } as OptionalKind; } /** @@ -514,7 +503,7 @@ export function generateNormalizedAuthRef( model: DataModel, allows: Expression[], denies: Expression[], - statements: (string | WriterFunction)[] + statements: string[] ) { // check if any allow or deny rule contains 'auth()' invocation const hasAuthRef = [...allows, ...denies].some((rule) => streamAst(rule).some((child) => isAuthInvocation(child))); diff --git a/packages/sdk/src/code-gen.ts b/packages/sdk/src/code-gen.ts index 67833b788..1b68ccf62 100644 --- a/packages/sdk/src/code-gen.ts +++ b/packages/sdk/src/code-gen.ts @@ -70,3 +70,60 @@ export async function emitProject(project: Project) { throw new PluginError('', `Error emitting generated code`); } } + +/* + * Abstraction for source code writer. + */ +export interface CodeWriter { + block(callback: () => void): void; + inlineBlock(callback: () => void): void; + write(text: string): void; + writeLine(text: string): void; + conditionalWrite(condition: boolean, text: string): void; +} + +/** + * A fast code writer. + */ +export class FastWriter implements CodeWriter { + private content = ''; + private indentLevel = 0; + + constructor(private readonly indentSize = 4) {} + + get result() { + return this.content; + } + + block(callback: () => void) { + this.content += '{\n'; + this.indentLevel++; + callback(); + this.indentLevel--; + this.content += '\n}'; + } + + inlineBlock(callback: () => void) { + this.content += '{'; + callback(); + this.content += '}'; + } + + write(text: string) { + this.content += this.indent(text); + } + + writeLine(text: string) { + this.content += `${this.indent(text)}\n`; + } + + conditionalWrite(condition: boolean, text: string) { + if (condition) { + this.write(text); + } + } + + private indent(text: string) { + return ' '.repeat(this.indentLevel * this.indentSize) + text; + } +} diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 88e064512..c5b866417 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -19,9 +19,11 @@ import { import type { RuntimeAttribute } from '@zenstackhq/runtime'; import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; -import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; +import { FunctionDeclarationStructure, OptionalKind, Project, VariableDeclarationKind } from 'ts-morph'; import { + CodeWriter, ExpressionContext, + FastWriter, getAttribute, getAttributeArg, getAttributeArgs, @@ -73,12 +75,20 @@ export function generate( options: ModelMetaGeneratorOptions ) { const sf = project.createSourceFile(options.output, undefined, { overwrite: true }); + + const writer = new FastWriter(); + const extraFunctions: OptionalKind[] = []; + generateModelMetadata(models, typeDefs, writer, options, extraFunctions); + sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, - declarations: [ - { name: 'metadata', initializer: (writer) => generateModelMetadata(models, typeDefs, sf, writer, options) }, - ], + declarations: [{ name: 'metadata', initializer: writer.result }], }); + + if (extraFunctions.length > 0) { + sf.addFunctions(extraFunctions); + } + sf.addStatements('export default metadata;'); if (options.preserveTsFiles) { @@ -91,13 +101,13 @@ export function generate( function generateModelMetadata( dataModels: DataModel[], typeDefs: TypeDef[], - sourceFile: SourceFile, - writer: CodeBlockWriter, - options: ModelMetaGeneratorOptions + writer: CodeWriter, + options: ModelMetaGeneratorOptions, + extraFunctions: OptionalKind[] ) { writer.block(() => { - writeModels(sourceFile, writer, dataModels, options); - writeTypeDefs(sourceFile, writer, typeDefs, options); + writeModels(writer, dataModels, options, extraFunctions); + writeTypeDefs(writer, typeDefs, options, extraFunctions); writeDeleteCascade(writer, dataModels); writeShortNameMap(options, writer); writeAuthModel(writer, dataModels, typeDefs); @@ -105,10 +115,10 @@ function generateModelMetadata( } function writeModels( - sourceFile: SourceFile, - writer: CodeBlockWriter, + writer: CodeWriter, dataModels: DataModel[], - options: ModelMetaGeneratorOptions + options: ModelMetaGeneratorOptions, + extraFunctions: OptionalKind[] ) { writer.write('models:'); writer.block(() => { @@ -117,7 +127,7 @@ function writeModels( writer.block(() => { writer.write(`name: '${model.name}',`); writeBaseTypes(writer, model); - writeFields(sourceFile, writer, model, options); + writeFields(writer, model, options, extraFunctions); writeUniqueConstraints(writer, model); if (options.generateAttributes) { writeModelAttributes(writer, model); @@ -131,10 +141,10 @@ function writeModels( } function writeTypeDefs( - sourceFile: SourceFile, - writer: CodeBlockWriter, + writer: CodeWriter, typedDefs: TypeDef[], - options: ModelMetaGeneratorOptions + options: ModelMetaGeneratorOptions, + extraFunctions: OptionalKind[] ) { if (typedDefs.length === 0) { return; @@ -145,7 +155,7 @@ function writeTypeDefs( writer.write(`${lowerCaseFirst(typeDef.name)}:`); writer.block(() => { writer.write(`name: '${typeDef.name}',`); - writeFields(sourceFile, writer, typeDef, options); + writeFields(writer, typeDef, options, extraFunctions); }); writer.writeLine(','); } @@ -153,7 +163,7 @@ function writeTypeDefs( writer.writeLine(','); } -function writeBaseTypes(writer: CodeBlockWriter, model: DataModel) { +function writeBaseTypes(writer: CodeWriter, model: DataModel) { if (model.superTypes.length > 0) { writer.write('baseTypes: ['); writer.write(model.superTypes.map((t) => `'${t.ref?.name}'`).join(', ')); @@ -161,14 +171,14 @@ function writeBaseTypes(writer: CodeBlockWriter, model: DataModel) { } } -function writeAuthModel(writer: CodeBlockWriter, dataModels: DataModel[], typeDefs: TypeDef[]) { +function writeAuthModel(writer: CodeWriter, dataModels: DataModel[], typeDefs: TypeDef[]) { const authModel = getAuthDecl([...dataModels, ...typeDefs]); if (authModel) { writer.writeLine(`authModel: '${authModel.name}'`); } } -function writeDeleteCascade(writer: CodeBlockWriter, dataModels: DataModel[]) { +function writeDeleteCascade(writer: CodeWriter, dataModels: DataModel[]) { writer.write('deleteCascade:'); writer.block(() => { for (const model of dataModels) { @@ -181,7 +191,7 @@ function writeDeleteCascade(writer: CodeBlockWriter, dataModels: DataModel[]) { writer.writeLine(','); } -function writeUniqueConstraints(writer: CodeBlockWriter, model: DataModel) { +function writeUniqueConstraints(writer: CodeWriter, model: DataModel) { const constraints = getUniqueConstraints(model); if (constraints.length > 0) { writer.write('uniqueConstraints:'); @@ -197,7 +207,7 @@ function writeUniqueConstraints(writer: CodeBlockWriter, model: DataModel) { } } -function writeModelAttributes(writer: CodeBlockWriter, model: DataModel) { +function writeModelAttributes(writer: CodeWriter, model: DataModel) { const attrs = getAttributes(model); if (attrs.length > 0) { writer.write(` @@ -205,7 +215,7 @@ attributes: ${JSON.stringify(attrs)},`); } } -function writeDiscriminator(writer: CodeBlockWriter, model: DataModel) { +function writeDiscriminator(writer: CodeWriter, model: DataModel) { const delegateAttr = getAttribute(model, '@@delegate'); if (!delegateAttr) { return; @@ -220,10 +230,10 @@ function writeDiscriminator(writer: CodeBlockWriter, model: DataModel) { } function writeFields( - sourceFile: SourceFile, - writer: CodeBlockWriter, + writer: CodeWriter, container: DataModel | TypeDef, - options: ModelMetaGeneratorOptions + options: ModelMetaGeneratorOptions, + extraFunctions: OptionalKind[] ) { writer.write('fields:'); writer.block(() => { @@ -279,7 +289,7 @@ function writeFields( } } - const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); + const defaultValueProvider = generateDefaultValueProvider(f, extraFunctions); if (defaultValueProvider) { writer.write(` defaultValueProvider: ${defaultValueProvider},`); @@ -496,7 +506,10 @@ function getDeleteCascades(model: DataModel): string[] { .map((m) => m.name); } -function generateDefaultValueProvider(field: DataModelField | TypeDefField, sourceFile: SourceFile) { +function generateDefaultValueProvider( + field: DataModelField | TypeDefField, + extraFunctions: OptionalKind[] +) { const defaultAttr = getAttribute(field, '@default'); if (!defaultAttr) { return undefined; @@ -515,8 +528,9 @@ function generateDefaultValueProvider(field: DataModelField | TypeDefField, sour // generates a provider function like: // function $default$Model$field(user: any) { ... } - const func = sourceFile.addFunction({ - name: `$default$${field.$container.name}$${field.name}`, + const funcName = `$default$${field.$container.name}$${field.name}`; + extraFunctions.push({ + name: funcName, parameters: [{ name: 'user', type: 'any' }], returnType: 'unknown', statements: (writer) => { @@ -526,7 +540,7 @@ function generateDefaultValueProvider(field: DataModelField | TypeDefField, sour }, }); - return func.getName(); + return funcName; } function isAutoIncrement(field: DataModelField) { @@ -543,7 +557,7 @@ function isAutoIncrement(field: DataModelField) { return isInvocationExpr(arg) && arg.function.$refText === 'autoincrement'; } -function writeShortNameMap(options: ModelMetaGeneratorOptions, writer: CodeBlockWriter) { +function writeShortNameMap(options: ModelMetaGeneratorOptions, writer: CodeWriter) { if (options.shortNameMap && options.shortNameMap.size > 0) { writer.write('shortNameMap:'); writer.block(() => {