diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/memory/support/RecentMemory.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/memory/support/RecentMemory.java new file mode 100644 index 00000000..424a4423 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/memory/support/RecentMemory.java @@ -0,0 +1,99 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.memory.support; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.memory.Memory; +import modelengine.fel.core.template.BulkStringTemplate; +import modelengine.fel.core.template.support.DefaultBulkStringTemplate; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.MapBuilder; + +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static modelengine.fitframework.inspection.Validation.notNull; + +/** + * 表示使用最近一定次数历史记录的实现。 + * + * @author 宋永坦 + * @since 2025-07-04 + */ +public class RecentMemory implements Memory { + private final Queue records; + private final BulkStringTemplate bulkTemplate; + private final Function> extractor; + + /** + * 指定最大保留历史记录数量的构造方法。 + * + * @param maxCount 表示最大保留历史记录数量的 {@code int}。 + * @throws IllegalArgumentException 当 {@code maxCount < 0} 时。 + */ + public RecentMemory(int maxCount) { + this(maxCount, + new DefaultBulkStringTemplate("{{type}}:{{text}}", "\n"), + message -> MapBuilder.get() + .put("type", message.type().getRole()) + .put("text", message.text()) + .build()); + } + + /** + * 指定最大保留历史记录数量、渲染模板、抽取方法的构造方法。 + * + * @param maxCount 表示最大保留历史记录数量的 {@code int}。 + * @param bulkTemplate 表示批量字符串模板的 {@link BulkStringTemplate}。 + * @param extractor 表示将 {@link ChatMessage} 转换成 + * {@link Map}{@code <}{@link String}, {@link String}{@code >} 的处理函数。 + * @throws IllegalArgumentException 当 {@code maxCount < 0}、{@code bulkTemplate}、{@code extractor} 为 {@code null} 时。 + */ + public RecentMemory(int maxCount, BulkStringTemplate bulkTemplate, + Function> extractor) { + Validation.greaterThanOrEquals(maxCount, 0, "The max count should >= 0."); + this.records = new ArrayBlockingQueue<>(maxCount); + this.bulkTemplate = notNull(bulkTemplate, "The bulkTemplate cannot be null."); + this.extractor = notNull(extractor, "The extractor cannot be null."); + } + + @Override + public void add(ChatMessage message) { + notNull(message, "The message cannot be null."); + if (!this.records.offer(message)) { + this.records.poll(); + this.records.offer(message); + } + } + + @Override + public void set(List messages) { + notNull(messages, "The messages cannot be null."); + messages.forEach(this::add); + } + + @Override + public void clear() { + this.records.clear(); + } + + @Override + public List messages() { + return this.records.stream().toList(); + } + + @Override + public String text() { + return this.records.stream() + .map(this.extractor) + .collect(Collectors.collectingAndThen(Collectors.toList(), this.bulkTemplate::render)); + } +} diff --git a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/memory/support/RecentMemoryTest.java b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/memory/support/RecentMemoryTest.java new file mode 100644 index 00000000..ee6ddc97 --- /dev/null +++ b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/memory/support/RecentMemoryTest.java @@ -0,0 +1,62 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.memory.support; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.support.AiMessage; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * 表示 {@link RecentMemory} 的测试。 + * + * @author 宋永坦 + * @since 2025-07-04 + */ +class RecentMemoryTest { + private final List inputChatMessages = + Arrays.asList(new AiMessage("1"), new AiMessage("2"), new AiMessage("3")); + + @Test + void shouldKeepAllMessagesWhenAddGivenLessMessage() { + RecentMemory recentMemory = new RecentMemory(4); + this.inputChatMessages.forEach(recentMemory::add); + List messages = recentMemory.messages(); + + assertEquals(inputChatMessages.size(), messages.size()); + for (int i = 0; i < inputChatMessages.size(); ++i) { + assertEquals(inputChatMessages.get(i).text(), messages.get(i).text()); + } + } + + @Test + void shouldKeepMaxCountMessagesWhenAddGivenOverMaxCountMessages() { + RecentMemory recentMemory = new RecentMemory(2); + this.inputChatMessages.forEach(recentMemory::add); + List messages = recentMemory.messages(); + + assertEquals(2, messages.size()); + assertEquals(inputChatMessages.get(1).text(), messages.get(0).text()); + assertEquals(inputChatMessages.get(2).text(), messages.get(1).text()); + } + + @Test + void shouldKeepMaxCountMessagesWhenSetGivenOverMaxCountMessages() { + RecentMemory recentMemory = new RecentMemory(2); + recentMemory.set(this.inputChatMessages); + List messages = recentMemory.messages(); + + assertEquals(2, messages.size()); + assertEquals(inputChatMessages.get(1).text(), messages.get(0).text()); + assertEquals(inputChatMessages.get(2).text(), messages.get(1).text()); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-flow/pom.xml b/framework/fel/java/fel-flow/pom.xml index 04faa3f8..cd1b0a3a 100644 --- a/framework/fel/java/fel-flow/pom.xml +++ b/framework/fel/java/fel-flow/pom.xml @@ -74,5 +74,10 @@ assertj-core test + + org.mockito + mockito-core + test + \ No newline at end of file diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java index 25b40a98..54963968 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java @@ -9,10 +9,10 @@ import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.ChatOption; import modelengine.fel.core.memory.Memory; +import modelengine.fel.core.memory.support.RecentMemory; import modelengine.fel.engine.activities.AiStart; import modelengine.fel.engine.activities.FlowCallBack; import modelengine.fel.engine.operators.models.StreamingConsumer; -import modelengine.fel.engine.operators.sources.Source; import modelengine.fel.engine.util.StateKey; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.stream.operators.Operators; @@ -33,6 +33,8 @@ * @since 2024-04-28 */ public class Conversation { + private static final int DEFAULT_HISTORY_COUNT = 20; + private final AiProcessFlow flow; private final FlowSession session; private final AtomicReference> converseListener = new AtomicReference<>(null); @@ -66,6 +68,7 @@ public Conversation(AiProcessFlow flow, FlowSession session) { @SafeVarargs public final ConverseLatch offer(D... data) { ConverseLatch latch = setListener(this.flow); + this.initMemory(); FlowSession newSession = FlowSession.newRootSession(this.session, this.session.preserved()); newSession.getWindow().setFrom(null); this.flow.start().offer(data, newSession); @@ -85,6 +88,7 @@ public final ConverseLatch offer(D... data) { public ConverseLatch offer(String nodeId, List data) { Validation.notBlank(nodeId, "invalid nodeId."); ConverseLatch latch = setListener(this.flow); + this.initMemory(); FlowSession newSession = new FlowSession(this.session); newSession.getWindow().setFrom(null); this.flow.origin().offer(nodeId, data.toArray(new Object[0]), newSession); @@ -231,4 +235,10 @@ private FlowSession setConverseListener(FlowSession session) { session.setInnerState(StateKey.CONVERSE_LISTENER, new AtomicReference<>(new ConcurrentHashMap<>())); return session; } + + private void initMemory() { + if (this.session.getInnerState(StateKey.HISTORY) == null) { + this.session.setInnerState(StateKey.HISTORY, new RecentMemory(DEFAULT_HISTORY_COUNT)); + } + } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java index c13d0f67..278d3df9 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java @@ -8,12 +8,15 @@ import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.Prompt; +import modelengine.fel.core.chat.support.HumanMessage; +import modelengine.fel.core.memory.Memory; import modelengine.fel.engine.util.StateKey; import modelengine.fit.waterflow.bridge.fitflow.FitBoundedEmitter; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fitframework.flowable.Publisher; import modelengine.fitframework.inspection.Validation; import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.StringUtils; /** * 流式模型发射器。 @@ -26,6 +29,8 @@ public class LlmEmitter extends FitBoundedEmitter consumer; + private final Memory memory; + private final ChatMessage question; /** * 初始化 {@link LlmEmitter}。 @@ -38,6 +43,9 @@ public LlmEmitter(Publisher publisher, Prompt prompt, FlowSession session) { super(publisher, data -> data); Validation.notNull(session, "The session cannot be null."); this.consumer = ObjectUtils.nullIf(session.getInnerState(StateKey.STREAMING_CONSUMER), EMPTY_CONSUMER); + this.memory = session.getInnerState(StateKey.HISTORY); + this.question = + ObjectUtils.getIfNull(session.getInnerState(StateKey.HISTORY_INPUT), () -> getDefaultQuestion(prompt)); } @Override @@ -46,4 +54,21 @@ public void emit(ChatMessage data, FlowSession trans) { this.chunkAcc.merge(data); this.consumer.accept(this.chunkAcc, data); } + + @Override + public void complete() { + if (this.memory != null && this.chunkAcc.toolCalls().isEmpty()) { + this.memory.add(this.question); + this.memory.add(this.chunkAcc); + } + super.complete(); + } + + private static ChatMessage getDefaultQuestion(Prompt prompt) { + int size = prompt.messages().size(); + if (size == 0) { + return new HumanMessage(StringUtils.EMPTY); + } + return prompt.messages().get(size - 1); + } } diff --git a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/models/LlmEmitterTest.java b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/models/LlmEmitterTest.java new file mode 100644 index 00000000..55d65530 --- /dev/null +++ b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/models/LlmEmitterTest.java @@ -0,0 +1,77 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.models; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.Prompt; +import modelengine.fel.core.chat.support.AiMessage; +import modelengine.fel.core.chat.support.ChatMessages; +import modelengine.fel.core.memory.Memory; +import modelengine.fel.core.tool.ToolCall; +import modelengine.fel.engine.util.StateKey; +import modelengine.fit.waterflow.domain.context.FlowSession; +import modelengine.fitframework.flowable.Choir; +import modelengine.fitframework.util.StringUtils; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * 表示 {@link LlmEmitter} 的测试。 + * + * @author 宋永坦 + * @since 2025-07-05 + */ +class LlmEmitterTest { + @Test + void shouldAddMemoryWhenCompleteGivenLlmOutput() { + String output = "data1"; + Prompt prompt = ChatMessages.fromList(Collections.emptyList()); + Choir dataSource = Choir.create(emitter -> { + emitter.emit(new AiMessage(output)); + emitter.complete(); + }); + FlowSession flowSession = new FlowSession(); + Memory mockMemory = Mockito.mock(Memory.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatMessage.class); + Mockito.doNothing().when(mockMemory).add(captor.capture()); + flowSession.setInnerState(StateKey.HISTORY, mockMemory); + + LlmEmitter llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession); + llmEmitter.start(flowSession); + + List captured = captor.getAllValues(); + assertEquals(2, captured.size()); + assertEquals(StringUtils.EMPTY, captured.get(0).text()); + assertEquals(output, captured.get(1).text()); + } + + @Test + void shouldNotAddMemoryWhenCompleteGivenLlmToolCallOutput() { + String output = "data1"; + Prompt prompt = ChatMessages.fromList(Collections.emptyList()); + Choir dataSource = Choir.create(emitter -> { + emitter.emit(new AiMessage(output, Arrays.asList(ToolCall.custom().id("id1").build()))); + emitter.complete(); + }); + FlowSession flowSession = new FlowSession(); + Memory mockMemory = Mockito.mock(Memory.class); + flowSession.setInnerState(StateKey.HISTORY, mockMemory); + + LlmEmitter llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession); + llmEmitter.start(flowSession); + + Mockito.verify(mockMemory, Mockito.times(0)).add(Mockito.any()); + } +} \ No newline at end of file