Skip to content

Commit 2fb025c

Browse files
committed
feat: connect to mlc-llm REST API endpoint
1 parent 2ac0357 commit 2fb025c

File tree

22 files changed

+694
-589
lines changed

22 files changed

+694
-589
lines changed

app/client/api.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm";
2-
import { CacheType, ModelType } from "../store";
2+
import { CacheType, Model } from "../store";
33
export const ROLES = ["system", "user", "assistant"] as const;
44
export type MessageRole = (typeof ROLES)[number];
55

66
export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
7-
export type ChatModel = ModelType;
7+
export type ChatModel = Model;
88

99
export interface MultimodalContent {
1010
type: "text" | "image_url";
@@ -40,7 +40,6 @@ export interface ChatOptions {
4040
usage?: CompletionUsage,
4141
) => void;
4242
onError?: (err: Error) => void;
43-
onController?: (controller: AbortController) => void;
4443
}
4544

4645
export interface LLMUsage {
@@ -60,7 +59,7 @@ export interface ModelRecord {
6059
buffer_size_required_bytes?: number;
6160
low_resource_required?: boolean;
6261
required_features?: string[];
63-
recommended_config: {
62+
recommended_config?: {
6463
temperature?: number;
6564
top_p?: number;
6665
presence_penalty?: number;
@@ -71,4 +70,5 @@ export interface ModelRecord {
7170
export abstract class LLMApi {
7271
abstract chat(options: ChatOptions): Promise<void>;
7372
abstract abort(): Promise<void>;
73+
abstract models(): Promise<ModelRecord[] | Model[]>;
7474
}

app/client/mlcllm.ts

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import log from "loglevel";
2+
import { ChatOptions, LLMApi } from "./api";
3+
import { ChatCompletionFinishReason } from "@mlc-ai/web-llm";
4+
5+
export class MlcLLMApi implements LLMApi {
6+
private endpoint: string;
7+
private abortController: AbortController | null = null;
8+
9+
constructor(endpoint: string) {
10+
this.endpoint = endpoint;
11+
}
12+
13+
async chat(options: ChatOptions) {
14+
const { messages, config } = options;
15+
16+
const payload = {
17+
messages: messages,
18+
...config,
19+
};
20+
21+
// Instantiate a new AbortController for this request
22+
this.abortController = new AbortController();
23+
const { signal } = this.abortController;
24+
25+
let reply: string = "";
26+
let stopReason: ChatCompletionFinishReason | undefined;
27+
28+
try {
29+
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
30+
method: "POST",
31+
headers: {
32+
"Content-Type": "application/json",
33+
},
34+
body: JSON.stringify(payload),
35+
signal,
36+
});
37+
38+
if (config.stream) {
39+
const reader = response.body!.getReader();
40+
while (true) {
41+
const { value, done } = await reader.read();
42+
if (done) break;
43+
// Extracting the data part from the server response
44+
const chunk = new TextDecoder("utf-8").decode(value);
45+
const result = chunk.match(/data: (.+)/);
46+
if (result) {
47+
const data = JSON.parse(result[1]);
48+
if (data.choices && data.choices.length > 0) {
49+
reply += data.choices[0].delta.content; // Append the content
50+
options.onUpdate?.(reply, chunk); // Handle the chunk update
51+
52+
if (data.choices[0].finish_reason) {
53+
stopReason = data.choices[0].finish_reason;
54+
}
55+
}
56+
}
57+
58+
if (chunk === "[DONE]") {
59+
// Ending the stream when "[DONE]" is found
60+
break;
61+
}
62+
}
63+
options.onFinish(reply, stopReason);
64+
} else {
65+
const data = await response.json();
66+
options.onFinish(
67+
data.choices[0].message.content,
68+
data.choices[0].finish_reason,
69+
);
70+
}
71+
} catch (error: any) {
72+
if (error.name === "AbortError") {
73+
log.info("MLC_LLM: chat aborted");
74+
} else {
75+
log.error("MLC_LLM: Fetch error:", error);
76+
options.onError?.(error);
77+
}
78+
}
79+
}
80+
81+
// Implements the abort method to cancel the request
82+
async abort() {
83+
this.abortController?.abort();
84+
}
85+
86+
async models() {
87+
try {
88+
const response = await fetch(`${this.endpoint}/v1/models`, {
89+
method: "GET",
90+
});
91+
const jsonRes = await response.json();
92+
return jsonRes.data.map((model: { id: string }) => ({
93+
name: model.id,
94+
display_name: model.id.split("/")[model.id.split("/").length - 1],
95+
}));
96+
} catch (error: any) {
97+
log.error("MLC_LLM: Fetch error:", error);
98+
}
99+
}
100+
}

app/client/webllm.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api";
1818
import { LogLevel } from "@mlc-ai/web-llm";
1919
import { fixMessage } from "../utils";
20+
import { DEFAULT_MODELS } from "../constant";
2021

2122
const KEEP_ALIVE_INTERVAL = 5_000;
2223

@@ -215,6 +216,8 @@ export class WebLLMApi implements LLMApi {
215216
usage: chatCompletion.usage,
216217
};
217218
}
218-
}
219219

220-
export const WebLLMContext = createContext<WebLLMApi | undefined>(undefined);
220+
async models() {
221+
return DEFAULT_MODELS;
222+
}
223+
}

