Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ public <O> AiMatchHappen<O, D, I, RF, F> match(Operators.Whether<I> whether,
node -> processor.process(new AiState<>(node, this.flow())).state), this.flow());
}

/**
* 指定条件和处理器创建条件分支。
*
* @param whether 表示匹配条件的 {@link Operators.Whether}{@code <}{@link I}{@code >}。
* @param processor 表示分支处理器的 {@link Operators.Then}{@code <}{@link O}{@code , }{@link D}{@code ,
* }{@link I}{@code , }{@link RF}{@code , }{@link F}{@code >}。
* @param <O> 表示第一个条件分支指定的返回类型。
* @return 表示条件分支的 {@link AiMatchHappen}{@code <}{@link O}{@code , }{@link D}{@code ,
* }{@link I}{@code , }{@link RF}{@code , }{@link F}{@code >}。
* @throws IllegalArgumentException 当 {@code processor} 为 {@code null} 时。
*/
public <O> AiWhenHappen<O, D, I, RF, F> when(Operators.Whether<I> whether,
Operators.Then<I, O> processor) {
Validation.notNull(processor, "Ai branch processor cannot be null.");
return new AiWhenHappen<>(this.conditions.when(whether, processor), this.flow());
}

/**
* 指定条件和对应的处理器创建条件跳转分支。
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ public AiState<O, D, I, RF, F> id(String id) {
return this;
}

/**
* Sets the maximum concurrency level for this state's processing pipeline.
*
* @param concurrency The maximum number of concurrent operations allowed (must be positive).
* @return The current state instance for method chaining.
* @throws IllegalArgumentException If the concurrency value is zero or negative.
*/
public AiState<O, D, I, RF, F> concurrency(int concurrency) {
this.state.concurrency(concurrency);
return this;
}

