diff --git a/app/client/webllm.ts b/app/client/webllm.ts index 9e739db1..5ff1a145 100644 --- a/app/client/webllm.ts +++ b/app/client/webllm.ts @@ -36,21 +36,33 @@ type WebLLMHandler = ServiceWorkerWebLLMHandler | WebWorkerWebLLMHandler; export class WebLLMApi implements LLMApi { private llmConfig?: LLMConfig; private initialized = false; - webllm: WebLLMHandler; + webllm?: WebLLMHandler; + private type: "serviceWorker" | "webWorker"; + private logLevel: LogLevel; + private isEngineInitialized = false; constructor( type: "serviceWorker" | "webWorker", logLevel: LogLevel = "WARN", ) { + this.type = type; + this.logLevel = logLevel; + } + + private initEngine() { + if (this.isEngineInitialized || typeof window === "undefined") { + return; + } + const engineConfig = { appConfig: { ...prebuiltAppConfig, useIndexedDBCache: this.llmConfig?.cache === "index_db", }, - logLevel, + logLevel: this.logLevel, }; - if (type === "serviceWorker") { + if (this.type === "serviceWorker") { log.info("Create ServiceWorkerMLCEngine"); this.webllm = { type: "serviceWorker", @@ -68,12 +80,17 @@ export class WebLLMApi implements LLMApi { ), }; } + this.isEngineInitialized = true; } private async initModel(onUpdate?: (message: string, chunk: string) => void) { if (!this.llmConfig) { throw Error("llmConfig is undefined"); } + this.initEngine(); + if (!this.webllm) { + throw Error("Engine not initialized"); + } this.webllm.engine.setInitProgressCallback((report: InitProgressReport) => { onUpdate?.(report.text, report.text); }); @@ -153,7 +170,8 @@ export class WebLLMApi implements LLMApi { } async abort() { - await this.webllm.engine?.interruptGenerate(); + this.initEngine(); + await this.webllm?.engine?.interruptGenerate(); } private isDifferentConfig(config: LLMConfig): boolean { @@ -227,6 +245,10 @@ export class WebLLMApi implements LLMApi { extraBody.enable_thinking = this.llmConfig?.enable_thinking ?? false; } + this.initEngine(); + if (!this.webllm) { + throw Error("Engine not initialized"); + } const completion = await this.webllm.engine.chatCompletion({ stream: stream, messages: (newMessages || messages) as ChatCompletionMessageParam[], diff --git a/app/components/home.tsx b/app/components/home.tsx index f9908a6f..4e55cef8 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -239,9 +239,9 @@ const useWebLLM = () => { } }, []); - if (webllm?.webllm.type === "serviceWorker") { + if (webllm?.webllm?.type === "serviceWorker") { setInterval(() => { - if (webllm) { + if (webllm?.webllm) { // 10s per heartbeat, dead after 30 seconds of inactivity setWebllmAlive( !!webllm.webllm.engine && @@ -314,7 +314,7 @@ const useLogLevel = (webllm?: WebLLMApi) => { useEffect(() => { log.setLevel(config.logLevel); if (webllm?.webllm?.engine) { - webllm.webllm.engine.setLogLevel(config.logLevel); + webllm?.webllm?.engine.setLogLevel(config.logLevel); } }, [config.logLevel, webllm?.webllm?.engine]); };