app/components/chat.tsx

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ import {
4545
createMessage,
4646
useAppConfig,
4747
DEFAULT_TOPIC,
48-
ModelType,
48+
Model,
49+
ModelClient,
4950
} from "../store";
5051

5152
import {
@@ -92,9 +93,9 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command";
9293
import { prettyObject } from "../utils/format";
9394
import { ExportMessageModal } from "./exporter";
9495
import { MultimodalContent } from "../client/api";
95-
import { WebLLMContext } from "../client/webllm";
9696
import { Template, useTemplateStore } from "../store/template";
9797
import Image from "next/image";
98+
import { MLCLLMContext, WebLLMContext } from "../context";
9899

99100
export function ScrollDownToast(prop: { show: boolean; onclick: () => void }) {
100101
return (
@@ -125,7 +126,7 @@ export function SessionConfigModel(props: { onClose: () => void }) {
125126
};
126127

127128
return (
128-
<div className="modal-template">
129+
<div className="screen-model-container">
129130
<Modal
130131
title={Locale.Context.Edit}
131132
onClose={() => props.onClose()}
@@ -556,7 +557,7 @@ export function ChatActions(props: {
556557
onClose={() => setShowModelSelector(false)}
557558
onSelection={(s) => {
558559
if (s.length === 0) return;
559-
config.selectModel(s[0] as ModelType);
560+
config.selectModel(s[0] as Model);
560561
showToast(s[0]);
561562
}}
562563
/>
@@ -606,6 +607,12 @@ function _Chat() {
606607
const [uploading, setUploading] = useState(false);
607608
const [showEditPromptModal, setShowEditPromptModal] = useState(false);
608609
const webllm = useContext(WebLLMContext)!;
610+
const mlcllm = useContext(MLCLLMContext)!;
611+
612+
const llm =
613+
config.modelClientType === ModelClient.MLCLLM_API ? mlcllm : webllm;
614+
615+
const models = config.models;
609616

610617
// prompt hints
611618
const promptStore = usePromptStore();
@@ -685,7 +692,7 @@ function _Chat() {
685692

686693
if (isStreaming) return;
687694

688-
chatStore.onUserInput(userInput, webllm, attachImages);
695+
chatStore.onUserInput(userInput, llm, attachImages);
689696
setAttachImages([]);
690697
localStorage.setItem(LAST_INPUT_KEY, userInput);
691698
setUserInput("");
@@ -713,7 +720,7 @@ function _Chat() {
713720

714721
// stop response
715722
const onUserStop = () => {
716-
webllm.abort();
723+
llm.abort();
717724
chatStore.stopStreaming();
718725
};
719726

@@ -836,7 +843,7 @@ function _Chat() {
836843
// resend the message
837844
const textContent = getMessageTextContent(userMessage);
838845
const images = getMessageImages(userMessage);
839-
chatStore.onUserInput(textContent, webllm, images);
846+
chatStore.onUserInput(textContent, llm, images);
840847
inputRef.current?.focus();
841848
};
842849

@@ -867,7 +874,13 @@ function _Chat() {
867874
]
868875
: [],
869876
);
870-
}, [config.sendPreviewBubble, context, session.messages, userInput]);
877+
}, [
878+
config.sendPreviewBubble,
879+
context,
880+
session.messages,
881+
session.messages.length,
882+
userInput,
883+
]);
871884

872885
const [msgRenderIndex, _setMsgRenderIndex] = useState(
873886
Math.max(0, renderMessages.length - CHAT_PAGE_SIZE),
@@ -1183,10 +1196,9 @@ function _Chat() {
11831196
)}
11841197
{message.role === "assistant" && (
11851198
<div className={styles["chat-message-role-name"]}>
1186-
{config.models.find((m) => m.name === message.model)
1187-
? config.models.find(
1188-
(m) => m.name === message.model,
1189-
)!.display_name
1199+
{models.find((m) => m.name === message.model)
1200+
? models.find((m) => m.name === message.model)!
1201+
.display_name
11901202
: message.model}
11911203
</div>
11921204
)}

app/components/emoji.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import EmojiPicker, {
44
Theme as EmojiTheme,
55
} from "emoji-picker-react";
66

7-
import { ModelType } from "../store";
7+
import { Model } from "../store";
88

99
import MlcIcon from "../icons/mlc.svg";
1010

@@ -31,7 +31,7 @@ export function AvatarPicker(props: {
3131
);
3232
}
3333

34-
export function Avatar(props: { model?: ModelType; avatar?: string }) {
34+
export function Avatar(props: { model?: Model; avatar?: string }) {
3535
if (props.model) {
3636
return (
3737
<div className="bot-avatar mlc-icon no-dark">

app/components/exporter.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/* eslint-disable @next/next/no-img-element */
2-
import { ChatMessage, ModelType, useAppConfig, useChatStore } from "../store";
2+
import { ChatMessage, Model, useAppConfig, useChatStore } from "../store";
33
import Locale from "../locales";
44
import styles from "./exporter.module.scss";
55
import {
@@ -46,7 +46,7 @@ const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
4646

4747
export function ExportMessageModal(props: { onClose: () => void }) {
4848
return (
49-
<div className="modal-template">
49+
<div className="screen-model-container">
5050
<Modal
5151
title={Locale.Export.Title}
5252
onClose={props.onClose}

app/components/home.tsx

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import styles from "./home.module.scss";
66

77
import log from "loglevel";
88
import dynamic from "next/dynamic";
9-
import { useState, useEffect, useRef } from "react";
9+
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
1010
import {
1111
HashRouter as Router,
1212
Routes,
@@ -20,13 +20,15 @@ import LoadingIcon from "../icons/three-dots.svg";
2020

2121
import Locale from "../locales";
2222
import { getCSSVar, useMobileScreen } from "../utils";
23-
import { Path, SlotID } from "../constant";
23+
import { DEFAULT_MODELS, Path, SlotID } from "../constant";
2424
import { ErrorBoundary } from "./error";
2525
import { getISOLang, getLang } from "../locales";
2626
import { SideBar } from "./sidebar";
2727
import { useAppConfig } from "../store/config";
28-
import { WebLLMApi, WebLLMContext } from "../client/webllm";
29-
import { useChatStore } from "../store";
28+
import { WebLLMApi } from "../client/webllm";
29+
import { ModelClient, useChatStore } from "../store";
30+
import { MLCLLMContext, WebLLMContext } from "../context";
31+
import { MlcLLMApi } from "../client/mlcllm";
3032

3133
export function Loading(props: { noLogo?: boolean }) {
3234
return (
@@ -251,6 +253,17 @@ const useWebLLM = () => {
251253
return { webllm, isWebllmActive };
252254
};
253255

256+
const useMlcLLM = () => {
257+
const config = useAppConfig();
258+
const [mlcllm, setMlcLlm] = useState<MlcLLMApi | undefined>(undefined);
259+
260+
useEffect(() => {
261+
setMlcLlm(new MlcLLMApi(config.modelConfig.mlc_endpoint));
262+
}, [config.modelConfig.mlc_endpoint, setMlcLlm]);
263+
264+
return mlcllm;
265+
};
266+
254267
const useLoadUrlParam = () => {
255268
const config = useAppConfig();
256269

@@ -306,14 +319,32 @@ const useLogLevel = (webllm?: WebLLMApi) => {
306319
}, [config.logLevel, webllm?.webllm?.engine]);
307320
};
308321

322+
const useModels = (mlcllm: MlcLLMApi | undefined) => {
323+
const config = useAppConfig();
324+
325+
useEffect(() => {
326+
if (config.modelClientType == ModelClient.WEBLLM) {
327+
config.setModels(DEFAULT_MODELS);
328+
} else if (config.modelClientType == ModelClient.MLCLLM_API) {
329+
if (mlcllm) {
330+
mlcllm.models().then((models) => {
331+
config.setModels(models);
332+
});
333+
}
334+
}
335+
}, [config.modelClientType, mlcllm]);
336+
};
337+
309338
export function Home() {
310339
const hasHydrated = useHasHydrated();
311340
const { webllm, isWebllmActive } = useWebLLM();
341+
const mlcllm = useMlcLLM();
312342

313343
useSwitchTheme();
314344
useHtmlLang();
315345
useLoadUrlParam();
316346
useStopStreamingMessages();
347+
useModels(mlcllm);
317348
useLogLevel(webllm);
318349

319350
if (!hasHydrated || !webllm || !isWebllmActive) {
@@ -328,7 +359,9 @@ export function Home() {
328359
<ErrorBoundary>
329360
<Router>
330361
<WebLLMContext.Provider value={webllm}>
331-
<Screen />
362+
<MLCLLMContext.Provider value={mlcllm}>
363+
<Screen />
364+
</MLCLLMContext.Provider>
332365
</WebLLMContext.Provider>
333366
</Router>
334367
</ErrorBoundary>

0 commit comments

Comments
 (0)