/**
* 获取当前节点的数据订阅者。
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
* This file is a part of the ModelEngine Project.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.engine.activities;

import modelengine.fel.engine.flows.AiFlow;
import modelengine.fit.waterflow.domain.flow.Flow;
import modelengine.fit.waterflow.domain.states.WhenHappen;
import modelengine.fit.waterflow.domain.stream.operators.Operators;
import modelengine.fitframework.inspection.Validation;

/**
* Represents a conditional branch that matches when conditions in an AI processing flow.
* This class handles the branching logic when specific conditions are met in the workflow.
*
* @param <O> The output data type of the current node.
* @param <D> The initial data type of the containing flow.
* @param <I> The input parameter data type.
* @param <RF> The internal flow type, extending {@link Flow}{@code <D>}.
* @param <F> The AI flow type, extending {@link AiFlow}{@code <D, RF>}.
*
* @author 宋永坦
* @since 2025-06-12
*/
public class AiWhenHappen<O, D, I, RF extends Flow<D>, F extends AiFlow<D, RF>> {
private final WhenHappen<O, D, I, RF> matchHappen;

private final F flow;

/**
* Creates a new AI flow matching generator that handles conditional branching.
* This constructor initializes a stateful processor for when/then style pattern matching
* within AI workflows.
*
* @param matchHappen The core matching generator that evaluates conditions.
* @param flow The parent AI flow.
* @throws NullPointerException If either parameter is null.
*/
public AiWhenHappen(WhenHappen<O, D, I, RF> matchHappen, F flow) {
this.matchHappen = Validation.notNull(matchHappen, "WhenHappen cannot be null.");
this.flow = Validation.notNull(flow, "Flow cannot be null.");
}

/**
* Creates a conditional branch with the specified predicate and handler.
*
* @param whether The condition predicate that determines branch activation.
* @param processor The transformation handler to execute when condition is met.
* @return A new {@link AiWhenHappen} instance representing the conditional branch.
* @throws IllegalArgumentException if processor is null.
*/
public AiWhenHappen<O, D, I, RF, F> when(Operators.Whether<I> whether, Operators.Then<I, O> processor) {
Validation.notNull(processor, "Ai branch processor cannot be null.");
return new AiWhenHappen<>(this.matchHappen.when(whether, processor), this.flow);
}

/**
* Provides a default processing logic and terminates the conditional node.
*
* @param processor The handler to process unmatched inputs.
* @return An {@link AiState} representing the terminal node of the conditional flow.
* @throws IllegalArgumentException if processor is null.
*/
public AiState<O, D, O, RF, F> others(Operators.Then<I, O> processor) {
Validation.notNull(processor, "Ai branch processor cannot be null.");
return new AiState<>(this.matchHappen.others(processor), this.flow);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public class Conversation<D, R> {
*/
public Conversation(AiProcessFlow<D, R> flow, FlowSession session) {
this.flow = Validation.notNull(flow, "Flow cannot be null.");
if (session != null) {
session.begin();
}
this.session =
(session == null) ? this.setConverseListener(new FlowSession(true)) : this.setSubConverseListener(session);
this.session.begin();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public FlowEmitter<O> invoke(I data) {
public Pattern<I, O> sync() {
return new SimplePattern<>(data -> {
FlowSession require = AiFlowSession.require();
FlowSession session = new FlowSession(true);
FlowSession session = new FlowSession(require.preserved());
Window window = session.begin();
session.copySessionState(require);
ConverseLatch<O> conversation = this.getFlow().converse(session).offer(data);
Expand All @@ -116,7 +116,7 @@ public Flow<I> origin() {
*/
protected static <O> FlowSession buildFlowSession(FlowEmitter<O> emitter) {
FlowSession mainSession = AiFlowSession.require();
FlowSession flowSession = FlowSession.newRootSession(mainSession, true);
FlowSession flowSession = FlowSession.newRootSession(mainSession, mainSession.preserved());
flowSession.setInnerState(PARENT_SESSION_ID_KEY, mainSession.getId());
ResultAction<O> resultAction = emitter::emit;
flowSession.setInnerState(RESULT_ACTION_KEY, resultAction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
* This file is a part of the ModelEngine Project.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.engine;

import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.ChatOption;
import modelengine.fel.core.chat.support.AiMessage;
import modelengine.fel.core.util.Tip;
import modelengine.fel.engine.flows.AiFlows;
import modelengine.fel.engine.flows.AiProcessFlow;
import modelengine.fel.engine.flows.ConverseLatch;
import modelengine.fel.engine.operators.models.ChatFlowModel;
import modelengine.fel.engine.operators.prompts.Prompts;
import modelengine.fit.waterflow.domain.context.FlowSession;
import modelengine.fit.waterflow.domain.context.StateContext;
import modelengine.fit.waterflow.domain.utils.SleepUtil;
import modelengine.fitframework.flowable.Choir;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;

/**
* Test cases demonstrating different flow control scenarios in AI processing pipelines.
* Contains nested test classes for specific flow control mechanisms.
*
* @author 宋永坦
* @since 2025-06-11
*/
public class AiFlowCaseTest {
private static final int SPEED = 1;

@Nested
class DesensitizeCase {
private final ChatFlowModel model = new ChatFlowModel((prompt, chatOption) -> Choir.create(emitter -> {
emitter.emit(new AiMessage("<think>"));
int takeTime = 10 * SPEED;
SleepUtil.sleep(takeTime);
for (int i = 0; i < 48; i++) {
emitter.emit(new AiMessage(String.valueOf(i)));
SleepUtil.sleep(takeTime);
}
emitter.emit(new AiMessage("</think>"));
SleepUtil.sleep(takeTime);
for (int i = 100; i < 150; i++) {
emitter.emit(new AiMessage(String.valueOf(i)));
SleepUtil.sleep(takeTime);
}
emitter.complete();
}), ChatOption.custom().model("modelName").stream(true).build());

private final AiProcessFlow<Tip, String> flow = AiFlows.<Tip>create()
.prompt(Prompts.human("{{0}}"))
.generate(model)
.map(this::classic)
.conditions()
.when(chunk -> chunk.isThinkContent, input -> input)
.others(input -> {
this.log(input);
return input;
})
.map(this::mockDesensitize1)
.map(this::mockDesensitize2)
.close();

@Test
@DisplayName("DesensitizeCase")
void run() {
AtomicInteger counter = new AtomicInteger(0);
long startTime = System.currentTimeMillis();
System.out.printf("time:%s, start.\n", startTime);
ConverseLatch<String> result = flow.converse(new FlowSession(true)).doOnConsume(answer -> {
System.out.printf("time:%s, chunk=%s\n", System.currentTimeMillis(), answer);
counter.incrementAndGet();
}).offer(Tip.fromArray("hi"));
result.await();
System.out.printf("time:%s, cost=%s\n", System.currentTimeMillis(), System.currentTimeMillis() - startTime);
Assertions.assertEquals(100, counter.get());
}

private Chunk classic(ChatMessage message, StateContext ctx) {
if (message.text().trim().equals("<think>")) {
ctx.setState("isThinking", true);
return new Chunk(true, message.text());
}
if (message.text().trim().equals("</think>")) {
ctx.setState("isThinking", false);
return new Chunk(true, message.text());
}
if (Boolean.TRUE.equals(ctx.getState("isThinking"))) {
return new Chunk(true, message.text());
}
return new Chunk(false, message.text());
}

private String mockDesensitize1(Chunk chunk) {
SleepUtil.sleep(10 * SPEED);
return chunk.content.replace("3", "*");
}

private String mockDesensitize2(String chunk) {
SleepUtil.sleep(10 * SPEED);
return chunk.replace("4", "*");
}

private void log(Chunk chunk) {
System.out.println("log content:" + chunk.content);
}

private static class Chunk {
private final boolean isThinkContent;
private final String content;

private Chunk(boolean isThinkContent, String content) {this.isThinkContent = isThinkContent;
this.content = content;
}
}
}

/**
* Simulates a backpressure scenario where:
* <ol>
* <li>The LLM generates data faster than the TTS can process it.</li>
* <li>TTS processing is constrained to a single thread.</li>
* <li>TTS processing speed is artificially slowed.</li>
* </ol>
*/
@Nested
class BackPressureCase {
private final ChatFlowModel model = new ChatFlowModel((prompt, chatOption) -> Choir.create(emitter -> {
for (int i = 0; i < 100; i++) {
emitter.emit(new AiMessage(String.valueOf(i)));
SleepUtil.sleep(5 * SPEED);
}
emitter.complete();
System.out.printf("time:%s, generate completed.\n", System.currentTimeMillis());
}), ChatOption.custom().model("modelName").stream(true).build());

private final AiProcessFlow<Tip, String> flow = AiFlows.<Tip>create()
.prompt(Prompts.human("{{0}}"))
.generate(model)
.map(this::mockDesensitize).concurrency(1) // Limit processing to 1 concurrent thread
.map(this::mockTTS).concurrency(1) // Limit processing to 1 concurrent thread
.close();

@Test
@DisplayName("BackPressureCase")
void run() {
AtomicInteger counter = new AtomicInteger(0);
long startTime = System.currentTimeMillis();
System.out.printf("time:%s, start.\n", startTime);
ConverseLatch<String> result = flow.converse(new FlowSession(false)).doOnConsume(answer -> {
System.out.printf("time:%s, chunk=%s\n", System.currentTimeMillis(), answer);
counter.incrementAndGet();
}).offer(Tip.fromArray("hi"));
result.await();
System.out.printf("time:%s, cost=%s\n", System.currentTimeMillis(), System.currentTimeMillis() - startTime);
Assertions.assertEquals(100, counter.get());
}

private String mockDesensitize(ChatMessage chunk) {
// Simulate time-consuming operation with a delay.
SleepUtil.sleep(10 * SPEED);
return chunk.text().replace("3", "*");
}

private String mockTTS(String chunk) {
// Simulate time-consuming operation with a delay.
SleepUtil.sleep(10 * SPEED);
return chunk;
}
}

/**
* Demonstrates concurrent processing with balanced throughput where:
* <ol>
* <li>LLM generates data at moderate pace.</li>
* <li>Downstream processing runs with 3 concurrent threads.</li>
* <li>Processing speed is slightly slower than generation (3 : 1).</li>
* </ol>
*/
@Nested
class ConcurrencyCase {
private final ChatFlowModel model = new ChatFlowModel((prompt, chatOption) -> Choir.create(emitter -> {
for (int i = 0; i < 100; i++) {
emitter.emit(new AiMessage(String.valueOf(i)));
SleepUtil.sleep(10 * SPEED);
}
emitter.complete();
}), ChatOption.custom().model("modelName").stream(true).build());

private final AiProcessFlow<Tip, String> flow = AiFlows.<Tip>create()
.prompt(Prompts.human("{{0}}"))
.generate(model)
.map(this::mockDesensitize).concurrency(3) // Set processing to 3 concurrent thread
.close();

@Test
@DisplayName("ConcurrencyCase")
void run() {
AtomicInteger counter = new AtomicInteger(0);
long startTime = System.currentTimeMillis();
System.out.printf("time:%s, start.\n", startTime);
ConverseLatch<String> result = flow.converse(new FlowSession(false)).doOnConsume(answer -> {
System.out.printf("time:%s, chunk=%s\n", System.currentTimeMillis(), answer);
counter.incrementAndGet();
}).offer(Tip.fromArray("hi"));
result.await();
System.out.printf("time:%s, cost=%s\n", System.currentTimeMillis(), System.currentTimeMillis() - startTime);
Assertions.assertEquals(100, counter.get());
}

private String mockDesensitize(ChatMessage chunk) {
// Simulate slower processing at 1/3 speed of LLM generation.
SleepUtil.sleep(30 * SPEED);
return chunk.text().replace("3", "*");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,18 @@ default <I> void send(ProcessType type, Subscriber<I, ?> subscriber, List<FlowCo
* @param <I> 流程实例执行时的入参数据类型,用于泛型推倒
*/
<I> void sendCallback(List<FlowContext<I>> contexts);

/**
* Directly processes a list of flow contexts through the specified subscriber.
* This method serves as a default implementation for immediate processing without
* any intermediate transformations or routing.
*
* @param <I> The type of input data contained in the flow contexts.
* @param type The type of processing to be performed.
* @param subscriber The subscriber that will handle the processing.
* @param context List of flow contexts to be processed.
*/
default <I> void directProcess(ProcessType type, Subscriber<I, ?> subscriber, List<FlowContext<I>> context) {
subscriber.process(type, context);
}
}
Loading