From f55a2fd64bc08a2c946ae452296761ad3a580d6a Mon Sep 17 00:00:00 2001 From: Youssef Gaber <1728215+gabrola@users.noreply.github.com> Date: Thu, 27 Feb 2025 20:45:29 +0200 Subject: [PATCH 1/2] feat(runtime): inject enhanced client or tx context so it can be retrieved in extensions --- .../runtime/src/enhancements/node/proxy.ts | 7 +- .../proxy/extension-context.test.ts | 71 +++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 tests/integration/tests/enhancements/proxy/extension-context.test.ts diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index e063f002b..c64e664ed 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -289,7 +289,7 @@ export function makeProxy( return propVal; } - return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer); + return createHandlerProxy(makeHandler(target, prop), propVal, prop, target, errorTransformer); }, }); @@ -303,10 +303,15 @@ function createHandlerProxy( handler: T, origTarget: any, model: string, + dbOrTx: any, errorTransformer?: ErrorTransformer ): T { return new Proxy(handler, { get(target, propKey) { + if (propKey === '$zenstack_parent') { + return dbOrTx; + } + const prop = target[propKey as keyof T]; if (typeof prop !== 'function') { // the proxy handler doesn't have this method, fall back to the original target diff --git a/tests/integration/tests/enhancements/proxy/extension-context.test.ts b/tests/integration/tests/enhancements/proxy/extension-context.test.ts new file mode 100644 index 000000000..227d8f0d0 --- /dev/null +++ b/tests/integration/tests/enhancements/proxy/extension-context.test.ts @@ -0,0 +1,71 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Proxy Extension Context', () => { + it('works', async () => { + const { enhance } = await loadSchema( + ` + model Counter { + model String @unique + value Int + + @@allow('all', true) + } + + model Address { + id String @id @default(cuid()) + city String + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + const dbExtended = db.$extends({ + model: { + $allModels: { + async createWithCounter(this: any, args: any) { + const modelName = this.$name; + const dbOrTx = this.$zenstack_parent; + + const fn = async (tx: any) => { + const counter = await tx.counter.findUnique({ + where: { model: modelName }, + }); + + await tx.counter.upsert({ + where: { model: modelName }, + update: { value: (counter?.value ?? 0) + 1 }, + create: { model: modelName, value: 1 }, + }); + + return tx[modelName].create(args); + }; + + if (dbOrTx['$transaction']) { + // not running in a transaction, so we need to create a new transaction + return dbOrTx.$transaction(fn); + } + + return fn(dbOrTx); + }, + }, + }, + }); + + const cities = ['Vienna', 'New York', 'Delhi']; + + await Promise.all([ + ...cities.map((city) => dbExtended.address.createWithCounter({ data: { city } })), + ...cities.map((city) => + dbExtended.$transaction((tx: any) => tx.address.createWithCounter({ data: { city: `${city}$tx` } })) + ), + ]); + + // expecting object + await expect(dbExtended.counter.findUniqueOrThrow({ where: { model: 'Address' } })).resolves.toMatchObject({ + model: 'Address', + value: cities.length * 2, + }); + }); +}); From 981980e4b74e168a0501f0ceb9981045b14f7c77 Mon Sep 17 00:00:00 2001 From: Youssef Gaber <1728215+gabrola@users.noreply.github.com> Date: Sat, 1 Mar 2025 08:04:23 +0200 Subject: [PATCH 2/2] proxy $parent + more robust test --- packages/runtime/src/enhancements/node/proxy.ts | 4 ++-- .../enhancements/proxy/extension-context.test.ts | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index c64e664ed..181d7b26b 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -289,7 +289,7 @@ export function makeProxy( return propVal; } - return createHandlerProxy(makeHandler(target, prop), propVal, prop, target, errorTransformer); + return createHandlerProxy(makeHandler(target, prop), propVal, prop, proxy, errorTransformer); }, }); @@ -308,7 +308,7 @@ function createHandlerProxy( ): T { return new Proxy(handler, { get(target, propKey) { - if (propKey === '$zenstack_parent') { + if (propKey === '$parent') { return dbOrTx; } diff --git a/tests/integration/tests/enhancements/proxy/extension-context.test.ts b/tests/integration/tests/enhancements/proxy/extension-context.test.ts index 227d8f0d0..f84fd1e84 100644 --- a/tests/integration/tests/enhancements/proxy/extension-context.test.ts +++ b/tests/integration/tests/enhancements/proxy/extension-context.test.ts @@ -22,11 +22,19 @@ describe('Proxy Extension Context', () => { const db = enhance(); const dbExtended = db.$extends({ + client: { + $one() { + return 1; + } + }, model: { $allModels: { async createWithCounter(this: any, args: any) { const modelName = this.$name; - const dbOrTx = this.$zenstack_parent; + const dbOrTx = this.$parent; + + // prisma exposes some internal properties, makes sure these are still preserved + expect(dbOrTx._engine).toBeDefined(); const fn = async (tx: any) => { const counter = await tx.counter.findUnique({ @@ -35,8 +43,8 @@ describe('Proxy Extension Context', () => { await tx.counter.upsert({ where: { model: modelName }, - update: { value: (counter?.value ?? 0) + 1 }, - create: { model: modelName, value: 1 }, + update: { value: (counter?.value ?? 0) + tx.$one() }, + create: { model: modelName, value: tx.$one() }, }); return tx[modelName].create(args); @@ -62,7 +70,6 @@ describe('Proxy Extension Context', () => { ), ]); - // expecting object await expect(dbExtended.counter.findUniqueOrThrow({ where: { model: 'Address' } })).resolves.toMatchObject({ model: 'Address', value: cities.length * 2,