diff --git a/bun.lockb b/bun.lockb index 3fc99d0..4edd0d1 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/example/index.ts b/example/index.ts index f1f9fd6..515cf5c 100644 --- a/example/index.ts +++ b/example/index.ts @@ -1,4 +1,4 @@ -import { Elysia, ws, t } from 'elysia' +import { Elysia, t } from 'elysia' import { trpc, compile as c } from '../src' import { initTRPC } from '@trpc/server' @@ -34,11 +34,11 @@ const router = p.router({ export type Router = typeof router new Elysia() - .use(ws()) .get('/', () => 'tRPC') .use( trpc(router, { - createContext + createContext, + useSubscription: true }) ) .listen(8080, ({ hostname, port }) => { diff --git a/package.json b/package.json index 609ab86..adaab34 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@elysiajs/trpc", - "version": "0.5.2", + "version": "0.7", "description": "A plugin for Elysia that add support for using tRPC", "author": { "name": "saltyAom", @@ -40,13 +40,13 @@ "@types/node": "^20.1.4", "@types/ws": "^8.5.4", "bun-types": "^0.5.8", - "elysia": "0.5.12", + "elysia": "^0.7.15", "eslint": "^8.40.0", "rimraf": "4.4.1", "typescript": "^5.0.4" }, "peerDependencies": { "@trpc/server": ">= 10.0.0", - "elysia": ">= 0.5.12" + "elysia": ">= 0.7.15" } } diff --git a/src/index.ts b/src/index.ts index e313d3c..80d4c50 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,13 +1,22 @@ -import { Elysia, getSchemaValidator } from 'elysia' - -import { callProcedure, TRPCError, type Router } from '@trpc/server' +import { + DefinitionBase, + Elysia, + InputSchema, + MergeSchema, + RouteSchema, + UnwrapSchema, + getSchemaValidator +} from 'elysia' + +import { TRPCError, callProcedure, type Router } from '@trpc/server' import { fetchRequestHandler } from '@trpc/server/adapters/fetch' -import { isObservable, Unsubscribable } from '@trpc/server/observable' +import { Unsubscribable, isObservable } from '@trpc/server/observable' -import { transformTRPCResponse, getTRPCErrorFromUnknown } from './utils' +import { getTRPCErrorFromUnknown, transformTRPCResponse } from './utils' import type { TSchema } from '@sinclair/typebox' import type { TRPCClientIncomingRequest, TRPCOptions } from './types' +import { getErrorShape } from '@trpc/server/shared' export function compile(schema: T) { const check = getSchemaValidator(schema, {}) @@ -34,6 +43,8 @@ const getPath = (url: string) => { return url.slice(start, end) } +type ClientSubscripted = Map + export const trpc = ( router: Router, @@ -41,8 +52,8 @@ export const trpc = endpoint: '/trpc' } ) => - (eri: Elysia) => { - let app = eri + (eri: Elysia): Elysia => { + const app = eri .onParse(async ({ request: { url } }) => { if (getPath(url).startsWith(endpoint)) return true }) @@ -63,28 +74,47 @@ export const trpc = }) }) - const observers: Map = new Map() + const observers: Map = new Map() - // @ts-ignore - if (app.wsRouter) - app.ws(endpoint, { + if (options.useSubscription) + app.ws(endpoint, { + open(ws) { + const id = + ws.data.headers['sec-websocket-key'] ?? + crypto.randomUUID() + + // @ts-ignore + ws.data.id = id + }, async message(ws, message) { + // @ts-ignore + const id = ws.data.id + + if (!observers.get(id)) { + observers.set(id, new Map()) + } + + const msg = + typeof message === 'string' + ? JSON.parse(message) + : message + const messages: TRPCClientIncomingRequest[] = Array.isArray( - message + msg ) - ? message - : [message] + ? msg + : [msg] - let observer: Unsubscribable | undefined + await Promise.allSettled(messages.map((incoming) => {})) for (const incoming of messages) { - if(!incoming.method || !incoming.params) { - continue - } - if (incoming.method === 'subscription.stop') { + const clientObservers = observers.get(id) + const observer = clientObservers?.get( + incoming.id.toString() + ) observer?.unsubscribe() - observers.delete(ws.data.id.toString()) + clientObservers?.delete(incoming.id.toString()) return void ws.send( JSON.stringify({ @@ -97,99 +127,122 @@ export const trpc = ) } - const result = await callProcedure({ - procedures: router._def.procedures, - path: incoming.params.path, - rawInput: incoming.params.input?.json, - type: incoming.method, - ctx: {} - }) + if (!incoming.method || !incoming.params) { + continue + } - if (incoming.method !== 'subscription') - return void ws.send( + const sendErrorMessage = (err: unknown) => { + ws.send( JSON.stringify( transformTRPCResponse(router, { id: incoming.id, jsonrpc: incoming.jsonrpc, - result: { - type: 'data', - data: result - } + error: getErrorShape({ + error: getTRPCErrorFromUnknown(err), + type: incoming.method as 'subscription', + path: incoming.params.path, + input: incoming.params.input, + ctx: {}, + config: router._def._config + }) }) ) ) + } - ws.send( - JSON.stringify({ - id: incoming.id, - jsonrpc: incoming.jsonrpc, - result: { - type: 'started' - } - }) - ) - - if (!isObservable(result)) - throw new TRPCError({ - message: `Subscription ${incoming.params.path} did not return an observable`, - code: 'INTERNAL_SERVER_ERROR' + try { + const result = await callProcedure({ + procedures: router._def.procedures, + path: incoming.params.path, + rawInput: incoming.params.input?.json, + type: incoming.method, + ctx: {} }) - observer = result.subscribe({ - next(data) { - ws.send( + if (incoming.method !== 'subscription') { + return void ws.send( JSON.stringify( transformTRPCResponse(router, { id: incoming.id, jsonrpc: incoming.jsonrpc, result: { type: 'data', - data + data: result } }) ) ) - }, - error(err) { - ws.send( - JSON.stringify( - transformTRPCResponse(router, { - id: incoming.id, - jsonrpc: incoming.jsonrpc, - error: router.getErrorShape({ - error: getTRPCErrorFromUnknown( - err - ), - type: incoming.method as 'subscription', - path: incoming.params.path, - input: incoming.params.input, - ctx: {} + } + + ws.send( + JSON.stringify({ + id: incoming.id, + jsonrpc: incoming.jsonrpc, + result: { + type: 'started' + } + }) + ) + + if (!isObservable(result)) { + throw new TRPCError({ + message: `Subscription ${incoming.params.path} did not return an observable`, + code: 'INTERNAL_SERVER_ERROR' + }) + } + + const observer = result.subscribe({ + next(data) { + ws.send( + JSON.stringify( + transformTRPCResponse(router, { + id: incoming.id, + jsonrpc: incoming.jsonrpc, + result: { + type: 'data', + data + } }) - }) + ) ) - ) - }, - complete() { - ws.send( - JSON.stringify( - transformTRPCResponse(router, { - id: incoming.id, - jsonrpc: incoming.jsonrpc, - result: { - type: 'stopped' - } - }) + }, + error(err) { + sendErrorMessage(err) + }, + complete() { + ws.send( + JSON.stringify( + transformTRPCResponse(router, { + id: incoming.id, + jsonrpc: incoming.jsonrpc, + result: { + type: 'stopped' + } + }) + ) ) - ) - } - }) + } + }) - observers.set(ws.data.id.toString(), observer) + observers + .get(id) + ?.set(incoming.id.toString(), observer) + } catch (err) { + sendErrorMessage(err) + } } }, close(ws) { - observers.get(ws.data.id.toString())?.unsubscribe() - observers.delete(ws.data.id.toString()) + // @ts-ignore + const id = ws.data.id + + const clientObservers = observers.get(id) + + clientObservers?.forEach((val, key) => { + val.unsubscribe() + }) + + observers.delete(id) } }) diff --git a/src/types.ts b/src/types.ts index f0abb4f..98130fa 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,3 @@ -import type { Router } from '@trpc/server' import type { FetchHandlerRequestOptions } from '@trpc/server/adapters/fetch' export interface TRPCClientIncomingRequest { @@ -24,4 +23,5 @@ export interface TRPCOptions * @default '/trpc' */ endpoint?: string + useSubscription?: boolean; }