diff --git a/framework/fel/java/fel-community/model-openai/pom.xml b/framework/fel/java/fel-community/model-openai/pom.xml index 8bfe6306..8177f875 100644 --- a/framework/fel/java/fel-community/model-openai/pom.xml +++ b/framework/fel/java/fel-community/model-openai/pom.xml @@ -29,6 +29,10 @@ org.fitframework fit-util + + org.fitframework.service + fit-security + @@ -53,6 +57,15 @@ org.assertj assertj-core + + org.fitframework + fit-test-framework + + + com.h2database + h2 + test + @@ -90,7 +103,7 @@ + todir="../../../../../build/plugins"/> diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java index c1c0e57a..df62005e 100644 --- a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java @@ -17,29 +17,51 @@ import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbedding; import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingRequest; import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse; +import modelengine.fel.community.model.openai.entity.image.OpenAiImageRequest; +import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse; +import modelengine.fel.community.model.openai.enums.ModelProcessingState; import modelengine.fel.community.model.openai.util.HttpUtils; import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.ChatModel; import modelengine.fel.core.chat.ChatOption; import modelengine.fel.core.chat.Prompt; +import modelengine.fel.core.chat.support.AiMessage; import modelengine.fel.core.embed.EmbedModel; import modelengine.fel.core.embed.EmbedOption; import modelengine.fel.core.embed.Embedding; +import modelengine.fel.core.image.ImageModel; +import modelengine.fel.core.image.ImageOption; +import modelengine.fel.core.model.http.SecureConfig; +import modelengine.fit.http.client.HttpClassicClient; import modelengine.fit.http.client.HttpClassicClientFactory; import modelengine.fit.http.client.HttpClassicClientRequest; import modelengine.fit.http.client.HttpClassicClientResponse; import modelengine.fit.http.entity.ObjectEntity; import modelengine.fit.http.protocol.HttpRequestMethod; +import modelengine.fit.security.Decryptor; import modelengine.fitframework.annotation.Component; +import modelengine.fitframework.annotation.Fit; +import modelengine.fitframework.conf.Config; import modelengine.fitframework.exception.FitException; import modelengine.fitframework.flowable.Choir; +import modelengine.fitframework.ioc.BeanContainer; +import modelengine.fitframework.ioc.BeanFactory; +import modelengine.fitframework.log.Logger; import modelengine.fitframework.resource.UrlUtils; +import modelengine.fitframework.resource.web.Media; import modelengine.fitframework.serialization.ObjectSerializer; import modelengine.fitframework.util.CollectionUtils; +import modelengine.fitframework.util.LazyLoader; +import modelengine.fitframework.util.MapBuilder; +import modelengine.fitframework.util.ObjectUtils; import modelengine.fitframework.util.StringUtils; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; /** * 表示 openai 模型服务。 @@ -48,31 +70,53 @@ * @since 2024-08-07 */ @Component -public class OpenAiModel implements EmbedModel, ChatModel { +public class OpenAiModel implements EmbedModel, ChatModel, ImageModel { + private static final Logger log = Logger.get(OpenAiModel.class); + private static final Map HTTPS_CONFIG_KEY_MAPS = MapBuilder.get() + .put("client.http.secure.ignore-trust", Boolean.FALSE) + .put("client.http.secure.ignore-hostname", Boolean.FALSE) + .put("client.http.secure.trust-store-file", Boolean.FALSE) + .put("client.http.secure.trust-store-password", Boolean.TRUE) + .put("client.http.secure.key-store-file", Boolean.FALSE) + .put("client.http.secure.key-store-password", Boolean.TRUE) + .build(); + private static final String RESPONSE_TEMPLATE = "{0}{1}"; + private final HttpClassicClientFactory httpClientFactory; - private final HttpClassicClientFactory.Config config; + private final HttpClassicClientFactory.Config clientConfig; private final String baseUrl; private final String defaultApiKey; private final ObjectSerializer serializer; + private final Config config; + private final Decryptor decryptor; + private final LazyLoader httpClient; /** * 创建 openai 嵌入模型服务的实例。 * * @param httpClientFactory 表示 http 客户端工厂的 {@link HttpClassicClientFactory}。 - * @param config 表示 openai http 配置的 {@link OpenAiConfig}。 + * @param clientConfig 表示 openai http 配置的 {@link OpenAiConfig}。 * @param serializer 表示对象序列化器的 {@link ObjectSerializer}。 + * @param config 表示配置信息的 {@link Config}。 + * @param container 表示 bean 容器的 {@link BeanContainer}。 * @throws IllegalArgumentException 当 {@code httpClientFactory}、{@code config} 为 {@code null} 时。 */ - public OpenAiModel(HttpClassicClientFactory httpClientFactory, OpenAiConfig config, ObjectSerializer serializer) { - notNull(config, "The config cannot be null."); + public OpenAiModel(HttpClassicClientFactory httpClientFactory, OpenAiConfig clientConfig, + @Fit(alias = "json") ObjectSerializer serializer, Config config, BeanContainer container) { + notNull(clientConfig, "The config cannot be null."); this.httpClientFactory = notNull(httpClientFactory, "The http client factory cannot be null."); - this.config = HttpClassicClientFactory.Config.builder() - .connectTimeout(config.getConnectTimeout()) - .socketTimeout(config.getReadTimeout()) + this.clientConfig = HttpClassicClientFactory.Config.builder() + .connectTimeout(clientConfig.getConnectTimeout()) + .socketTimeout(clientConfig.getReadTimeout()) .build(); this.serializer = notNull(serializer, "The serializer cannot be null."); - this.baseUrl = config.getApiBase(); - this.defaultApiKey = config.getApiKey(); + this.baseUrl = clientConfig.getApiBase(); + this.defaultApiKey = clientConfig.getApiKey(); + this.httpClient = new LazyLoader<>(this::getHttpClient); + this.config = config; + this.decryptor = container.lookup(Decryptor.class) + .map(BeanFactory::get) + .orElseGet(() -> encrypted -> encrypted); } @Override @@ -80,7 +124,7 @@ public List generate(List inputs, EmbedOption option) { notEmpty(inputs, "The input cannot be empty."); notNull(option, "The embed option cannot be null."); notBlank(option.model(), "The embed model name cannot be null."); - HttpClassicClientRequest request = this.httpClientFactory.create(this.config) + HttpClassicClientRequest request = this.httpClient.get() .createRequest(HttpRequestMethod.POST, UrlUtils.combine(this.baseUrl, OpenAiApi.EMBEDDING_ENDPOINT)); HttpUtils.setBearerAuth(request, StringUtils.blankIf(option.apiKey(), this.defaultApiKey)); request.jsonEntity(new OpenAiEmbeddingRequest(inputs, option.model())); @@ -98,19 +142,61 @@ public List generate(List inputs, EmbedOption option) { public Choir generate(Prompt prompt, ChatOption chatOption) { notNull(prompt, "The prompt cannot be null."); notNull(chatOption, "The chat option cannot be null."); - HttpClassicClientRequest request = this.httpClientFactory.create(this.config) - .createRequest(HttpRequestMethod.POST, UrlUtils.combine(this.baseUrl, OpenAiApi.CHAT_ENDPOINT)); + String modelSource = StringUtils.blankIf(chatOption.baseUrl(), this.baseUrl); + HttpClassicClientRequest request = this.getHttpClient(chatOption.secureConfig()) + .createRequest(HttpRequestMethod.POST, UrlUtils.combine(modelSource, OpenAiApi.CHAT_ENDPOINT)); HttpUtils.setBearerAuth(request, StringUtils.blankIf(chatOption.apiKey(), this.defaultApiKey)); request.jsonEntity(new OpenAiChatCompletionRequest(prompt, chatOption)); return chatOption.stream() ? this.createChatStream(request) : this.createChatCompletion(request); } + @Override + public List generate(String prompt, ImageOption option) { + notNull(prompt, "The prompt cannot be null."); + notNull(option, "The image option cannot be null."); + String modelSource = StringUtils.blankIf(option.baseUrl(), this.baseUrl); + HttpClassicClientRequest request = this.httpClient.get() + .createRequest(HttpRequestMethod.POST, UrlUtils.combine(modelSource, OpenAiApi.IMAGE_ENDPOINT)); + HttpUtils.setBearerAuth(request, StringUtils.blankIf(option.apiKey(), this.defaultApiKey)); + request.jsonEntity(new OpenAiImageRequest(option.model(), option.size(), prompt)); + Class clazz = OpenAiImageResponse.class; + try (HttpClassicClientResponse response = request.exchange(clazz)) { + return response.objectEntity() + .map(entity -> entity.object().media()) + .orElseThrow(() -> new FitException("The response body is abnormal.")); + } catch (IOException e) { + throw new IllegalStateException("Failed to close response.", e); + } + } + private Choir createChatStream(HttpClassicClientRequest request) { + AtomicReference modelProcessingState = + new AtomicReference<>(ModelProcessingState.INITIAL); return request.exchangeStream(String.class) .filter(str -> !StringUtils.equals(str, "[DONE]")) .map(str -> this.serializer.deserialize(str, OpenAiChatCompletionResponse.class)) - .map(OpenAiChatCompletionResponse::message); + .map(response -> getChatMessage(response, modelProcessingState)); + } + + private ChatMessage getChatMessage(OpenAiChatCompletionResponse response, + AtomicReference state) { + // 适配reasoning_content格式返回的模型推理内容,模型生成内容顺序为先reasoning_content后content + // 在第一个reasoning_content chunk之前增加标签,并且在第一个content chunk之前增加标签 + if (state.get() == ModelProcessingState.INITIAL && StringUtils.isNotEmpty(response.reasoningContent().text())) { + String text = "" + response.reasoningContent().text(); + state.set(ModelProcessingState.THINKING); + return new AiMessage(text, response.message().toolCalls()); + } + if (state.get() == ModelProcessingState.THINKING && StringUtils.isNotEmpty(response.message().text())) { + String text = "" + response.message().text(); + state.set(ModelProcessingState.RESPONDING); + return new AiMessage(text, response.message().toolCalls()); + } + if (state.get() == ModelProcessingState.THINKING) { + return new AiMessage(response.reasoningContent().text(), response.message().toolCalls()); + } + return response.message(); } private Choir createChatCompletion(HttpClassicClientRequest request) { @@ -119,9 +205,64 @@ private Choir createChatCompletion(HttpClassicClientRequest request OpenAiChatCompletionResponse chatCompletionResponse = response.objectEntity() .map(ObjectEntity::object) .orElseThrow(() -> new FitException("The response body is abnormal.")); - return Choir.just(chatCompletionResponse.message()); + String finalMessage = chatCompletionResponse.message().text(); + if (StringUtils.isNotBlank(chatCompletionResponse.reasoningContent().text())) { + finalMessage = StringUtils.format(RESPONSE_TEMPLATE, + chatCompletionResponse.reasoningContent().text(), + finalMessage); + } + return Choir.just(new AiMessage(finalMessage, chatCompletionResponse.message().toolCalls())); } catch (IOException e) { throw new FitException(e); } } + + private HttpClassicClient getHttpClient() { + Map custom = HTTPS_CONFIG_KEY_MAPS.keySet() + .stream() + .filter(sslKey -> this.config.keys().contains(Config.canonicalizeKey(sslKey))) + .collect(Collectors.toMap(sslKey -> sslKey, sslKey -> { + Object value = this.config.get(sslKey, Object.class); + if (HTTPS_CONFIG_KEY_MAPS.get(sslKey)) { + value = this.decryptor.decrypt(ObjectUtils.cast(value)); + } + return value; + })); + + return this.httpClientFactory.create(HttpClassicClientFactory.Config.builder() + .socketTimeout(this.clientConfig.socketTimeout()) + .connectTimeout(this.clientConfig.connectTimeout()) + .custom(custom) + .build()); + } + + private HttpClassicClient getHttpClient(SecureConfig secureConfig) { + if (secureConfig == null) { + return getHttpClient(); + } + + Map custom = buildHttpsConfig(secureConfig); + return this.httpClientFactory.create(HttpClassicClientFactory.Config.builder() + .socketTimeout(this.clientConfig.socketTimeout()) + .connectTimeout(this.clientConfig.connectTimeout()) + .custom(custom) + .build()); + } + + private Map buildHttpsConfig(SecureConfig secureConfig) { + Map result = new HashMap<>(); + putConfigIfNotNull(secureConfig.ignoreTrust(), "client.http.secure.ignore-trust", result); + putConfigIfNotNull(secureConfig.ignoreHostName(), "client.http.secure.ignore-hostname", result); + putConfigIfNotNull(secureConfig.trustStoreFile(), "client.http.secure.trust-store-file", result); + putConfigIfNotNull(secureConfig.trustStorePassword(), "client.http.secure.trust-store-password", result); + putConfigIfNotNull(secureConfig.keyStoreFile(), "client.http.secure.key-store-file", result); + putConfigIfNotNull(secureConfig.keyStorePassword(), "client.http.secure.key-store-password", result); + return result; + } + + private static void putConfigIfNotNull(Object value, String key, Map result) { + if (value != null) { + result.put(key, value); + } + } } \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java index 1b64f02b..e59b4014 100644 --- a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java @@ -24,6 +24,11 @@ public interface OpenAiApi { */ String EMBEDDING_ENDPOINT = "/embeddings"; + /** + * 图像生成请求的端点。 + */ + String IMAGE_ENDPOINT = "/images/generations"; + /** * 请求头模型密钥字段。 */ diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatCompletionResponse.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatCompletionResponse.java index 516a5335..1ed60581 100644 --- a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatCompletionResponse.java +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatCompletionResponse.java @@ -6,8 +6,6 @@ package modelengine.fel.community.model.openai.entity.chat; -import static modelengine.fitframework.util.ObjectUtils.cast; - import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.support.AiMessage; import modelengine.fel.core.tool.ToolCall; @@ -16,7 +14,10 @@ import modelengine.fitframework.util.CollectionUtils; import modelengine.fitframework.util.StringUtils; +import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.function.Function; /** * OpenAi API 格式的会话补全响应。 @@ -36,6 +37,21 @@ public class OpenAiChatCompletionResponse { * @return 表示模型回复的 {@link ChatMessage}。 */ public ChatMessage message() { + return extractMessage(OpenAiChatMessage::content, OpenAiChatMessage::toolCalls); + } + + /** + * 获取响应中的模型推理。 + * + * @return 表示模型回复的 {@link ChatMessage}。 + */ + public ChatMessage reasoningContent() { + return extractMessage(OpenAiChatMessage::reasoningContent, OpenAiChatMessage::toolCalls); + } + + private ChatMessage extractMessage( + Function contentExtractor, + Function> toolCallsExtractor) { if (CollectionUtils.isEmpty(choices)) { return EMPTY_RESPONSE; } @@ -43,11 +59,15 @@ public ChatMessage message() { if (openAiChatMessage == null) { return EMPTY_RESPONSE; } - String content = StringUtils.EMPTY; - if (openAiChatMessage.content() instanceof String) { - content = cast(openAiChatMessage.content()); - } - List toolCalls = CollectionUtils.asParent(openAiChatMessage.toolCalls()); + + String content = Optional.ofNullable(contentExtractor.apply(openAiChatMessage)) + .filter(obj -> obj instanceof String) + .map(obj -> (String) obj) + .orElse(StringUtils.EMPTY); + + List toolCalls = Optional.ofNullable(toolCallsExtractor.apply(openAiChatMessage)) + .orElse(Collections.emptyList()); + return new AiMessage(content, toolCalls); } diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatMessage.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatMessage.java index 6be3928b..618e8511 100644 --- a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatMessage.java +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiChatMessage.java @@ -40,6 +40,8 @@ public class OpenAiChatMessage { private String toolCallId; @Property(name = "tool_calls") private List toolCalls; + @Property(name = "reasoning_content") + private String reasoningContent; /** * 将 {@link ChatMessage} 对象转换为 {@link OpenAiChatMessage} 对象。 @@ -79,6 +81,15 @@ public Object content() { return this.content; } + /** + * 获取模型推理内容。 + * + * @return 表示推理内容的 {@link String}。 + */ + public String reasoningContent() { + return this.reasoningContent; + } + /** * 获取消息的工具调用。 * diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiToolCall.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiToolCall.java index 5f7eeb98..1675adc9 100644 --- a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiToolCall.java +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/chat/OpenAiToolCall.java @@ -8,6 +8,7 @@ import modelengine.fel.core.tool.ToolCall; import modelengine.fitframework.inspection.Nonnull; +import modelengine.fitframework.serialization.annotation.SerializeStrategy; /** * 表示 {@link ToolCall} 的 openai 实现。 @@ -15,10 +16,12 @@ * @author 易文渊 * @since 2024-08-17 */ +@SerializeStrategy(include = SerializeStrategy.Include.NON_NULL) public class OpenAiToolCall implements ToolCall { private String id; private final String type = "function"; private FunctionCall function; + private Integer index; /** * 使用 {@link ToolCall} 构造一个新的 {@link OpenAiToolCall}。 @@ -33,6 +36,7 @@ public static OpenAiToolCall from(ToolCall toolCall) { OpenAiToolCall openAiToolCall = new OpenAiToolCall(); openAiToolCall.id = toolCall.id(); openAiToolCall.function = functionCall; + openAiToolCall.index = toolCall.index(); return openAiToolCall; } @@ -42,6 +46,12 @@ public String id() { return this.id; } + @Nonnull + @Override + public Integer index() { + return this.index; + } + @Nonnull @Override public String name() { @@ -64,7 +74,7 @@ public static class FunctionCall { @Override public String toString() { - return "ToolCall{" + "id='" + id + '\'' + ", name='" + this.function.name + '\'' + ", arguments='" - + this.function.arguments + '\'' + '}'; + return "ToolCall{" + "id='" + id + '\'' + "index='" + index + '\'' + ", name='" + this.function.name + '\'' + + ", arguments='" + this.function.arguments + '\'' + '}'; } } \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImage.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImage.java new file mode 100644 index 00000000..07149d22 --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImage.java @@ -0,0 +1,40 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai.entity.image; + +import modelengine.fitframework.annotation.Property; +import modelengine.fitframework.exception.FitException; +import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.util.StringUtils; + +import java.net.MalformedURLException; +import java.net.URL; + +/** + * 表示 OpenAi 格式的图片。 + * + * @author 何嘉斌 + * @since 2024-12-17 + */ +public class OpenAiImage { + @Property(name = "b64_json") + private String b64Json; + private String url; + + /** + * 获取图片媒体资源。 + * + * @return 表示图片媒体资源的 {@link Media}。 + */ + public Media media() { + try { + return StringUtils.isNotBlank(b64Json) ? new Media("image/jpeg", b64Json) : new Media(new URL(url)); + } catch (MalformedURLException ex) { + throw new FitException(ex); + } + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageRequest.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageRequest.java new file mode 100644 index 00000000..f428a68c --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageRequest.java @@ -0,0 +1,34 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai.entity.image; + +import static modelengine.fitframework.inspection.Validation.notBlank; + +/** + * 表示 OpenAi Api 格式的图片生成请求。 + * + * @author 何嘉斌 + * @since 2024-12-17 + */ +public class OpenAiImageRequest { + private final String model; + private final String size; + private final String prompt; + + /** + * 创建一个新的 OpenAi API 格式的图片生成请求。 + * + * @param model 表示调用的模型名称的 {@link String}。 + * @param size 表示生成图片规格的 {@link String}。 + * @param prompt 表示用户输入提示词的 {@link String}。 + */ + public OpenAiImageRequest(String model, String size, String prompt) { + this.model = notBlank(model, "The model cannot be blank."); + this.size = notBlank(size, "The image size cannot be blank."); + this.prompt = notBlank(prompt, "The prompt cannot be blank."); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageResponse.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageResponse.java new file mode 100644 index 00000000..a5911c0f --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageResponse.java @@ -0,0 +1,34 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai.entity.image; + +import modelengine.fitframework.resource.web.Media; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * 表示 OpenAi API 格式的图片生成响应。 + * + * @author 何嘉斌 + * @since 2024-12-17 + */ +public class OpenAiImageResponse { + /** + * 模型生成的 Image 列表。 + */ + private List data; + + /** + * 获取模型生成的图片列表。 + * + * @return 表示模型嵌入向量列表的 {@link List}{@code <}{@link Media}{@code >}。 + */ + public List media() { + return this.data.stream().map(OpenAiImage::media).collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/enums/ModelProcessingState.java b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/enums/ModelProcessingState.java new file mode 100644 index 00000000..1c9488c3 --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/enums/ModelProcessingState.java @@ -0,0 +1,30 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai.enums; + +/** + * 模型内容生成状态枚举类。 + * + * @author 孙怡菲 + * @since 2025-04-29 + */ +public enum ModelProcessingState { + /** + * 表示初始状态。 + */ + INITIAL, + + /** + * 表示内部推理状态。 + */ + THINKING, + + /** + * 表示结果生成状态。 + */ + RESPONDING +} diff --git a/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/OpenAiModelTest.java b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/OpenAiModelTest.java new file mode 100644 index 00000000..35c79d1e --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/OpenAiModelTest.java @@ -0,0 +1,95 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai; + +import static org.assertj.core.api.Assertions.assertThat; + +import modelengine.fel.community.model.openai.config.OpenAiConfig; +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.ChatOption; +import modelengine.fel.core.chat.support.ChatMessages; +import modelengine.fel.core.chat.support.HumanMessage; +import modelengine.fel.core.embed.EmbedOption; +import modelengine.fel.core.embed.Embedding; +import modelengine.fel.core.image.ImageOption; +import modelengine.fit.http.client.HttpClassicClientFactory; +import modelengine.fitframework.annotation.Fit; +import modelengine.fitframework.conf.Config; +import modelengine.fitframework.flowable.Choir; +import modelengine.fitframework.ioc.BeanContainer; +import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.serialization.ObjectSerializer; +import modelengine.fitframework.test.annotation.MvcTest; +import modelengine.fitframework.test.domain.mvc.MockMvc; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * {@link OpenAiModel} 的模型测试。 + * + * @author 刘信宏 + * @since 2024-09-23 + */ +@MvcTest(classes = TestModelController.class) +public class OpenAiModelTest { + private OpenAiModel openAiModel; + + @Fit + private HttpClassicClientFactory httpClientFactory; + + @Fit + private ObjectSerializer serializer; + + @Fit + private Config config; + + @Fit + private BeanContainer container; + + @Fit + private MockMvc mockMvc; + + @BeforeEach + public void setUp() { + OpenAiConfig openAiConfig = new OpenAiConfig(); + openAiConfig.setApiBase("http://localhost:" + mockMvc.getPort()); + this.openAiModel = new OpenAiModel(this.httpClientFactory, openAiConfig, this.serializer, config, container); + } + + @Test + @DisplayName("测试聊天流式返回") + void testOpenAiChatModelStreamService() { + List contents = Arrays.asList("1", "2", "3"); + Choir choir = this.openAiModel.generate(ChatMessages.from(new HumanMessage("hello")), + ChatOption.custom().stream(true).model("model").build()); + List response = choir.blockAll(); + assertThat(response).extracting(ChatMessage::text).isEqualTo(contents); + } + + @Test + @DisplayName("测试嵌入模型返回") + void testOpenAiEmbeddingModel() { + Embedding embedding = this.openAiModel.generate("1", EmbedOption.custom().model("model").build()); + assertThat(embedding.embedding()).containsExactly(1f, 2f, 3f); + } + + @Test + @DisplayName("测试图片生成模型返回") + void testOpenAiImageModel() { + List images = + this.openAiModel.generate("prompt", ImageOption.custom().model("model").size("256x256").build()); + assertThat(images.stream().map(Media::getData).collect(Collectors.toList())).containsExactly("123", + "456", + "789"); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/TestModelController.java b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/TestModelController.java new file mode 100644 index 00000000..d32edff4 --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/TestModelController.java @@ -0,0 +1,85 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai; + +import static modelengine.fel.community.model.openai.api.OpenAiApi.CHAT_ENDPOINT; +import static modelengine.fel.community.model.openai.api.OpenAiApi.EMBEDDING_ENDPOINT; +import static modelengine.fel.community.model.openai.api.OpenAiApi.IMAGE_ENDPOINT; + +import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse; +import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse; +import modelengine.fit.http.annotation.PostMapping; +import modelengine.fitframework.annotation.Component; +import modelengine.fitframework.flowable.Choir; +import modelengine.fitframework.serialization.ObjectSerializer; + +/** + * 表示测试使用的聊天接口。 + * + * @author 易文渊 + * @since 2024-09-24 + */ +@Component +public class TestModelController { + private final ObjectSerializer serializer; + + /** + * 创建 {@link TestModelController} 的实例。 + * + * @param serializer 表示对象序列化器的 {@link ObjectSerializer}。 + */ + public TestModelController(ObjectSerializer serializer) { + this.serializer = serializer; + } + + /** + * 测试用聊天接口。 + * + * @return 表示流式返回结果的 {@link Choir}{@code <}{@link String}{@code >}。 + */ + @PostMapping(CHAT_ENDPOINT) + public Choir chat() { + return Choir.create(emitter -> { + for (int i = 1; i <= 3; ++i) { + emitter.emit(getMockStreamResponseChunk(String.valueOf(i))); + } + emitter.emit("[DONE]"); + emitter.complete(); + }); + } + + /** + * 测试用嵌入接口。 + * + * @return 表示嵌入响应的 {@link OpenAiEmbeddingResponse}。 + */ + @PostMapping(EMBEDDING_ENDPOINT) + public OpenAiEmbeddingResponse embed() { + String json = "{\"object\":\"list\"," + + "\"data\":[{\"index\":0,\"object\":\"embedding\",\"embedding\":[1.0,2.0,3.0]}]," + + "\"usage\":{\"prompt_tokens\":1,\"total_tokens\":2}}"; + return this.serializer.deserialize(json, OpenAiEmbeddingResponse.class); + } + + private String getMockStreamResponseChunk(String content) { + return "{\"id\": \"0\"," + "\"object\": \"chat.completion.chunk\"," + "\"created\": 0," + + "\"model\": \"test_model\"," + "\"choices\": [{\"index\": 0,\"delta\": {\"content\": \"" + content + + "\"}," + "\"finish_reason\": null}]}"; + } + + /** + * 测试用图片生成接口。 + * + * @return 表示嵌入响应的 {@link OpenAiImageResponse}。 + */ + @PostMapping(IMAGE_ENDPOINT) + public OpenAiImageResponse image() { + String json = "{\"object\":\"list\"," + + "\"data\":[{\"b64_json\":\"123\"}, {\"b64_json\":\"456\"}, {\"b64_json\":\"789\"}]}"; + return this.serializer.deserialize(json, OpenAiImageResponse.class); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageEntityTest.java b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageEntityTest.java new file mode 100644 index 00000000..36d9bd4c --- /dev/null +++ b/framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/entity/image/OpenAiImageEntityTest.java @@ -0,0 +1,48 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.model.openai.entity.image; + +import static org.assertj.core.api.Assertions.assertThat; + +import modelengine.fit.serialization.json.jackson.JacksonObjectSerializer; +import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.serialization.ObjectSerializer; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.stream.Collectors; + +/** + * 测试 {@link modelengine.fel.community.model.openai.entity.image} 下对象的序列化和反序列化。 + * + * @author 何嘉斌 + * @since 2024-12-18 + */ +public class OpenAiImageEntityTest { + private static final ObjectSerializer SERIALIZER = new JacksonObjectSerializer(null, null, null, false); + + @Test + @DisplayName("测试序列化图片生成请求成功") + void giveOpenAiImageRequestThenSerializeOk() { + OpenAiImageRequest request = new OpenAiImageRequest("model", "256x256", "prompt"); + String excepted = "{\"model\":\"model\",\"size\":\"256x256\",\"prompt\":\"prompt\"}"; + assertThat(SERIALIZER.serialize(request)).isEqualTo(excepted); + } + + @Test + @DisplayName("测试反序列化图片生成响应成功") + void giveOpenAiImageResponseThenDeserializeToMediaOk() { + String json = "{\"object\":\"list\"," + "\"data\":[{\"url\":\"https://huawei.com\"}, {\"b64_json\":\"456\"}]}"; + OpenAiImageResponse response = SERIALIZER.deserialize(json, OpenAiImageResponse.class); + assertThat(response).extracting(r -> r.media().stream().map(Media::getMime).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(null, "image/jpeg")); + assertThat(response).extracting(r -> r.media().stream().map(Media::getData).collect(Collectors.toList())) + .isEqualTo(Arrays.asList("https://huawei.com", "456")); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/pom.xml b/framework/fel/java/fel-community/pom.xml index b775aa49..5a056e15 100644 --- a/framework/fel/java/fel-community/pom.xml +++ b/framework/fel/java/fel-community/pom.xml @@ -14,5 +14,6 @@ model-openai + tokenizer-hanlp \ No newline at end of file diff --git a/framework/fel/java/fel-community/tokenizer-hanlp/pom.xml b/framework/fel/java/fel-community/tokenizer-hanlp/pom.xml new file mode 100644 index 00000000..93928d44 --- /dev/null +++ b/framework/fel/java/fel-community/tokenizer-hanlp/pom.xml @@ -0,0 +1,90 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-community-parent + 3.5.0-SNAPSHOT + + + fel-tokenizer-hanlp-plugin + + + + + org.fitframework + fit-api + + + + + org.fitframework.fel + fel-core + + + + + com.hankcs + hanlp + + + + + org.junit.jupiter + junit-jupiter + + + org.assertj + assertj-core + + + + + + + org.fitframework + fit-build-maven-plugin + ${fit.version} + + user + 1 + + + + build-plugin + + build-plugin + + + + package-plugin + + package-plugin + + + + + + org.apache.maven.plugins + maven-antrun-plugin + ${maven.antrun.version} + + + package + + + + + + + run + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/fel-community/tokenizer-hanlp/src/main/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizer.java b/framework/fel/java/fel-community/tokenizer-hanlp/src/main/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizer.java new file mode 100644 index 00000000..7a3061be --- /dev/null +++ b/framework/fel/java/fel-community/tokenizer-hanlp/src/main/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizer.java @@ -0,0 +1,42 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.tokenizer.hanlp; + +import com.hankcs.hanlp.HanLP; +import com.hankcs.hanlp.seg.Segment; + +import modelengine.fel.core.tokenizer.Tokenizer; +import modelengine.fitframework.annotation.Component; +import modelengine.fitframework.util.StringUtils; + +import java.util.List; + +/** + * 表示 {@link Tokenizer} 的 hanlp 实现。 + * + * @author 易文渊 + * @since 2024-09-24 + */ +@Component +public class HanlpTokenizer implements Tokenizer { + private final Segment segment = HanLP.newSegment().enablePartOfSpeechTagging(false).enableOffset(false); + + @Override + public List encode(String text) { + throw new UnsupportedOperationException("The operator encode is not support."); + } + + @Override + public String decode(List tokens) { + throw new UnsupportedOperationException("The operator decode is not support."); + } + + @Override + public int countToken(String text) { + return StringUtils.isBlank(text) ? 0 : segment.seg(text).size(); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-community/tokenizer-hanlp/src/main/resources/application.yml b/framework/fel/java/fel-community/tokenizer-hanlp/src/main/resources/application.yml new file mode 100644 index 00000000..037b4809 --- /dev/null +++ b/framework/fel/java/fel-community/tokenizer-hanlp/src/main/resources/application.yml @@ -0,0 +1,4 @@ +fit: + beans: + packages: + - 'modelengine.fel.community.tokenizer.hanlp' \ No newline at end of file diff --git a/framework/fel/java/fel-community/tokenizer-hanlp/src/test/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizerTest.java b/framework/fel/java/fel-community/tokenizer-hanlp/src/test/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizerTest.java new file mode 100644 index 00000000..b771f399 --- /dev/null +++ b/framework/fel/java/fel-community/tokenizer-hanlp/src/test/java/modelengine/fel/community/tokenizer/hanlp/HanlpTokenizerTest.java @@ -0,0 +1,31 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.community.tokenizer.hanlp; + +import static org.assertj.core.api.Assertions.assertThat; + +import modelengine.fitframework.util.StringUtils; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * 表示 {@link HanlpTokenizer} 的测试集。 + * + * @author 易文渊 + * @since 2024-09-24 + */ +@DisplayName("测试 hanlpTokenizer") +public class HanlpTokenizerTest { + @Test + @DisplayName("测试分词") + void testCountToken() { + HanlpTokenizer tokenizer = new HanlpTokenizer(); + assertThat(tokenizer.countToken(StringUtils.EMPTY)).isEqualTo(0); + assertThat(tokenizer.countToken("你好")).isEqualTo(1); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/pom.xml b/framework/fel/java/fel-core/pom.xml index 192c8fe4..508ea866 100644 --- a/framework/fel/java/fel-core/pom.xml +++ b/framework/fel/java/fel-core/pom.xml @@ -25,6 +25,10 @@ org.fitframework fit-util + + org.fitframework.service + fit-http-classic + @@ -43,6 +47,15 @@ org.assertj assertj-core + + org.fitframework + fit-test-framework + + + com.h2database + h2 + test + diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/ChatOption.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/ChatOption.java index fcb04a0a..77bab59b 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/ChatOption.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/ChatOption.java @@ -7,7 +7,7 @@ package modelengine.fel.core.chat; import modelengine.fel.core.tool.ToolInfo; -import modelengine.fitframework.inspection.Nonnull; +import modelengine.fel.core.model.http.SecureConfig; import modelengine.fitframework.pattern.builder.BuilderFactory; import java.util.List; @@ -25,7 +25,6 @@ public interface ChatOption { * * @return 表示模型名字的 {@link String}。 */ - @Nonnull String model(); /** @@ -36,9 +35,15 @@ public interface ChatOption { * * @return 表示是否使用流式接口的 {@code boolean}。 */ - @Nonnull Boolean stream(); + /** + * 大模型服务端地址。 + * + * @return 表示大模型服务端地址的 {@link String}。 + */ + String baseUrl(); + /** * 获取模型接口秘钥。 * @@ -122,6 +127,13 @@ public interface ChatOption { */ List tools(); + /** + * 获取调用大模型服务的安全配置。 + * + * @return 表示调用大模型服务安全配置的 {@link SecureConfig}。 + */ + SecureConfig secureConfig(); + /** * {@link ChatOption} 的构建器。 */ @@ -137,10 +149,18 @@ interface Builder { /** * 设置是否使用流式接口。 * - * @param stream 表示是否使用流式接口的 {@code boolean}。 + * @param stream 表示是否使用流式接口的 {@code Boolean}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder stream(Boolean stream); + + /** + * 设置模型服务端地址。 + * + * @param baseUrl 表示大模型服务端地址的 {@link String}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder stream(boolean stream); + Builder baseUrl(String baseUrl); /** * 设置模型接口秘钥。 @@ -153,26 +173,26 @@ interface Builder { /** * 设置生成文本的最大长度。 * - * @param maxTokens 表示生成文本最大长度的 {@code int}。 + * @param maxTokens 表示生成文本最大长度的 {@code Integer}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder maxTokens(int maxTokens); + Builder maxTokens(Integer maxTokens); /** * 设置频率惩罚系数。 * - * @param frequencyPenalty 表示频率惩罚系数的 {@code double}。 + * @param frequencyPenalty 表示频率惩罚系数的 {@code Double}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder frequencyPenalty(double frequencyPenalty); + Builder frequencyPenalty(Double frequencyPenalty); /** * 设置文本出现惩罚系数。 * - * @param presencePenalty 表示文本出现惩罚系数的 {@code double}。 + * @param presencePenalty 表示文本出现惩罚系数的 {@code Double}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder presencePenalty(double presencePenalty); + Builder presencePenalty(Double presencePenalty); /** * 设置停止字符串列表。 @@ -185,18 +205,18 @@ interface Builder { /** * 设置采样温度。 * - * @param temperature 表示采样温度的 {@code double}。 + * @param temperature 表示采样温度的 {@code Double}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder temperature(double temperature); + Builder temperature(Double temperature); /** * 设置采样率。 * - * @param topP 表示采样率的 {@code double}。 + * @param topP 表示采样率的 {@code Double}。 * @return 表示当前构建器的 {@link Builder}。 */ - Builder topP(double topP); + Builder topP(Double topP); /** * 设置模型能使用的工具列表。 @@ -206,6 +226,14 @@ interface Builder { */ Builder tools(List tools); + /** + * 设置调用大模型服务的安全配置。 + * + * @param secureConfig 表示调用大模型服务安全配置的 {@link SecureConfig}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder secureConfig(SecureConfig secureConfig); + /** * 构建对象。 * diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/MessageType.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/MessageType.java index 5bc4fdb0..9ddea524 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/MessageType.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/MessageType.java @@ -6,6 +6,11 @@ package modelengine.fel.core.chat; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + /** * 表示消息类型的枚举。 * @@ -35,6 +40,9 @@ public enum MessageType { private final String role; + private static final Map RELATIONSHIP = + Arrays.stream(MessageType.values()).collect(Collectors.toMap(MessageType::getRole, Function.identity())); + MessageType(String role) { this.role = role; } @@ -47,4 +55,14 @@ public enum MessageType { public String getRole() { return role; } + + /** + * 根据字符串获取 {@link MessageType} 的实例。 + * + * @param role 表示消息角色的 {@link String}。 + * @return 表示消息类型的 {@link MessageType}。 + */ + public static MessageType parse(String role) { + return RELATIONSHIP.getOrDefault(role, MessageType.HUMAN); + } } \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/ChatMessages.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/ChatMessages.java index 5744f94b..d11f3efd 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/ChatMessages.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/ChatMessages.java @@ -35,6 +35,18 @@ public static ChatMessages from(ChatMessage... message) { return from(Arrays.asList(message)); } + /** + * 使用聊天消息数组创建 {@link ChatMessages} 的实例。 + * + * @param messages 表示聊天消息的 {@link List}{@code }。 + * @return 表示创建成功的 {@link ChatMessages}。 + */ + public static ChatMessages fromList(List messages) { + ChatMessages chatMessages = new ChatMessages(); + chatMessages.messages().addAll(messages); + return chatMessages; + } + /** * 从给定的提示中创建 {@link ChatMessages} 的实例。 * diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/FlatChatMessage.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/FlatChatMessage.java new file mode 100644 index 00000000..ca5064a4 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/chat/support/FlatChatMessage.java @@ -0,0 +1,79 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.chat.support; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.MessageType; +import modelengine.fel.core.tool.ToolCall; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.StringUtils; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * 表示聊天消息的传输实现。 + * + * @author 易文渊 + * @since 2024-04-12 + */ +public class FlatChatMessage implements ChatMessage { + private String id; + private String type; + private String text; + private List medias; + private List toolCalls; + + /** + * 根据{@link ChatMessage} 构造消息传输对象。 + * + * @param chatMessage 提供构造参数的 {@link ChatMessage}。 + * @return 表示创建成功的 {@link FlatChatMessage}。 + */ + public static FlatChatMessage from(ChatMessage chatMessage) { + Validation.notNull(chatMessage, "The chat message cannot be null."); + if (chatMessage instanceof FlatChatMessage) { + return (FlatChatMessage) chatMessage; + } + Validation.notNull(chatMessage.type(), "The message type cannot be null."); + FlatChatMessage flatMessage = new FlatChatMessage(); + flatMessage.id = chatMessage.id().orElse(null); + flatMessage.type = chatMessage.type().getRole(); + flatMessage.text = chatMessage.text(); + flatMessage.medias = chatMessage.medias(); + flatMessage.toolCalls = chatMessage.toolCalls(); + return flatMessage; + } + + @Override + public Optional id() { + return Optional.ofNullable(this.id); + } + + @Override + public MessageType type() { + return MessageType.parse(this.type); + } + + @Override + public String text() { + return ObjectUtils.nullIf(this.text, StringUtils.EMPTY); + } + + @Override + public List medias() { + return ObjectUtils.nullIf(this.medias, Collections.emptyList()); + } + + @Override + public List toolCalls() { + return ObjectUtils.nullIf(this.toolCalls, Collections.emptyList()); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/Measurable.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/Measurable.java index 24d6fb15..9d660fa1 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/Measurable.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/Measurable.java @@ -6,6 +6,8 @@ package modelengine.fel.core.document; +import modelengine.fitframework.inspection.Nonnull; + /** * 表示具有量化能力的对象。 * @@ -19,4 +21,12 @@ public interface Measurable { * @return 表示当前对象的量化分数的 {@code double}。 */ double score(); + + /** + * 获取文档的分组标识。 + * + * @return 表示文档分组标识的 {@link String}。 + */ + @Nonnull + String group(); } \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/MeasurableDocument.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/MeasurableDocument.java index e3091e6e..e796090f 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/MeasurableDocument.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/MeasurableDocument.java @@ -9,7 +9,9 @@ import static modelengine.fitframework.inspection.Validation.notNull; import modelengine.fitframework.inspection.Nonnull; +import modelengine.fitframework.inspection.Validation; import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.util.UuidUtils; import java.util.List; import java.util.Map; @@ -22,8 +24,12 @@ * @since 2024-08-06 */ public class MeasurableDocument implements Document, Measurable { - private final Document document; + private final String id; + private final String text; + private final String groupId; private final double score; + private final Map metadata; + private final List medias; /** * 创建 {@link MeasurableDocument} 的实体。 @@ -33,30 +39,47 @@ public class MeasurableDocument implements Document, Measurable { * @throws IllegalArgumentException 当 {@code document} 为 {@code null} 时。 */ public MeasurableDocument(Document document, double score) { - this.document = notNull(document, "The document cannot be null."); + this(document, score, UuidUtils.randomUuidString()); + } + + /** + * 创建 {@link MeasurableDocument} 的实体。 + * + * @param document 表示原始文档的 {@link Document}。 + * @param score 表示文档评分的 {@code double}。 + * @param groupId 表示文档的分组标识的 {@link String}。 + * @throws IllegalArgumentException 当 {@code document} 为 {@code null} 时。 + */ + public MeasurableDocument(Document document, double score, String groupId) { + notNull(document, "The document cannot be null."); + this.id = document.id(); + this.text = document.text(); this.score = score; + this.groupId = Validation.notBlank(groupId, "The groupId cannot be null."); + this.metadata = document.metadata(); + this.medias = document.medias(); } @Override @Nonnull public String text() { - return this.document.text(); + return this.text; } @Override public List medias() { - return this.document.medias(); + return this.medias; } @Override public String id() { - return this.document.id(); + return this.id; } @Nonnull @Override public Map metadata() { - return this.document.metadata(); + return this.metadata; } @Override @@ -64,6 +87,12 @@ public double score() { return this.score; } + @Override + @Nonnull + public String group() { + return this.groupId; + } + @Override public boolean equals(Object object) { if (this == object) { @@ -73,16 +102,18 @@ public boolean equals(Object object) { return false; } MeasurableDocument that = (MeasurableDocument) object; - return Double.compare(this.score, that.score) == 0 && Objects.equals(this.document, that.document); + return Double.compare(this.score, that.score) == 0 && Objects.equals(this.id, that.id); } @Override public int hashCode() { - return Objects.hash(this.document, this.score); + return Objects.hash(this.id, this.score); } @Override public String toString() { - return "DocumentWithScore{" + "document=" + document + ", score=" + score + '}'; + return "MeasurableDocument{" + "id='" + this.id + '\'' + ", text='" + this.text + '\'' + ", groupId='" + + this.groupId + '\'' + ", score=" + this.score + ", metadata=" + this.metadata + ", medias=" + + this.medias + '}'; } } \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/DefaultContent.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/DefaultContent.java index ea332f3d..98517da6 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/DefaultContent.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/DefaultContent.java @@ -48,4 +48,9 @@ public String text() { public List medias() { return this.medias; } + + @Override + public String toString() { + return this.text; + } } \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankApi.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankApi.java new file mode 100644 index 00000000..c23c6c77 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankApi.java @@ -0,0 +1,20 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +/** + * 提供 Rerank 客户端接口:发送 Rerank API 格式的请求并接收响应。 + * + * @author 马朝阳 + * @since 2024-09-27 + */ +public interface RerankApi { + /** + * Rerank 模型请求的端点。 + */ + String RERANK_ENDPOINT = "/rerank"; +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankDocumentProcessor.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankDocumentProcessor.java new file mode 100644 index 00000000..0cf2fb20 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankDocumentProcessor.java @@ -0,0 +1,98 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import modelengine.fel.core.document.DocumentPostProcessor; +import modelengine.fel.core.document.MeasurableDocument; +import modelengine.fit.http.client.HttpClassicClient; +import modelengine.fit.http.client.HttpClassicClientFactory; +import modelengine.fit.http.client.HttpClassicClientRequest; +import modelengine.fit.http.client.HttpClassicClientResponse; +import modelengine.fit.http.entity.Entity; +import modelengine.fit.http.entity.ObjectEntity; +import modelengine.fit.http.protocol.HttpRequestMethod; +import modelengine.fit.http.protocol.HttpResponseStatus; +import modelengine.fitframework.exception.FitException; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.log.Logger; +import modelengine.fitframework.resource.UrlUtils; +import modelengine.fitframework.util.CollectionUtils; +import modelengine.fitframework.util.LazyLoader; +import modelengine.fitframework.util.ObjectUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * 表示检索文档的后置重排序接口。 + * + * @author 马朝阳 + * @since 2024-09-14 + */ +public class RerankDocumentProcessor implements DocumentPostProcessor { + private static final Logger log = Logger.get(RerankDocumentProcessor.class); + + private final LazyLoader httpClient; + private final RerankOption rerankOption; + + /** + * 创建 {@link RerankDocumentProcessor} 的实体。 + * + * @param httpClientFactory 表示 {@link HttpClassicClientFactory} 的实例。 + * @param rerankOption 表示 rerank 模型参数的 {@link RerankOption} + */ + public RerankDocumentProcessor(HttpClassicClientFactory httpClientFactory, RerankOption rerankOption) { + Validation.notNull(httpClientFactory, "The httpClientFactory cannot be null."); + this.httpClient = + new LazyLoader<>(() -> httpClientFactory.create(HttpClassicClientFactory.Config.builder().build())); + this.rerankOption = Validation.notNull(rerankOption, "The rerankOption cannot be null."); + } + + /** + * 对检索结果进行重排序。 + * + * @param documents 表示输入文档的 {@link List}{@code <}{@link MeasurableDocument}{@code >}。 + * @return 表示处理后文档的 {@link List}{@code <}{@link MeasurableDocument}{@code >}。 + */ + public List process(List documents) { + if (CollectionUtils.isEmpty(documents)) { + return Collections.emptyList(); + } + List docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList()); + RerankRequest fields = new RerankRequest(this.rerankOption, docs); + + HttpClassicClientRequest request = this.httpClient.get() + .createRequest(HttpRequestMethod.POST, + UrlUtils.combine(this.rerankOption.baseUri(), RerankApi.RERANK_ENDPOINT)); + request.entity(Entity.createObject(request, fields)); + RerankResponse rerankResponse = this.rerankExchange(request); + + return rerankResponse.results() + .stream() + .map(result -> new MeasurableDocument(documents.get(result.index()), result.relevanceScore())) + .sorted((document1, document2) -> (int) (document2.score() - document1.score())) + .collect(Collectors.toList()); + } + + private RerankResponse rerankExchange(HttpClassicClientRequest request) { + try (HttpClassicClientResponse response = request.exchange(RerankResponse.class)) { + if (response.statusCode() != HttpResponseStatus.OK.statusCode()) { + log.error("Failed to get rerank model response. [code={}, reason={}]", + response.statusCode(), + response.reasonPhrase()); + throw new FitException("Failed to get rerank model response."); + } + return ObjectUtils.cast(response.objectEntity() + .map(ObjectEntity::object) + .orElseThrow(() -> new FitException("The response body is abnormal."))); + } catch (IOException e) { + throw new IllegalStateException("Failed to request rerank model.", e); + } + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankOption.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankOption.java new file mode 100644 index 00000000..d542bc3b --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankOption.java @@ -0,0 +1,109 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import modelengine.fitframework.pattern.builder.BuilderFactory; + +/** + * 表示 rerank 模型参数的实体。 + * + * @author 马朝阳 + * @since 2024-09-23 + */ +public interface RerankOption { + /** + * 获取调用模型的名字。 + * + * @return 表示模型名字的 {@link String}。 + */ + String model(); + + /** + * 获取搜索查询。 + * + * @return 表示搜索查询的 {@link String}。 + */ + String query(); + + /** + * 获取 Rerank 接口的 Uri。 + * + * @return 表示 Rerank 接口 Uri 的 {@link String}。 + */ + String baseUri(); + + /** + * 获取返回的最相关的文档数量。 + * + * @return 表示返回的最相关的文档数量的 {@link Integer}。 + */ + Integer topN(); + + /** + * {@link RerankOption} 的构建器。 + */ + interface Builder { + /** + * 设置调用模型的名字。 + * + * @param model 表示模型名字的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder model(String model); + + /** + * 设置搜索查询。 + * + * @param query 表示搜索查询的 {@code String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder query(String query); + + /** + * 设置 Rerank 接口的 Uri。 + * + * @param baseUri 表示 Rerank 接口 Uri 的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder baseUri(String baseUri); + + /** + * 设置返回的最相关的文档数量。 + * + * @param topN 表示返回的最相关的文档数量的 {@link Integer}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder topN(Integer topN); + + /** + * 构建对象。 + * + * @return 表示构建出来的对象的 {@link RerankOption}。 + */ + + RerankOption build(); + } + + /** + * 获取 {@link RerankOption} 的构建器。 + * + * @return 表示 {@link RerankOption} 的构建器的 {@link Builder}。 + */ + static Builder custom() { + return custom(null); + } + + /** + * 获取 {@link RerankOption} 的构建器,同时将指定对象的值进行填充。 + * + * @param value 表示指定对象的 {@link RerankOption}。 + * @return 表示 {@link RerankOption} 的构建器的 {@link Builder} + */ + static Builder custom(RerankOption value) { + return BuilderFactory.get(RerankOption.class, Builder.class).create(value); + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankRequest.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankRequest.java new file mode 100644 index 00000000..8384cbb7 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankRequest.java @@ -0,0 +1,42 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import modelengine.fitframework.annotation.Property; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.serialization.annotation.SerializeStrategy; + +import java.util.List; + +/** + * 表示 Rerank API 格式的请求。 + * + * @author 马朝阳 + * @since 2024-09-27 + */ +@SerializeStrategy(include = SerializeStrategy.Include.NON_NULL) +public class RerankRequest { + private final String model; + private final String query; + private final List documents; + @Property(name = "top_n") + private final Integer topN; + + /** + * 创建 {@link RerankRequest} 的实体。 + * + * @param documents 表示要重新排序的文档对象。 + * @param rerankOption 表示 rerank 模型参数。 + */ + public RerankRequest(RerankOption rerankOption, List documents) { + Validation.notNull(rerankOption, "The rerankOption cannot be null."); + this.model = rerankOption.model(); + this.query = rerankOption.query(); + this.documents = Validation.notNull(documents, "The documents cannot be null."); + this.topN = rerankOption.topN(); + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankResponse.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankResponse.java new file mode 100644 index 00000000..75fc535a --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankResponse.java @@ -0,0 +1,58 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import modelengine.fitframework.annotation.Property; +import modelengine.fitframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.List; + +/** + * 表示 Rerank API 格式的请求。 + * + * @author 马朝阳 + * @since 2024-09-27 + */ +public class RerankResponse { + private List results; + + /** + * 获取重新排序后的文档列表。 + * + * @return 表示重新排序后的文档列表的 {@link List}{@code <}{@link RerankOrder}{@code >}。 + */ + public List results() { + return CollectionUtils.isEmpty(this.results) + ? Collections.emptyList() + : Collections.unmodifiableList(this.results); + } + + static class RerankOrder { + private int index; + @Property(name = "relevance_score") + private double relevanceScore; + + /** + * 获取文档在原始列表中的索引。 + * + * @return 表示文档在原始列表中的索引的 {@code int}。 + */ + public int index() { + return this.index; + } + + /** + * 获取文档的相关性评分。 + * + * @return 表示文档的相关性评分的 {@code double}。 + */ + public double relevanceScore() { + return this.relevanceScore; + } + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfPostProcessor.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfPostProcessor.java new file mode 100644 index 00000000..79d9fb13 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfPostProcessor.java @@ -0,0 +1,104 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support.postprocessor; + +import modelengine.fel.core.document.DocumentPostProcessor; +import modelengine.fel.core.document.MeasurableDocument; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.CollectionUtils; +import modelengine.fitframework.util.MapBuilder; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalDouble; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; + +/** + * 基于 RRF 算法的后处理器。 + * + * @author 马朝阳 + * @since 2024-09-29 + */ +public class RrfPostProcessor implements DocumentPostProcessor { + private static final int DEFAULT_FACTOR = 60; + + private static final Map> SCORE_STRATEGY_MAP = + MapBuilder.>get() + .put(RrfScoreStrategyEnum.MAX, DoubleStream::max) + .put(RrfScoreStrategyEnum.AVG, DoubleStream::average) + .build(); + + private final RrfScoreStrategyEnum scoreStrategy; + private final int factor; + + public RrfPostProcessor() { + this(RrfScoreStrategyEnum.MAX, DEFAULT_FACTOR); + } + + public RrfPostProcessor(RrfScoreStrategyEnum scoreStrategy) { + this(scoreStrategy, DEFAULT_FACTOR); + } + + public RrfPostProcessor(RrfScoreStrategyEnum scoreStrategy, int factor) { + this.scoreStrategy = Validation.notNull(scoreStrategy, "The score strategy cannot be null."); + this.factor = Validation.greaterThanOrEquals(factor, 0, "The factor must be non-negative."); + if (!SCORE_STRATEGY_MAP.containsKey(this.scoreStrategy)) { + throw new IllegalArgumentException("The score strategy map not include this strategy."); + } + } + + /** + * 基于 RRF 算法对检索结果去重和重排序。 + * + * @param documents 表示输入文档的 {@link List}{@code <}{@link MeasurableDocument}{@code >}。 + * @return 表示处理后文档的 {@link List}{@code <}{@link MeasurableDocument}{@code >}。 + */ + @Override + public List process(List documents) { + if (CollectionUtils.isEmpty(documents)) { + return Collections.emptyList(); + } + Map rrfDocumentScore = this.getRrfDocumentScore(documents); + return this.getScoreByStrategy(documents) + .stream() + .sorted((document1, document2) -> rrfDocumentScore.get(document2.id()) + .compareTo(rrfDocumentScore.get(document1.id()))) + .collect(Collectors.toList()); + } + + private List getScoreByStrategy(List documents) { + Map> documentsMap = + documents.stream().collect(Collectors.groupingBy(MeasurableDocument::id)); + return documentsMap.values().stream().map(measurableDocuments -> { + DoubleStream doubleStream = measurableDocuments.stream().mapToDouble(MeasurableDocument::score); + double score = SCORE_STRATEGY_MAP.get(this.scoreStrategy).apply(doubleStream).orElse(0.0d); + MeasurableDocument document = measurableDocuments.get(0); + return new MeasurableDocument(document, score, document.group()); + }).collect(Collectors.toList()); + } + + private Map getRrfDocumentScore(List documents) { + Map> groupedDocuments = + documents.stream().collect(Collectors.groupingBy(MeasurableDocument::group)); + groupedDocuments.values() + .forEach(groupedList -> groupedList.sort(Comparator.comparingDouble(MeasurableDocument::score) + .reversed())); + Map idScoreMap = new HashMap<>(); + for (List groupedDocumentList : groupedDocuments.values()) { + for (int i = 0; i < groupedDocumentList.size(); i++) { + MeasurableDocument curr = groupedDocumentList.get(i); + idScoreMap.put(curr.id(), idScoreMap.getOrDefault(curr.id(), 0.0) + (1.0 / (i + 1 + this.factor))); + } + } + return idScoreMap; + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfScoreStrategyEnum.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfScoreStrategyEnum.java new file mode 100644 index 00000000..d83b7c09 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/postprocessor/RrfScoreStrategyEnum.java @@ -0,0 +1,25 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support.postprocessor; + +/** + * RRF 算法 score 选择策略。 + * + * @author 马朝阳 + * @since 2024-09-29 + */ +public enum RrfScoreStrategyEnum { + /** + * 相同文档的分数取最大值。 + */ + MAX, + + /** + * 相同文档的分数取平均值。 + */ + AVG; +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageModel.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageModel.java new file mode 100644 index 00000000..386537aa --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageModel.java @@ -0,0 +1,28 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.image; + +import modelengine.fitframework.resource.web.Media; + +import java.util.List; + +/** + * 表示大模型图像生成服务。 + * + * @author 何嘉斌 + * @since 2024-12-17 + */ +public interface ImageModel { + /** + * 调用图像生成模型生成结果。 + * + * @param prompt 表示提示词的 {@link String}。 + * @param chatOption 表示图像生成模型参数的 {@link ImageOption}。 + * @return 表示图像生成模型生成结果的 {@link List}{@code <}{@link Media}{@code >}。 + */ + List generate(String prompt, ImageOption chatOption); +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageOption.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageOption.java new file mode 100644 index 00000000..49aa10cf --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/image/ImageOption.java @@ -0,0 +1,98 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.image; + +import modelengine.fitframework.pattern.builder.BuilderFactory; + +/** + * 表示图像生成模型参数的实体。 + * + * @author 何嘉斌 + * @since 2024-12-17 + */ +public interface ImageOption { + /** + * 获取调用模型的名字。 + * + * @return 表示模型名字的 {@link String}。 + */ + String model(); + + /** + * 大模型服务端地址。 + * + * @return 表示大模型服务端地址的 {@link String}。 + */ + String baseUrl(); + + /** + * 获取图片规格。 + * + * @return 表示图片规格的 {@link String}。 + */ + String size(); + + /** + * 获取服务密钥。 + * + * @return 表示服务密钥的 {@link String}。 + */ + String apiKey(); + + /** + * 表示 {@link ImageOption} 的构建器。 + */ + interface Builder { + /** + * 设置模型名称。 + * + * @param model 表示模型名称的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder model(String model); + + /** + * 设置服务密钥。 + * + * @param apiKey 表示服务密钥的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder apiKey(String apiKey); + + /** + * 设置图片规格。 + * + * @param size 表示图片规格的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder size(String size); + + /** + * 设置模型服务端地址。 + * + * @param baseUrl 表示大模型服务端地址的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder baseUrl(String baseUrl); + + /** + * 构建 {@link ImageOption} 实例。 + * + * @return 返回构建成功的 {@link ImageOption} 实例。 + */ + ImageOption build(); + } + + /** + * 获取 {@link Builder} 的实例。 + * + * @return 表示构建器实例的 {@link Builder}。 + */ + static Builder custom() { + return BuilderFactory.get(ImageOption.class, Builder.class).create(null); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/BlockModel.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/BlockModel.java new file mode 100644 index 00000000..e5006c79 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/BlockModel.java @@ -0,0 +1,17 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.model; + +import modelengine.fel.core.pattern.Model; + +/** + * 阻塞模型。 + * + * @author 刘信宏 + * @since 2024-06-07 + */ +public interface BlockModel extends Model {} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/http/SecureConfig.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/http/SecureConfig.java new file mode 100644 index 00000000..712e6fce --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/model/http/SecureConfig.java @@ -0,0 +1,128 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.model.http; + +import modelengine.fitframework.pattern.builder.BuilderFactory; + +/** + * 表示 http 请求的安全相关配置。 + * + * @author 宋永坦 + * @since 2025-03-30 + */ +public interface SecureConfig { + /** + * 获取客户端 Http 是否忽略服务器根证书。 + * + * @return 表示是否忽略服务器根证书的 {@link String}。 + */ + Boolean ignoreTrust(); + + /** + * 获取客户端 Http 是否忽略对服务器主机名的身份认证。 + * + * @return 表示是否忽略对服务器主机名的身份认证的 {@link String}。 + */ + Boolean ignoreHostName(); + + /** + * 获取客户端 Http 的秘钥库的文件地址。 + * + * @return 表示秘钥库的文件地址的 {@link String}。 + */ + String trustStoreFile(); + + /** + * 获取客户端 Http 的秘钥库的密码。 + * + * @return 表示秘钥库的密码的 {@link String}。 + */ + String trustStorePassword(); + + /** + * 获取客户端 Http 的秘钥库的秘钥项的文件地址。 + * + * @return 表示秘钥库的秘钥项的文件地址的 {@link String}。 + */ + String keyStoreFile(); + + /** + * 获取客户端 Http 的秘钥库的秘钥项的密码。 + * + * @return 表示秘钥库的秘钥项的密码的 {@link String}。 + */ + String keyStorePassword(); + + /** + * {@link SecureConfig} 的构建器。 + */ + interface Builder { + /** + * 设置客户端 Http 是否忽略服务器根证书。 + * + * @param ignoreTrust 表示是否忽略服务器根证书的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder ignoreTrust(Boolean ignoreTrust); + + /** + * 设置客户端 Http 是否忽略对服务器主机名的身份认证。 + * + * @param ignoreHostName 表示是否忽略对服务器主机名的身份认证的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder ignoreHostName(String ignoreHostName); + + /** + * 设置客户端 Http 的秘钥库的文件地址。 + * + * @param trustStoreFile 表示秘钥库的文件地址的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder trustStoreFile(String trustStoreFile); + + /** + * 设置客户端 Http 的秘钥库的密码。 + * + * @param trustStorePassword 表示秘钥库的密码的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder trustStorePassword(String trustStorePassword); + + /** + * 设置客户端 Http 的秘钥库的秘钥项的文件地址。 + * + * @param keyStoreFile 表示秘钥库的秘钥项的文件地址的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder keyStoreFile(String keyStoreFile); + + /** + * 设置客户端 Http 的秘钥库的秘钥项的密码。 + * + * @param keyStorePassword 表示秘钥库的秘钥项的密码的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder keyStorePassword(String keyStorePassword); + + /** + * 构建对象。 + * + * @return 表示构建出来的对象的 {@link SecureConfig}。 + */ + SecureConfig build(); + } + + /** + * 获取 {@link SecureConfig} 的构建器。 + * + * @return 表示 {@link SecureConfig} 的构建器的 {@link Builder}。 + */ + static Builder custom() { + return BuilderFactory.get(SecureConfig.class, Builder.class).create(null); + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/template/support/DefaultStringTemplate.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/template/support/DefaultStringTemplate.java index 997130d2..630823a4 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/template/support/DefaultStringTemplate.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/template/support/DefaultStringTemplate.java @@ -29,7 +29,7 @@ */ public class DefaultStringTemplate implements StringTemplate { private static final ParameterizedStringResolver FORMATTER = - ParameterizedStringResolver.create("{{", "}}", '/', false); + ParameterizedStringResolver.create("{{", "}}", '\\', false); private final ParameterizedString parameterizedString; diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tokenizer/Tokenizer.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tokenizer/Tokenizer.java index 51887280..4e9f6f4b 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tokenizer/Tokenizer.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tokenizer/Tokenizer.java @@ -30,4 +30,12 @@ public interface Tokenizer { * @return 表示解码后的字符串的 {@link String}。 */ String decode(List tokens); + + /** + * 计算分词数。 + * + * @param text 表示需要进行分词字符串的 {@link String}。 + * @return 表示分词数的 {@code int}。 + */ + int countToken(String text); } \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCall.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCall.java index 2c69d41a..19aa2142 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCall.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCall.java @@ -19,17 +19,22 @@ public interface ToolCall { /** * 获取工具调用的唯一标识。 * - * @return 表示工具调用唯一编号的 {@link Integer}。 + * @return 表示工具调用唯一编号的 {@link String}。 */ - @Nonnull String id(); + /** + * 获取工具调用的索引标号。 + * + * @return 表示索引标号的 {@link Integer}。 + */ + Integer index(); + /** * 获取调用的工具名称。 * * @return 表示工具名称的 {@link String}。 */ - @Nonnull String name(); /** @@ -37,7 +42,6 @@ public interface ToolCall { * * @return 表示工具调用参数的 {@link String}。 */ - @Nonnull String arguments(); /** @@ -52,6 +56,14 @@ interface Builder { */ Builder id(String id); + /** + * 设置工具调用的索引标号。 + * + * @param index 表示索引标号的 {@code Integer}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder index(Integer index); + /** * 设置调用的工具名称。 * diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCallChunk.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCallChunk.java new file mode 100644 index 00000000..5df04293 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolCallChunk.java @@ -0,0 +1,22 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.tool; + +/** + * 表示工具调用请求的实体片段。 + * + * @author 刘信宏 + * @since 2024-12-23 + */ +public interface ToolCallChunk extends ToolCall { + /** + * 合并工具调用的流式报文。 + * + * @param toolCall 表示工具调用请求实体的 {@link ToolCall}。 + */ + void merge(ToolCall toolCall); +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolInfo.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolInfo.java index 749d9061..a9902ec5 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolInfo.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/ToolInfo.java @@ -6,7 +6,6 @@ package modelengine.fel.core.tool; -import modelengine.fitframework.inspection.Nonnull; import modelengine.fitframework.pattern.builder.BuilderFactory; import modelengine.fitframework.util.StringUtils; @@ -25,7 +24,6 @@ public interface ToolInfo { * * @return 表示工具分组的 {@link String}。 */ - @Nonnull String namespace(); /** @@ -33,7 +31,6 @@ public interface ToolInfo { * * @return 表示工具名称的 {@link String}。 */ - @Nonnull String name(); /** @@ -41,7 +38,6 @@ public interface ToolInfo { * * @return 表示工具描述的 {@link String}。 */ - @Nonnull String description(); /** @@ -49,7 +45,6 @@ public interface ToolInfo { * * @return 表示工具参数描述的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 */ - @Nonnull Map parameters(); /** @@ -57,7 +52,6 @@ public interface ToolInfo { * * @return 表示工具元数据的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 */ - @Nonnull Map extensions(); /** diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/support/DefaultToolCallChunk.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/support/DefaultToolCallChunk.java new file mode 100644 index 00000000..f3005303 --- /dev/null +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/tool/support/DefaultToolCallChunk.java @@ -0,0 +1,71 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.tool.support; + +import modelengine.fel.core.tool.ToolCall; +import modelengine.fel.core.tool.ToolCallChunk; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.LazyLoader; +import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.StringUtils; + +/** + * 表示工具调用请求的实体片段默认实现。 + *

+ * 该实现不保证流式片段聚合的线程安全,需要外部使用方保证线程安全。 + *

+ * + * @author 刘信宏 + * @since 2024-12-23 + */ +public class DefaultToolCallChunk implements ToolCallChunk { + private final String id; + private final String name; + private final String arguments; + private final LazyLoader argumentsBuffer; + + /** + * 使用 {@link ToolCall} 构造一个新的 {@link ToolCallChunk}。 + * + * @param toolCall 表示工具调用的 {@link ToolCall}。 + */ + public DefaultToolCallChunk(ToolCall toolCall) { + Validation.notNull(toolCall, "The tool call cannot be null."); + this.id = Validation.notNull(toolCall.id(), "The tool call id cannot be null."); + this.name = toolCall.name(); + this.arguments = toolCall.arguments(); + this.argumentsBuffer = + new LazyLoader<>(() -> new StringBuilder(ObjectUtils.nullIf(this.arguments, StringUtils.EMPTY))); + } + + @Override + public String id() { + return this.id; + } + + @Override + public Integer index() { + // 工具调用的片段不需要index字段。 + return null; + } + + @Override + public String name() { + return this.name; + } + + @Override + public String arguments() { + return this.argumentsBuffer.get().toString(); + } + + @Override + public void merge(ToolCall toolCall) { + Validation.notNull(toolCall, "The tool call cannot be null."); + this.argumentsBuffer.get().append(ObjectUtils.nullIf(toolCall.arguments(), StringUtils.EMPTY)); + } +} diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java index 7a1c2d60..c5a51fa4 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java @@ -25,6 +25,11 @@ public class Tip { private final Map values = new HashMap<>(); private int index = 0; + @Override + public String toString() { + return this.values.toString(); + } + /** * 从键值对创建 {@link Tip} 的实例。 * diff --git a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RerankDocumentProcessorTest.java b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RerankDocumentProcessorTest.java new file mode 100644 index 00000000..8a08a5e7 --- /dev/null +++ b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RerankDocumentProcessorTest.java @@ -0,0 +1,108 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import static modelengine.fel.core.document.support.TestRerankModelController.FAIL_ENDPOINT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +import modelengine.fel.core.document.Document; +import modelengine.fel.core.document.MeasurableDocument; +import modelengine.fit.http.client.HttpClassicClientFactory; +import modelengine.fitframework.annotation.Fit; +import modelengine.fitframework.exception.FitException; +import modelengine.fitframework.test.annotation.MvcTest; +import modelengine.fitframework.test.domain.mvc.MockMvc; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +/** + * ReRank 客户端服务测试。 + * + * @author 马朝阳 + * @since 2024-09-14 + */ +@MvcTest(classes = TestRerankModelController.class) +public class RerankDocumentProcessorTest { + private static final String[] DOCS = new String[] {"Burgers", "Carson", "Shanghai", "Beijing", "Test"}; + + private RerankDocumentProcessor client; + + @Fit + private HttpClassicClientFactory httpClientFactory; + + @Fit + private MockMvc mockMvc; + + @BeforeEach + public void setUp() { + this.client = new RerankDocumentProcessor(httpClientFactory, + RerankOption.custom() + .baseUri("http://localhost:" + mockMvc.getPort()) + .model("rerank1") + .query("What is the capital of the united states?") + .topN(3) + .build()); + } + + @Test + @DisplayName("测试 Rerank 接口调用响应成功") + public void testWhenCallRerankModelThenSuccess() { + List texts = Arrays.asList(DOCS[3], DOCS[4], DOCS[0]); + List scores = Arrays.asList(0.999071, 0.7867867, 0.32713068); + List docs = this.client.process(this.getRequest()); + assertThat(docs).extracting(MeasurableDocument::text).isEqualTo(texts); + assertThat(docs).extracting(MeasurableDocument::score).isEqualTo(scores); + } + + @Test + @DisplayName("测试 Rerank 接口调用响应异常") + public void testWhenCallRerankModelThenResponseException() { + RerankDocumentProcessor client1 = new RerankDocumentProcessor(httpClientFactory, + RerankOption.custom().baseUri("http://localhost:" + mockMvc.getPort() + FAIL_ENDPOINT).build()); + assertThatThrownBy(() -> client1.process(this.getRequest())).isInstanceOf(FitException.class); + } + + @Test + @DisplayName("测试 Rerank 接口参数为空响应异常") + public void testWhenCallRerankModelNullParamThenResponseException() { + assertThatThrownBy(() -> new RerankDocumentProcessor(this.httpClientFactory, null)).isInstanceOf( + IllegalArgumentException.class); + assertThatThrownBy(() -> new RerankDocumentProcessor(null, RerankOption.custom().build())).isInstanceOf( + IllegalArgumentException.class); + } + + @Test + @DisplayName("测试 Rerank 接口请求参数为空响应异常") + public void testWhenCallRerankModelNullRequestParamThenResponseException() { + assertThat(this.client.process(new ArrayList<>())).isEqualTo(Collections.emptyList()); + assertThat(this.client.process(null)).isEqualTo(Collections.emptyList()); + } + + private List getRequest() { + List documents = new ArrayList<>(); + Arrays.stream(DOCS) + .forEach(doc -> documents.add(new MeasurableDocument(Document.custom() + .text(doc) + .metadata(new HashMap<>()) + .build(), -1))); + return documents; + } + + private String getMockReRankResponseBody() { + return "{\"results\":[{\"index\":3,\"relevance_score\":0.999071},{\"index\":4,\"relevance_score\":0.7867867}," + + "{\"index\":0,\"relevance_score\":0.32713068}]}"; + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RrfPostProcessorTest.java b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RrfPostProcessorTest.java new file mode 100644 index 00000000..1b6777c6 --- /dev/null +++ b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RrfPostProcessorTest.java @@ -0,0 +1,90 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +import modelengine.fel.core.document.Document; +import modelengine.fel.core.document.MeasurableDocument; +import modelengine.fel.core.document.support.postprocessor.RrfPostProcessor; +import modelengine.fel.core.document.support.postprocessor.RrfScoreStrategyEnum; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * RRF 算法测试。 + * + * @author 马朝阳 + * @since 2024-09-29 + */ +public class RrfPostProcessorTest { + private static final String[] DOCS = new String[] {"A", "B", "C", "D", "E"}; + + @Test + @DisplayName("测试 RFF 算法最大值策略成功") + public void testWhenCallRRFMaxThenSuccess() { + RrfPostProcessor rrf = new RrfPostProcessor(); + List process = rrf.process(getDocumentList()); + assertThat(process).map(MeasurableDocument::score).containsExactly(0.94, 0.69, 0.36, 0.52, 0.32); + assertThat(process).map(MeasurableDocument::id).containsExactly("1", "4", "2", "5", "3"); + } + + @Test + @DisplayName("测试 RFF 算法平均值策略成功") + public void testWhenCallRRFAvgThenSuccess() { + RrfPostProcessor rrf = new RrfPostProcessor(RrfScoreStrategyEnum.AVG); + List process = rrf.process(getDocumentList()); + assertThat(process).map(MeasurableDocument::score).containsExactly(0.84, 0.655, 0.36, 0.52, 0.32); + assertThat(process).map(MeasurableDocument::id).containsExactly("1", "4", "2", "5", "3"); + } + + @Test + @DisplayName("测试 RFF 算法倒数系数") + public void testWhenCallRRFFactorThenSuccess() { + RrfPostProcessor rrf = new RrfPostProcessor(RrfScoreStrategyEnum.AVG, 100); + List process = rrf.process(getDocumentList()); + assertThat(process).map(MeasurableDocument::score).containsExactly(0.84, 0.655, 0.36, 0.52, 0.32); + assertThat(process).map(MeasurableDocument::id).containsExactly("1", "4", "2", "5", "3"); + } + + @Test + @DisplayName("测试 RFF 算法策略失败") + public void testWhenCallRRFArgNullThenFail() { + assertThatThrownBy(() -> new RrfPostProcessor(null, 60)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new RrfPostProcessor(RrfScoreStrategyEnum.AVG, -1)).isInstanceOf( + IllegalArgumentException.class); + } + + private List getDocumentList() { + List res = new ArrayList<>(); + res.addAll(getGroup("1", new int[] {1, 3, 4}, new double[] {0.74, 0.32, 0.69})); + res.addAll(getGroup("2", new int[] {1, 5}, new double[] {0.94, 0.52})); + res.addAll(getGroup("3", new int[] {2, 4}, new double[] {0.36, 0.62})); + + return res; + } + + private List getGroup(String groupId, int[] ids, double[] scores) { + List documents = new ArrayList<>(); + int scoreId = 0; + for (int id : ids) { + documents.add(new MeasurableDocument(Document.custom() + .text(DOCS[id - 1]) + .id(String.valueOf(id)) + .metadata(new HashMap<>()) + .build(), scores[scoreId], groupId)); + scoreId++; + } + return documents; + } +} diff --git a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/TestRerankModelController.java b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/TestRerankModelController.java new file mode 100644 index 00000000..74e1fd23 --- /dev/null +++ b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/TestRerankModelController.java @@ -0,0 +1,55 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.core.document.support; + +import modelengine.fit.http.annotation.PostMapping; +import modelengine.fitframework.annotation.Component; +import modelengine.fitframework.serialization.ObjectSerializer; + +/** + * 表示测试使用的 Rerank 接口。 + * + * @author 马朝阳 + * @since 2024-09-27 + */ +@Component +public class TestRerankModelController { + /** + * Rerank 接口失败调用端口。 + */ + public static final String FAIL_ENDPOINT = "/fail"; + + private final ObjectSerializer serializer; + + TestRerankModelController(ObjectSerializer serializer) { + this.serializer = serializer; + } + + /** + * 测试成功用 Rerank 接口。 + * + * @return 表示流式返回结果的 {@link String}。 + */ + @PostMapping(RerankApi.RERANK_ENDPOINT) + public RerankResponse rerankSuccess() { + String json = + "{\"results\":[{\"index\":3,\"relevance_score\":0.999071},{\"index\":4,\"relevance_score\":0.7867867}," + + "{\"index\":0,\"relevance_score\":0.32713068}]}"; + return this.serializer.deserialize(json, RerankResponse.class); + } + + /** + * 测试用 Rerank 接口。 + * + * @return 表示流式返回结果的 {@link String}。 + */ + @PostMapping(FAIL_ENDPOINT + RerankApi.RERANK_ENDPOINT) + public RerankResponse rerankFail() { + String json = "wrong json"; + return this.serializer.deserialize(json, RerankResponse.class); + } +} diff --git a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/splitter/support/SimpleTokenizer.java b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/splitter/support/SimpleTokenizer.java index b4489b55..48a29378 100644 --- a/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/splitter/support/SimpleTokenizer.java +++ b/framework/fel/java/fel-core/src/test/java/modelengine/fel/core/splitter/support/SimpleTokenizer.java @@ -6,6 +6,8 @@ package modelengine.fel.core.splitter.support; +import static modelengine.fitframework.inspection.Validation.notNull; + import modelengine.fel.core.tokenizer.Tokenizer; import java.util.ArrayList; @@ -36,4 +38,10 @@ public String decode(List tokens) { } return new String(charArray); } + + @Override + public int countToken(String text) { + notNull(text, "Text cannot be null."); + return text.length(); + } } \ No newline at end of file diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java index 8fc2f794..0bd7962c 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java @@ -12,6 +12,7 @@ import modelengine.fel.core.document.Content; import modelengine.fel.core.document.Document; import modelengine.fel.core.document.Measurable; +import modelengine.fel.core.model.BlockModel; import modelengine.fel.core.pattern.Parser; import modelengine.fel.core.pattern.Pattern; import modelengine.fel.core.pattern.PostProcessor; @@ -28,7 +29,9 @@ import modelengine.fel.engine.flows.Conversation; import modelengine.fel.engine.operators.models.FlowModel; import modelengine.fel.engine.operators.patterns.AbstractFlowPattern; +import modelengine.fel.engine.operators.patterns.FlowNodeSupportable; import modelengine.fel.engine.operators.patterns.FlowPattern; +import modelengine.fel.engine.operators.patterns.FlowSupportable; import modelengine.fel.engine.operators.patterns.SimpleFlowPattern; import modelengine.fel.engine.operators.prompts.PromptTemplate; import modelengine.fel.engine.util.AiFlowSession; @@ -46,7 +49,6 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Supplier; @@ -419,17 +421,30 @@ public AiState parse(Parser parser) { */ public AiState delegate(Pattern pattern) { Validation.notNull(pattern, "Pattern operator cannot be null."); - FlowPattern flowPattern = this.castFlowPattern(pattern); + return this.delegate(new SimpleFlowPattern<>(pattern)); + } + + /** + * 将数据委托给 {@link FlowPattern}{@code <}{@link O}{@code , }{@link R}{@code >} + * 处理,然后自身放弃处理数据。处理后的数据会发送回该节点,作为该节点的处理结果。 + * + * @param pattern 表示异步委托单元的 {@link FlowPattern}{@code <}{@link O}{@code , }{@link R}{@code >}。 + * @param 表示委托节点的输出数据类型。 + * @return 表示委托节点的 {@link AiState}{@code <}{@link R}{@code , }{@link D}{@code , }{@link O}{@code , + * }{@link RF}{@code , }{@link F}{@code >}。 + * @throws IllegalArgumentException 当 {@code pattern} 为 {@code null} 时。 + */ + public AiState delegate(FlowPattern pattern) { + Validation.notNull(pattern, "Pattern operator cannot be null."); Processor orProcessor = this.publisher().flatMap(input -> { - FlowEmitter cachedEmitter = FlowEmitter.from(flowPattern); - AiFlowSession.applyPattern(flowPattern, input.getData(), input.getSession()); - return Flows.source(cachedEmitter); + FlowEmitter emitter = AiFlowSession.applyPattern(pattern, input.getData(), input.getSession()); + return Flows.source(emitter); }, null); this.displayPatternProcessor(pattern, orProcessor); return new AiState<>(new State<>(orProcessor, this.flow().origin()), this.flow()); } - private void displayPatternProcessor(Pattern pattern, Processor processor) { + private void displayPatternProcessor(FlowPattern pattern, Processor processor) { if (pattern instanceof AbstractFlowPattern) { Flow originFlow = ObjectUtils.>cast(pattern).origin(); processor.displayAs("delegate to flow", originFlow, originFlow.start().getId()); @@ -477,13 +492,7 @@ public AiState delegate(Operators.ProcessMap operator) */ public AiState delegate(AiProcessFlow aiFlow) { Validation.notNull(aiFlow, "Flow cannot be null."); - Processor processor = this.publisher().map(input -> { - aiFlow.converse(input.getSession()).offer(input.getData()); - return (R) null; - }, null).displayAs("delegate to flow", aiFlow.origin(), aiFlow.origin().start().getId()); - AiState state = new AiState<>(new State<>(processor, this.flow().origin()), this.flow()); - state.offer(aiFlow); - return state; + return this.delegate(new FlowSupportable<>(aiFlow)); } /** @@ -503,14 +512,7 @@ public AiState delegate(AiProcessFlow aiFlow) { public AiState delegate(AiProcessFlow aiFlow, String nodeId) { Validation.notNull(aiFlow, "Flow cannot be null."); Validation.notBlank(nodeId, "Node id cannot be blank."); - Processor processor = this.publisher().map(input -> { - aiFlow.converse(input.getSession()).offer(nodeId, Collections.singletonList(input.getData())); - return (R) null; - }, null).displayAs("delegate to node", aiFlow.origin(), nodeId); - - AiState state = new AiState<>(new State<>(processor, this.flow().origin()), this.flow()); - state.offer(aiFlow); - return state; + return this.delegate(new FlowNodeSupportable<>(aiFlow, nodeId)); } /** @@ -531,6 +533,22 @@ public final AiState prompt(PromptTemplate... templates) }, null).displayAs("prompt"), this.flow().origin()), this.flow()); } + /** + * 生成大模型阻塞调用节点。 + * + * @param model 表示模型算子实现的 {@link BlockModel}{@code <}{@link M}{@code >}。 + * @param 表示模型节点的输入数据类型。 + * @return 表示大模型阻塞调用节点的 {@link AiState}{@code <}{@link ChatMessage}{@code , }{@link D}{@code , + * }{@link O}{@code , }{@link RF}{@code , }{@link F}{@code >}。 + * @throws IllegalArgumentException 当 {@code model} 为 {@code null} 时。 + */ + public AiState generate(BlockModel model) { + Validation.notNull(model, "Model operator cannot be null."); + return new AiState<>(new State<>(this.publisher() + .map(input -> AiFlowSession.applyPattern(model, input.getData(), input.getSession()), null) + .displayAs("generate"), this.flow().origin()), this.flow()); + } + /** * 生成大模型流式调用节点。 * diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiState.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiState.java index bb697ed5..1f175ac5 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiState.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiState.java @@ -103,9 +103,12 @@ public Publisher publisher() { @Override public void register(EmitterListener handler) { - if (handler != null) { - this.state.register(handler); - } + this.state.register(handler); + } + + @Override + public void unregister(EmitterListener listener) { + this.state.unregister(listener); } @Override diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiFlows.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiFlows.java index fff32fc5..49ace2aa 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiFlows.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiFlows.java @@ -8,6 +8,8 @@ import modelengine.fel.engine.activities.AiDataStart; import modelengine.fel.engine.activities.AiStart; +import modelengine.fit.waterflow.domain.context.FlowSession; +import modelengine.fit.waterflow.domain.emitters.Emitter; import modelengine.fit.waterflow.domain.flow.Flows; import modelengine.fit.waterflow.domain.flow.ProcessFlow; import modelengine.fit.waterflow.domain.states.Start; @@ -56,4 +58,16 @@ public static AiDataStart flux(D... data) { AiStart, AiProcessFlow> start = AiFlows.create(); return new AiDataStart<>(start, data); } + + /** + * 通过指定的发射源来构造一个数据前置流。 + * + * @param emitter 表示数据源的 {@link Emitter}{@code <}{@link D}{@code , }{@link FlowSession}{@code >}。 + * @param 表示数据类型。 + * @return 表示数据前置流的 {@link AiDataStart}{@code <}{@link D}{@code , }{@link D}{@code , }{@link D}{@code >}。 + */ + public static AiDataStart source(Emitter emitter) { + AiStart, AiProcessFlow> start = AiFlows.create(); + return new AiDataStart<>(start, emitter); + } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiProcessFlow.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiProcessFlow.java index 40e5abd9..c7d132e9 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiProcessFlow.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/AiProcessFlow.java @@ -16,6 +16,9 @@ import modelengine.fit.waterflow.domain.stream.reactive.Publisher; import modelengine.fitframework.util.ObjectUtils; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + /** * AI 数据处理流程,在 {@link AiFlow} 的基础上增加流程间的数据流转能力,并对外提供对话语义。 * @@ -26,6 +29,9 @@ */ public class AiProcessFlow extends AiFlow> implements EmitterListener, Emitter { + private final Map, EmitterListener> listeners = + new ConcurrentHashMap<>(); + public AiProcessFlow(ProcessFlow flow) { super(flow); } @@ -38,7 +44,18 @@ public void handle(D data, FlowSession session) { @Override public void register(EmitterListener listener) { if (listener != null) { - this.origin().register((data, token) -> listener.handle(ObjectUtils.cast(data), new FlowSession(token))); + EmitterListener wrapperHandler = + (data, session) -> listener.handle(ObjectUtils.cast(data), session); + this.listeners.put(listener, wrapperHandler); + this.origin().register(wrapperHandler); + } + } + + @Override + public void unregister(EmitterListener listener) { + EmitterListener target = this.listeners.remove(listener); + if (target != null) { + this.origin().unregister(target); } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java index 1a741843..48703e5e 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java @@ -6,10 +6,13 @@ package modelengine.fel.engine.flows; +import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.ChatOption; import modelengine.fel.core.memory.Memory; import modelengine.fel.engine.activities.AiStart; import modelengine.fel.engine.activities.FlowCallBack; +import modelengine.fel.engine.operators.models.StreamingConsumer; +import modelengine.fel.engine.operators.sources.Source; import modelengine.fel.engine.util.StateKey; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.stream.operators.Operators; @@ -45,7 +48,7 @@ public class Conversation { public Conversation(AiProcessFlow flow, FlowSession session) { this.flow = Validation.notNull(flow, "Flow cannot be null."); this.session = - (session == null) ? this.setConverseListener(new FlowSession()) : this.setSubConverseListener(session); + (session == null) ? this.setConverseListener(new FlowSession(true)) : this.setSubConverseListener(session); this.session.begin(); this.callBackBuilder = FlowCallBack.builder(); } @@ -60,7 +63,8 @@ public Conversation(AiProcessFlow flow, FlowSession session) { @SafeVarargs public final ConverseLatch offer(D... data) { ConverseLatch latch = setListener(this.flow); - FlowSession newSession = new FlowSession(this.session); + FlowSession newSession = FlowSession.newRootSession(this.session, this.session.preserved()); + newSession.getWindow().setFrom(null); this.flow.start().offer(data, newSession); newSession.getWindow().complete(); return latch; @@ -79,6 +83,7 @@ public ConverseLatch offer(String nodeId, List data) { Validation.notBlank(nodeId, "invalid nodeId."); ConverseLatch latch = setListener(this.flow); FlowSession newSession = new FlowSession(this.session); + newSession.getWindow().setFrom(null); this.flow.origin().offer(nodeId, data.toArray(new Object[0]), newSession); newSession.getWindow().complete(); return latch; @@ -110,6 +115,20 @@ public Conversation bind(Memory memory) { return this; } + /** + * 绑定流式响应信息消费者到对话上下文,用于消费流程流转过程中的流式信息。 + * + * @param consumer 表示流式响应信息消费者的 {@link StreamingConsumer}{@code <}{@link ChatMessage}{@code , + * }{@link ChatMessage}{@code >}。 + * @return 表示绑定了流式响应信息消费者的对话对象的 {@link Conversation}{@code <}{@link D}{@code , }{@link R}{@code >}。 + * @throws IllegalArgumentException 当 {@code consumer} 为 {@code null} 时。 + */ + public Conversation bind(StreamingConsumer consumer) { + Validation.notNull(consumer, "Streaming consumer cannot be null."); + this.session.setInnerState(StateKey.STREAMING_CONSUMER, consumer); + return this; + } + /** * 绑定自定义参数到对话上下文,后续可以在流程中的如下节点获取: *
    diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatBlockModel.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatBlockModel.java new file mode 100644 index 00000000..84da4294 --- /dev/null +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatBlockModel.java @@ -0,0 +1,90 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.models; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.ChatOption; +import modelengine.fel.core.chat.MessageType; +import modelengine.fel.core.chat.Prompt; +import modelengine.fel.core.chat.support.AiMessage; +import modelengine.fel.core.chat.ChatModel; +import modelengine.fel.core.memory.Memory; +import modelengine.fel.core.model.BlockModel; +import modelengine.fel.engine.util.AiFlowSession; +import modelengine.fel.engine.util.StateKey; +import modelengine.fit.waterflow.domain.context.FlowSession; +import modelengine.fitframework.inspection.Validation; + +import java.util.List; +import java.util.Optional; + +/** + * 阻塞对话模型实现。 + * + * @author 刘信宏 + * @since 2024-04-16 + */ +public class ChatBlockModel implements BlockModel { + private final ChatModel provider; + private final ChatOption option; + + public ChatBlockModel(ChatModel provider) { + this(provider, ChatOption.custom().build()); + } + + public ChatBlockModel(ChatModel provider, ChatOption option) { + this.provider = Validation.notNull(provider, "The model provider cannot be null."); + this.option = Validation.notNull(option, "The chat options cannot be null."); + } + + @Override + public ChatMessage invoke(Prompt input) { + Validation.notNull(input, "The model input data cannot be null."); + ChatOption dynamicOptions = AiFlowSession.get() + .map(state -> state.getInnerState(StateKey.CHAT_OPTION)) + .orElse(this.option); + List chatMessages = this.provider.generate(input, dynamicOptions).blockAll(); + Validation.notEmpty(chatMessages, "The model chat messages can not be empty."); + ChatMessage message = chatMessages.get(0); + Validation.equals(message.type(), + MessageType.AI, + "The message type must be {0}. [actualMessageType={1}]", + MessageType.AI, + message.type()); + AiMessage answer = new AiMessage(message.text(), message.toolCalls()); + this.updateMemory(answer); + return answer; + } + + /** + * 绑定模型超参数。 + * + * @param option 表示模型超参数的 {@link ChatOption}。 + * @return 表示绑定了超参数的 {@link ChatBlockModel}。 + * @throws IllegalArgumentException 当 {@code options} 为 {@code null} 时。 + */ + public ChatBlockModel bind(ChatOption option) { + Validation.notNull(option, "The chat options cannot be null."); + return new ChatBlockModel(this.provider, option); + } + + private void updateMemory(AiMessage answer) { + if (answer.isToolCall()) { + return; + } + Optional session = AiFlowSession.get(); + if (!session.isPresent()) { + return; + } + Memory memory = session.get().getInnerState(StateKey.HISTORY); + if (memory == null) { + return; + } + memory.add(session.get().getInnerState(StateKey.HISTORY_INPUT)); + memory.add(answer); + } +} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatChunk.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatChunk.java new file mode 100644 index 00000000..e75e21e1 --- /dev/null +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatChunk.java @@ -0,0 +1,103 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.models; + +import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.MessageType; +import modelengine.fel.core.tool.ToolCall; +import modelengine.fel.core.tool.ToolCallChunk; +import modelengine.fel.core.tool.support.DefaultToolCallChunk; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.CollectionUtils; +import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * 大模型流式响应内容片段。 + * + * @author 刘信宏 + * @since 2024-05-16 + */ +public class ChatChunk implements ChatMessage { + private final StringBuilder text = new StringBuilder(); + private final List toolCalls = new ArrayList<>(); + + public ChatChunk() { + } + /** + * 使用文本数据、媒体数据和工具请求初始化 {@link ChatChunk}。 + * + * @param text 表示字符串数据的 {@link String}。 + * @param toolCalls 表示工具请求的 {@link List}{@code <}{@link ToolCall}{@code >}。 + */ + public ChatChunk(String text, List toolCalls) { + this.text.append(ObjectUtils.nullIf(text, StringUtils.EMPTY)); + this.toolCalls.addAll(ObjectUtils.getIfNull(toolCalls, Collections::emptyList)); + } + + /** + * 聚合流式响应内容片段。 + * + * @param message 表示大模型流式响应内容片段的 {@link ChatMessage}。 + */ + public void merge(ChatMessage message) { + Validation.notNull(message, "Chat message can not be null."); + this.merge(message.text(), message.toolCalls()); + } + + @Override + public MessageType type() { + return MessageType.AI; + } + + @Override + public String text() { + return this.text.toString(); + } + + @Override + public List toolCalls() { + return this.toolCalls; + } + + @Override + public String toString() { + String textVal = this.toolCalls.isEmpty() ? this.text() : this.toolCalls.toString(); + return this.type().getRole() + ": " + textVal; + } + + /** + * 合并文本数据、媒体数据和工具请求。 + * + * @param text 表示字符串数据的 {@link String}。 + * @param toolCalls 表示工具请求的 {@link List}{@code <}{@link ToolCall}{@code >}。 + */ + private void merge(String text, List toolCalls) { + this.text.append(ObjectUtils.nullIf(text, StringUtils.EMPTY)); + if (CollectionUtils.isEmpty(toolCalls)) { + return; + } + toolCalls.stream().filter(Objects::nonNull).forEach(toolCall -> { + if (StringUtils.isNotBlank(toolCall.id())) { + this.toolCalls.add(new DefaultToolCallChunk(toolCall)); + return; + } + if (toolCall.index() == null || this.toolCalls.size() <= toolCall.index()) { + return; + } + ToolCall tarToolCall = this.toolCalls.get(toolCall.index()); + if (tarToolCall instanceof ToolCallChunk) { + ObjectUtils.cast(tarToolCall).merge(toolCall); + } + }); + } +} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatFlowModel.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatFlowModel.java index c4e664a4..1a00e827 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatFlowModel.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/ChatFlowModel.java @@ -33,7 +33,7 @@ public class ChatFlowModel implements FlowModel { public ChatFlowModel(ChatModel chatModel, ChatOption option) { this.chatModel = notNull(chatModel, "The model provider can not be null."); - this.option = notNull(option, "The chat options can not be null."); + this.option = option; } /** @@ -54,7 +54,8 @@ public FitBoundedEmitter invoke(Prompt input) { FlowSession session = AiFlowSession.get().orElseThrow(() -> new IllegalStateException("The ai session cannot be empty.")); ChatOption dynamicOption = nullIf(session.getInnerState(StateKey.CHAT_OPTION), this.option); + notNull(dynamicOption, "The chat options can not be null."); Choir choir = ObjectUtils.cast(this.chatModel.generate(input, dynamicOption)); - return new LlmEmitter<>(choir); + return new LlmEmitter<>(choir, input, session); } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java index 4b93424b..c13d0f67 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java @@ -7,8 +7,13 @@ package modelengine.fel.engine.operators.models; import modelengine.fel.core.chat.ChatMessage; +import modelengine.fel.core.chat.Prompt; +import modelengine.fel.engine.util.StateKey; import modelengine.fit.waterflow.bridge.fitflow.FitBoundedEmitter; +import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fitframework.flowable.Publisher; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.ObjectUtils; /** * 流式模型发射器。 @@ -17,12 +22,28 @@ * @since 2024-05-16 */ public class LlmEmitter extends FitBoundedEmitter { + private static final StreamingConsumer EMPTY_CONSUMER = (acc, chunk) -> {}; + + private final ChatChunk chunkAcc = new ChatChunk(); + private final StreamingConsumer consumer; + /** * 初始化 {@link LlmEmitter}。 * * @param publisher 表示数据发布者的 {@link Publisher}{@code <}{@link O}{@code >}。 + * @param prompt 表示模型输入的 {@link Prompt}, 用于获取默认用户问题。 + * @param session 表示流程实例运行标识的 {@link FlowSession}。 */ - public LlmEmitter(Publisher publisher) { + public LlmEmitter(Publisher publisher, Prompt prompt, FlowSession session) { super(publisher, data -> data); + Validation.notNull(session, "The session cannot be null."); + this.consumer = ObjectUtils.nullIf(session.getInnerState(StateKey.STREAMING_CONSUMER), EMPTY_CONSUMER); + } + + @Override + public void emit(ChatMessage data, FlowSession trans) { + super.emit(data, this.flowSession); + this.chunkAcc.merge(data); + this.consumer.accept(this.chunkAcc, data); } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/StreamingConsumer.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/StreamingConsumer.java new file mode 100644 index 00000000..73170cd8 --- /dev/null +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/StreamingConsumer.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.models; + +/** + * 流式响应信息消费者。 + * + * @author 刘信宏 + * @since 2024-05-17 + */ +@FunctionalInterface +public interface StreamingConsumer { + /** + * 消费流式响应数据。 + * + * @param acc 表示聚合信息的 {@link T}。 + * @param chunk 表示单次流式响应信息 {@link U}。 + */ + void accept(T acc, U chunk); +} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractAgent.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractAgent.java index eb1cf831..f12bcce1 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractAgent.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractAgent.java @@ -9,6 +9,7 @@ import static modelengine.fitframework.inspection.Validation.notBlank; import static modelengine.fitframework.inspection.Validation.notNull; +import lombok.Getter; import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.Prompt; import modelengine.fel.core.chat.support.AiMessage; @@ -62,13 +63,23 @@ protected AbstractAgent(ChatFlowModel flowModel, String memoryId) { this.memoryId = notBlank(memoryId, "The agent message key cannot be blank."); } + /** + * 获取配置的模型对象。 + * + * @return 配置的模型对象。 + */ + public ChatFlowModel getModel() { + return model; + } + /** * 执行工具调用。 * * @param toolCalls 表示工具调用的 {@link List}{@code <}{@link ToolCall}{@code >}。 + * @param ctx 表示工具调用上下文的 {@link StateContext}。 * @return 表示工具调用结果的 {@link Prompt}。 */ - protected abstract Prompt doToolCall(List toolCalls); + protected abstract Prompt doToolCall(List toolCalls, StateContext ctx); @Override protected AiProcessFlow buildFlow() { @@ -93,6 +104,6 @@ protected AiProcessFlow buildFlow() { private void handleTool(ChatMessage message, StateContext ctx) { ChatMessages lastRequest = ctx.getState(this.memoryId); lastRequest.add(message); - lastRequest.addAll(this.doToolCall(message.toolCalls()).messages()); + lastRequest.addAll(this.doToolCall(message.toolCalls(), ctx).messages()); } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractFlowPattern.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractFlowPattern.java index 4fe00642..ec475432 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractFlowPattern.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/AbstractFlowPattern.java @@ -13,9 +13,11 @@ import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.context.Window; import modelengine.fit.waterflow.domain.emitters.EmitterListener; +import modelengine.fit.waterflow.domain.emitters.FlowEmitter; import modelengine.fit.waterflow.domain.flow.Flow; import modelengine.fitframework.inspection.Validation; import modelengine.fitframework.util.LazyLoader; +import modelengine.fitframework.util.ObjectUtils; /** * 流程委托单元。 @@ -24,10 +26,25 @@ * @since 2024-06-04 */ public abstract class AbstractFlowPattern implements FlowPattern { + private static final String RESULT_ACTION_KEY = "resultAction"; + private static final String PARENT_SESSION_ID_KEY = "parentSessionId"; + private final LazyLoader> flowSupplier; + private final EmitterListener dataDispatcher = (data, session) -> { + Object rawResultAction = session.getInnerState(RESULT_ACTION_KEY); + if (rawResultAction == null) { + return; + } + ResultAction resultAction = ObjectUtils.cast(rawResultAction); + resultAction.process(data, session); + }; protected AbstractFlowPattern() { - this.flowSupplier = LazyLoader.of(this::buildFlow); + this.flowSupplier = LazyLoader.of(() -> { + AiProcessFlow flow = buildFlow(); + flow.register(this.dataDispatcher); + return flow; + }); } /** @@ -39,21 +56,25 @@ protected AbstractFlowPattern() { @Override public void register(EmitterListener handler) { - if (handler != null) { - this.getFlow().register(handler); - } + this.getFlow().register(handler); + } + + @Override + public void unregister(EmitterListener listener) { + this.getFlow().unregister(listener); } @Override public void emit(O data, FlowSession session) { - FlowSession flowSession = new FlowSession(session); - this.getFlow().emit(data, flowSession); + this.getFlow().emit(data, session); } @Override - public O invoke(I data) { - this.getFlow().converse(AiFlowSession.require()).offer(data); - return null; + public FlowEmitter invoke(I data) { + FlowEmitter emitter = new FlowEmitter.AutoCompleteEmitter<>(); + FlowSession flowSession = buildFlowSession(emitter); + this.getFlow().converse(flowSession).offer(data); + return emitter; } /** @@ -65,7 +86,7 @@ public O invoke(I data) { public Pattern sync() { return new SimplePattern<>(data -> { FlowSession require = AiFlowSession.require(); - FlowSession session = new FlowSession(); + FlowSession session = new FlowSession(true); Window window = session.begin(); session.copySessionState(require); ConverseLatch conversation = this.getFlow().converse(session).offer(data); @@ -83,7 +104,39 @@ public Flow origin() { return this.getFlow().origin(); } + /** + * Built the flow session for starting the conversation. + * + * @param emitter The {@link FlowEmitter}{@code <}{@link O}{@code >} representing output emitter. + * @return The new {@link FlowSession}. + * @param The output data type. + */ + protected static FlowSession buildFlowSession(FlowEmitter emitter) { + FlowSession mainSession = AiFlowSession.require(); + FlowSession flowSession = FlowSession.newRootSession(mainSession, true); + flowSession.setInnerState(PARENT_SESSION_ID_KEY, mainSession.getId()); + ResultAction resultAction = emitter::emit; + flowSession.setInnerState(RESULT_ACTION_KEY, resultAction); + return flowSession; + } + private AiProcessFlow getFlow() { return Validation.notNull(this.flowSupplier.get(), "The flow cannot be null."); } + + /** + * A functional interface defining an action to be performed with processed results. + * Implementations handle both the result data and its associated flow session context. + * + * @param The type of result data to be processed. + */ + protected interface ResultAction { + /** + * Process the result. + * + * @param data The result of {@link O}. + * @param flowSession The result flow session of {@link FlowSession}. + */ + void process(O data, FlowSession flowSession); + } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowNodeSupportable.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowNodeSupportable.java new file mode 100644 index 00000000..70f61528 --- /dev/null +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowNodeSupportable.java @@ -0,0 +1,52 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.patterns; + +import modelengine.fel.engine.flows.AiProcessFlow; +import modelengine.fit.waterflow.domain.context.FlowSession; +import modelengine.fit.waterflow.domain.emitters.FlowEmitter; +import modelengine.fitframework.inspection.Validation; + +import java.util.Collections; + +/** + * 指定流程节点的异步委托单元的流程实现。 + * + * @param 表示输入数据的类型。 + * @param 表示流程处理完成的数据类型。 + * @author 宋永坦 + * @since 2025-05-16 + */ +public class FlowNodeSupportable extends AbstractFlowPattern { + private final AiProcessFlow flow; + private final String nodeId; + + /** + * 通过 AI 流程初始化 {@link FlowNodeSupportable}{@code <}{@link I}{@code , }{@link O}{@code >}。 + * + * @param flow 表示 AI 流程的 {@link AiProcessFlow}{@code <}{@link I}{@code , }{@link O}{@code >}。 + * @param nodeId 表示流程节点标识的 {@link String}。 + * @throws IllegalArgumentException 当 {@code flow} 为 {@code null} 时。 + */ + public FlowNodeSupportable(AiProcessFlow flow, String nodeId) { + this.flow = Validation.notNull(flow, "The flow cannot be null."); + this.nodeId = Validation.notBlank(nodeId, "The node id cannot be null."); + } + + @Override + protected AiProcessFlow buildFlow() { + return this.flow; + } + + @Override + public FlowEmitter invoke(I data) { + FlowEmitter emitter = new FlowEmitter.AutoCompleteEmitter<>(); + FlowSession flowSession = buildFlowSession(emitter); + this.flow.converse(flowSession).offer(this.nodeId, Collections.singletonList(data)); + return emitter; + } +} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowPattern.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowPattern.java index 7a8d787f..dbd4199f 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowPattern.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/FlowPattern.java @@ -9,6 +9,7 @@ import modelengine.fel.core.pattern.Pattern; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.emitters.Emitter; +import modelengine.fit.waterflow.domain.emitters.FlowEmitter; /** * 流程委托单元。 @@ -18,4 +19,4 @@ * @author 刘信宏 * @since 2024-04-22 */ -public interface FlowPattern extends Pattern, Emitter {} +public interface FlowPattern extends Pattern>, Emitter {} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/SimpleFlowPattern.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/SimpleFlowPattern.java index cba97274..4b42ca40 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/SimpleFlowPattern.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/SimpleFlowPattern.java @@ -10,6 +10,7 @@ import modelengine.fel.engine.util.AiFlowSession; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.emitters.EmitterListener; +import modelengine.fit.waterflow.domain.emitters.FlowEmitter; import modelengine.fit.waterflow.domain.stream.operators.Operators; import modelengine.fitframework.inspection.Validation; import modelengine.fitframework.util.ObjectUtils; @@ -25,7 +26,7 @@ * @since 2024-04-22 */ public class SimpleFlowPattern implements FlowPattern { - private EmitterListener handler; + private final FlowEmitter emitter = new FlowEmitter<>(); private final Operators.ProcessMap processor; /** @@ -35,38 +36,39 @@ public class SimpleFlowPattern implements FlowPattern { * @throws IllegalArgumentException 当 {@code processor} 为 {@code null} 时。 */ public SimpleFlowPattern(Operators.ProcessMap processor) { - this(processor, null); + this.processor = Validation.notNull(processor, "The processor cannot be null."); } + /** + * 使用委托单元初始化 {@link SimpleFlowPattern}{@code <}{@link I}{@code , }{@link O}{@code >}。 + * + * @param pattern 表示委托单元的 {@link Pattern}{@code <}{@link I}{@code , }{@link O}{@code >}。 + * @throws IllegalArgumentException 当 {@code processor} 为 {@code null} 时。 + */ public SimpleFlowPattern(Pattern pattern) { - this((data, ctx) -> AiFlowSession.applyPattern(pattern, data, ObjectUtils.cast(ctx)), null); - } - - private SimpleFlowPattern(Operators.ProcessMap processor, EmitterListener handler) { - this.processor = Validation.notNull(processor, "The processor cannot be null."); - this.handler = handler; + this((data, ctx) -> AiFlowSession.applyPattern(pattern, data, ObjectUtils.cast(ctx))); } @Override - public O invoke(I data) { + public FlowEmitter invoke(I data) { FlowSession session = AiFlowSession.require(); - this.emit(this.processor.process(data, session), session); - session.getWindow().complete(); - return null; + this.emitter.emit(this.processor.process(data, session)); + this.emitter.complete(); + return this.emitter; } @Override public void register(EmitterListener handler) { - if (handler != null) { - this.handler = handler; - } + this.emitter.register(handler); + } + + @Override + public void unregister(EmitterListener handler) { + this.emitter.unregister(handler); } @Override public void emit(O data, FlowSession session) { - if (this.handler == null) { - return; - } - this.handler.handle(data, session); + this.emitter.emit(data, session); } } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/support/DefaultAgent.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/support/DefaultAgent.java index af40e238..1ca3742e 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/support/DefaultAgent.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/patterns/support/DefaultAgent.java @@ -17,6 +17,7 @@ import modelengine.fel.engine.operators.models.ChatFlowModel; import modelengine.fel.engine.operators.patterns.AbstractAgent; import modelengine.fel.tool.service.ToolExecuteService; +import modelengine.fit.waterflow.domain.context.StateContext; import java.util.List; import java.util.stream.Collectors; @@ -38,7 +39,7 @@ public DefaultAgent(ChatFlowModel flowModel, String namespace, ToolExecuteServic } @Override - protected Prompt doToolCall(List toolCalls) { + protected Prompt doToolCall(List toolCalls, StateContext ctx) { return toolCalls.stream().map(toolCall -> { String text = this.toolExecuteService.execute(this.namespace, toolCall.name(), toolCall.arguments()); return (ChatMessage) new ToolMessage(toolCall.id(), text); diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/sources/Source.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/sources/Source.java new file mode 100644 index 00000000..57e0e260 --- /dev/null +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/sources/Source.java @@ -0,0 +1,57 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.engine.operators.sources; + +import modelengine.fit.waterflow.domain.context.FlowSession; +import modelengine.fit.waterflow.domain.emitters.Emitter; +import modelengine.fit.waterflow.domain.emitters.EmitterListener; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * 数据发射源实现,支持注册多个监听器。仅支持主动逐个发射数据。 + * + * @param 表示数据源的业务数据类型。 + * @author 刘信宏 + * @since 2024-04-28 + */ +public class Source implements Emitter { + private static final Integer MAXIMUM_POOL_SIZE = 50; + private static final ExecutorService THREAD_POOL = new ThreadPoolExecutor(0, MAXIMUM_POOL_SIZE, + 60L, TimeUnit.SECONDS, new SynchronousQueue<>()); + + private final List> listeners = Collections.synchronizedList(new ArrayList<>()); + + /** + * 注销监听器。 + * + * @param listener 监听器。 + */ + public void unregister(EmitterListener listener) { + if (listener != null) { + this.listeners.remove(listener); + } + } + + @Override + public void register(EmitterListener listener) { + if (listener != null) { + this.listeners.add(listener); + } + } + + @Override + public void emit(T data, FlowSession session) { + this.listeners.forEach(handler -> THREAD_POOL.execute(() -> handler.handle(data, session))); + } +} diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/util/StateKey.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/util/StateKey.java index 80249d16..69d5634c 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/util/StateKey.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/util/StateKey.java @@ -18,6 +18,11 @@ public interface StateKey { */ String HISTORY = "history"; + /** + * 表示用户原始问题的键。 + */ + String HISTORY_INPUT = "history_input"; + /** * 表示模型超参数的键。 */ @@ -27,4 +32,19 @@ public interface StateKey { * 表示流程对话监听器的键。 */ String CONVERSE_LISTENER = "converse_listener"; + + /** + * 表示流式响应信息消费者的键。 + */ + String STREAMING_CONSUMER = "streaming_consumer"; + + /** + * 表示流式模型节点处理器。 + */ + String STREAMING_PROCESSOR = "streaming_processor"; + + /** + * 表示流式模型节点处理器。 + */ + String STREAMING_FLOW_CONTEXT = "streaming_flow_context"; } diff --git a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/ModelTest.java b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/ModelTest.java index a396fea4..35b0a168 100644 --- a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/ModelTest.java +++ b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/ModelTest.java @@ -12,13 +12,18 @@ import modelengine.fel.core.chat.ChatMessage; import modelengine.fel.core.chat.ChatOption; +import modelengine.fel.core.chat.Prompt; import modelengine.fel.core.chat.support.AiMessage; +import modelengine.fel.core.chat.support.ChatMessages; +import modelengine.fel.core.tool.ToolCall; import modelengine.fel.core.util.Tip; import modelengine.fel.engine.flows.AiFlows; import modelengine.fel.engine.flows.AiProcessFlow; import modelengine.fel.engine.flows.Conversation; import modelengine.fel.engine.operators.models.ChatFlowModel; +import modelengine.fel.engine.operators.patterns.AbstractAgent; import modelengine.fel.engine.operators.prompts.Prompts; +import modelengine.fit.waterflow.domain.context.StateContext; import modelengine.fit.waterflow.domain.utils.SleepUtil; import modelengine.fitframework.flowable.Choir; diff --git a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java index 5d087b85..47eac6f0 100644 --- a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java +++ b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java @@ -35,12 +35,12 @@ import modelengine.fel.engine.util.AiFlowSession; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.context.Window; -import modelengine.fit.waterflow.domain.utils.IdGenerator; +import modelengine.fit.waterflow.domain.utils.SleepUtil; import modelengine.fitframework.resource.web.Media; import modelengine.fitframework.util.CollectionUtils; +import modelengine.fitframework.util.ObjectUtils; import modelengine.fitframework.util.StringUtils; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -98,7 +98,6 @@ void shouldOkWhenAiFlowWithExampleSelector() { } @Test - @Disabled("暂不支持") @DisplayName("测试 Retriever") void shouldOkWhenAiFlowWithRetriever() { Memory memory = getMockMemory(); @@ -143,9 +142,14 @@ void shouldOkWhenAiFlowWithRetriever() { @DisplayName("测试 SimplePattern") void shouldOkWhenDelegateSimplePattern() { FlowSession session = new FlowSession(); + String key = "key"; + String value = "value"; + session.setState(key, value); SimplePattern pattern = new SimplePattern<>(prompt -> { - String sessionId = AiFlowSession.get().map(IdGenerator::getId).orElse(StringUtils.EMPTY); - return prompt.text() + sessionId; + String inputContextValue = AiFlowSession.get() + .map(target -> ObjectUtils.cast(target.getState(key))) + .orElse(StringUtils.EMPTY); + return prompt.text() + inputContextValue; }); Window token = session.begin(); ConverseLatch offer = AiFlows.create() @@ -157,7 +161,8 @@ void shouldOkWhenDelegateSimplePattern() { token.complete(); String result = offer.await(); - assertThat(result).isEqualTo("human msg." + session.getId()); + assertThat(result).isEqualTo("human msg." + value); + } private static Memory getMockMemory() { diff --git a/framework/fel/java/fel-jacoco-aggregator/pom.xml b/framework/fel/java/fel-jacoco-aggregator/pom.xml new file mode 100644 index 00000000..bc127ae6 --- /dev/null +++ b/framework/fel/java/fel-jacoco-aggregator/pom.xml @@ -0,0 +1,80 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-parent + 3.5.0-SNAPSHOT + + + fel-jacoco-aggregator + pom + + + + 1.0.0-SNAPSHOT + + + + + + org.fitframework.fel + fel-core + ${fel.version} + + + org.fitframework.fel + fel-flow + ${fel.version} + + + org.fitframework.fel + fel-pipeline-core + ${fel.version} + + + + + org.fitframework.fel + fel-langchain-service + ${fel.version} + + + org.fitframework.fel + fel-pipeline-service + ${fel.version} + + + + + org.fitframework.fel + fel-langchain-runnable + ${fel.version} + + + + + + + org.jacoco + jacoco-maven-plugin + ${jacoco.version} + + **/*.jar + + + + + report-aggregate + test + + report-aggregate + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/pom.xml b/framework/fel/java/fel-pipeline-core/pom.xml new file mode 100644 index 00000000..080c53d1 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/pom.xml @@ -0,0 +1,86 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-parent + 3.5.0-SNAPSHOT + + + fel-pipeline-core + + + + + org.fitframework + fit-api + + + org.fitframework + fit-util + + + + + org.fitframework.fel + fel-pipeline-service + ${fel.version} + + + + + org.projectlombok + lombok + + + + + org.fitframework.plugin + fit-message-serializer-json-jackson + test + + + org.junit.jupiter + junit-jupiter + test + + + org.assertj + assertj-core + test + + + + + + + org.fitframework + fit-dependency-maven-plugin + ${fit.version} + + + dependency + compile + + dependency + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.version} + + + + FIT Lab + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/Pipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/Pipeline.java new file mode 100644 index 00000000..4cf13330 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/Pipeline.java @@ -0,0 +1,19 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline; + +import java.util.function.Function; + +/** + * 流水线标准接口定义。 + * + * @param 表示流水线输入参数类型的 {@link I}。 + * @param 表示流水线输出参数类型的 {@link O}。 + * @author 易文渊 + * @since 2024-06-07 + */ +public interface Pipeline extends Function {} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/PipelineInput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/PipelineInput.java new file mode 100644 index 00000000..9ec05432 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/PipelineInput.java @@ -0,0 +1,15 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline; + +/** + * 流水线输入参数接口定义。 + * + * @author 易文渊 + * @since 2024-06-19 + */ +public interface PipelineInput {} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/ExplicitPipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/ExplicitPipeline.java new file mode 100644 index 00000000..90ca58f2 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/ExplicitPipeline.java @@ -0,0 +1,49 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import modelengine.fel.pipeline.Pipeline; +import modelengine.fel.pipeline.PipelineInput; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.ObjectUtils; + +import java.util.Map; + +/** + * 表示 huggingface pipeline 的特化实现。 + * + * @param 表示流水线输入参数类型的 {@link I}。 + * @param 表示流水线输出参数类型的 {@link O}。 + * @author 易文渊 + * @since 2024-06-04 + */ +public abstract class ExplicitPipeline implements Pipeline { + private final GeneralPipeline generalPipeline; + + private final PipelineTask task; + + /** + * 创建特化流水线的实例。 + * + * @param task 表示任务类型的 {@link PipelineTask}。 + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + */ + protected ExplicitPipeline(PipelineTask task, String model, HuggingFacePipelineService service) { + Validation.notBlank(model, "The model cannot be blank."); + Validation.notNull(service, "The pipeline service cannot be null."); + this.generalPipeline = new GeneralPipeline(task, model, service); + this.task = Validation.notNull(task, "The pipeline task cannot be null."); + } + + @Override + public O apply(I input) { + Map args = ObjectUtils.cast(ObjectUtils.toJavaObject(input)); + return ObjectUtils.toCustomObject(generalPipeline.apply(args), task.getOutputType()); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/GeneralPipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/GeneralPipeline.java new file mode 100644 index 00000000..00adfe43 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/GeneralPipeline.java @@ -0,0 +1,54 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import modelengine.fel.pipeline.Pipeline; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fitframework.inspection.Validation; + +import java.util.Map; + +/** + * 表示 huggingface pipeline 的泛化实现。 + *

    返回结果取决于任务类型,可能是以下值中的一个: + *

      + *
    • {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。
    • + *
    • {@link java.util.List}{@code <}{@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >>}。
    • + *
    + *

    + * + * @author 易文渊 + * @since 2024-06-04 + */ +public class GeneralPipeline implements Pipeline, Object> { + private final String taskId; + private final String model; + + private final HuggingFacePipelineService service; + + /** + * 创建泛化流水线的实例。 + * + * @param task 表示任务类型的 {@link PipelineTask}。 + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + * @throws IllegalArgumentException
      + *
    • 当 {@code task} 为 {@code null} 时。
    • + *
    • 当 {@code service} 为 {@code null} 时。
    • + *
    + */ + public GeneralPipeline(PipelineTask task, String model, HuggingFacePipelineService service) { + this.taskId = Validation.notNull(task, "The task cannot be null.").getId(); + this.model = Validation.notBlank(model, "The model cannot be blank."); + this.service = Validation.notNull(service, "The service cannot be null."); + } + + @Override + public Object apply(Map args) { + return this.service.call(this.taskId, this.model, args); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/PipelineTask.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/PipelineTask.java new file mode 100644 index 00000000..f2a1cecc --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/PipelineTask.java @@ -0,0 +1,120 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import static modelengine.fel.pipeline.huggingface.type.Constant.LIST_MEDIA_TYPE; + +import lombok.Getter; +import modelengine.fel.pipeline.huggingface.asr.AsrInput; +import modelengine.fel.pipeline.huggingface.asr.AsrOutput; +import modelengine.fel.pipeline.huggingface.asr.AsrPipeline; +import modelengine.fel.pipeline.huggingface.img2img.Image2ImageInput; +import modelengine.fel.pipeline.huggingface.img2img.Image2ImagePipeline; +import modelengine.fel.pipeline.huggingface.text2img.Text2ImageInput; +import modelengine.fel.pipeline.huggingface.text2img.Text2ImagePipeline; +import modelengine.fel.pipeline.huggingface.tts.TtsInput; +import modelengine.fel.pipeline.huggingface.tts.TtsOutput; + +import java.lang.reflect.Type; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * 表示 huggingface pipeline 任务类型枚举。 + * + * @author 易文渊 + * @since 2024-06-04 + */ +@Getter +public enum PipelineTask { + AUDIO_CLASSIFICATION("audio-classification", null, null), + + /** + * 音频文本提取。 + * + * @see AsrPipeline + */ + AUTOMATIC_SPEECH_RECOGNITION("automatic-speech-recognition", AsrInput.class, AsrOutput.class), + CONVERSATIONAL("conversational", null, null), + DEPTH_ESTIMATION("depth-estimation", null, null), + DOCUMENT_QUESTION_ANSWERING("document-question-answering", null, null), + FEATURE_EXTRACTION("feature-extraction", null, null), + FILL_MASK("fill-mask", null, null), + IMAGE_CLASSIFICATION("image-classification", null, null), + IMAGE_FEATURE_EXTRACTION("image-feature-extraction", null, null), + IMAGE_SEGMENTATION("image-segmentation", null, null), + IMAGE_TO_TEXT("image-to-text", null, null), + MASK_GENERATION("mask-generation", null, null), + OBJECT_DETECTION("object-detection", null, null), + QUESTION_ANSWERING("question-answering", null, null), + SUMMARIZATION("summarization", null, null), + TABLE_QUESTION_ANSWERING("table-question-answering", null, null), + TEXT2TEXT_GENERATION("text2text-generation", null, null), + TEXT_CLASSIFICATION("text-classification", null, null), + TEXT_GENERATION("text-generation", null, null), + + /** + * 语音合成。 + * + * @see Text2ImagePipeline + */ + TEXT_TO_SPEECH("text-to-speech", TtsInput.class, TtsOutput.class), + TOKEN_CLASSIFICATION("token-classification", null, null), + TRANSLATION("translation", null, null), + TRANSLATION_XX_TO_YY("translation_xx_to_yy", null, null), + VIDEO_CLASSIFICATION("video-classification", null, null), + VISUAL_QUESTION_ANSWERING("visual-question-answering", null, null), + ZERO_SHOT_CLASSIFICATION("zero-shot-classification", null, null), + ZERO_SHOT_IMAGE_CLASSIFICATION("zero-shot-image-classification", null, null), + ZERO_SHOT_AUDIO_CLASSIFICATION("zero-shot-audio-classification", null, null), + ZERO_SHOT_OBJECT_DETECTION("zero-shot-object-detection", null, null), + + /** + * 图生图。 + * + * @see Image2ImagePipeline + */ + IMAGE_TO_IMAGE("image-to-image", Image2ImageInput.class, LIST_MEDIA_TYPE), + + /** + * 文生图。 + * + * @see Text2ImagePipeline + */ + TEXT_TO_IMAGE("text-to-image", Text2ImageInput.class, LIST_MEDIA_TYPE); + + private final String id; + private final Type inputType; + private final Type outputType; + + private static final Map TASK_MAP = Arrays.stream(PipelineTask.values()) + .collect(Collectors.toMap(PipelineTask::getId, p -> p)); + + /** + * 根据任务名获取 {@link PipelineTask}。 + * + * @param task 表示任务名的 {@link String}。 + * @return 表示流水线任务的 {@link PipelineTask}。 + */ + public static PipelineTask get(String task) { + return TASK_MAP.get(task); + } + + /** + * 创建流水线任务枚举实例。 + * + * @param id 表示任务编号的 {@link String}。 + * @param inputType 表示任务输入参数类型的 {@link Type}。 + * @param outputType 表示任务输出参数类型的 {@link Type}。 + */ + PipelineTask(String id, Type inputType, Type outputType) { + this.id = id; + this.inputType = inputType; + this.outputType = outputType; + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrInput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrInput.java new file mode 100644 index 00000000..48b3932e --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrInput.java @@ -0,0 +1,51 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.asr; + +import lombok.Data; +import modelengine.fel.pipeline.PipelineInput; +import modelengine.fitframework.annotation.Property; + +import java.util.Map; + +/** + * 表示语音识别任务的输入参数。 + * + * @author 易文渊 + * @since 2024-06-04 + */ +@Data +public class AsrInput implements PipelineInput { + /** + * 表示音频文件的公共 URL 地址的 {@link String}。 + */ + private String inputs; + + /** + * 表示是否返回文本中每个单词的时间戳的 {@link Boolean}。 + *

    仅适用于以下模型,不适用于其他模型: + *

      + *
    • 纯 CTC 模型(Wav2Vec2、HuBERT 等)。
    • + *
    • Whisper 模型。
    • + *
    + *

    + */ + @Property(name = "return_timestamps") + private Boolean returnTimestamps; + + /** + * 表示用于模型生成的超参数的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 + */ + @Property(name = "generate_kwargs") + private Map generateKwargs; + + /** + * 表示生成的最大令牌数的 {@link Integer}。 + */ + @Property(name = "max_new_tokens") + private Integer maxNewTokens; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutput.java new file mode 100644 index 00000000..e34c71a2 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutput.java @@ -0,0 +1,31 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.asr; + +import lombok.Data; + +import java.util.List; + +/** + * 表示语音识别任务的输出参数。 + * + * @author 易文渊 + * @since 2024-06-04 + */ +@Data +public class AsrOutput { + /** + * 表示被识别的文本的 {@link String}。 + */ + private String text; + + /** + * 表示包含时间戳的文本片段集合的 {@link List}{@code <}{@link AsrOutputChunk}{@code >}。 + *

    当 {@link AsrInput#getReturnTimestamps()} 为 {@code true} 时生效。

    + */ + private List chunks; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutputChunk.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutputChunk.java new file mode 100644 index 00000000..a5869d52 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrOutputChunk.java @@ -0,0 +1,39 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.asr; + +import lombok.Data; + +import java.util.List; + +/** + * 表示包含时间戳的文本片段。 + *
    + *     {
    + *         "text": "hi",
    + *         "timestamp": [
    + *              0.5,
    + *              0.9
    + *         ]
    + *     }
    + * 
    + * + * @author 易文渊 + * @since 2024-06-04 + */ +@Data +public class AsrOutputChunk { + /** + * 表示文本片段的 {@link String}。 + */ + private String text; + + /** + * 表示时间区间的 {@link List}{@code <}{@link Double}{@code >}。 + */ + private List timestamp; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrPipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrPipeline.java new file mode 100644 index 00000000..041904fe --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/asr/AsrPipeline.java @@ -0,0 +1,29 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.asr; + +import modelengine.fel.pipeline.huggingface.ExplicitPipeline; +import modelengine.fel.pipeline.huggingface.PipelineTask; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; + +/** + * 表示 {@link PipelineTask#AUTOMATIC_SPEECH_RECOGNITION} 任务的流水线。 + * + * @author 易文渊 + * @since 2024-06-04 + */ +public class AsrPipeline extends ExplicitPipeline { + /** + * 创建语音识别流水线的实例。 + * + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + */ + public AsrPipeline(String model, HuggingFacePipelineService service) { + super(PipelineTask.AUTOMATIC_SPEECH_RECOGNITION, model, service); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImageInput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImageInput.java new file mode 100644 index 00000000..99a5f2dd --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImageInput.java @@ -0,0 +1,35 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.img2img; + +import lombok.Data; +import modelengine.fel.pipeline.PipelineInput; +import modelengine.fitframework.annotation.Property; + +/** + * 表示图生图任务的输入参数。 + * + * @author 易文渊 + * @since 2024-06-06 + */ +@Data +public class Image2ImageInput implements PipelineInput { + @Property(required = true) + private String prompt; + + @Property(required = true) + private String image; + + @Property(name = "negative_prompt") + private String negativePrompt; + + @Property(name = "num_images_per_prompt") + private Integer numImagesPerPrompt; + + @Property(name = "num_inference_steps") + private Integer numInferenceSteps; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImagePipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImagePipeline.java new file mode 100644 index 00000000..d1e86fd5 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/img2img/Image2ImagePipeline.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.img2img; + +import modelengine.fel.pipeline.huggingface.ExplicitPipeline; +import modelengine.fel.pipeline.huggingface.PipelineTask; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fitframework.resource.web.Media; + +import java.util.List; + +/** + * 表示 {@link PipelineTask#IMAGE_TO_IMAGE} 任务的流水线。 + * + * @author 易文渊 + * @since 2024-06-06 + */ +public class Image2ImagePipeline extends ExplicitPipeline> { + /** + * 创建图生图流水线的实例。 + * + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + */ + public Image2ImagePipeline(String model, HuggingFacePipelineService service) { + super(PipelineTask.IMAGE_TO_IMAGE, model, service); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImageInput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImageInput.java new file mode 100644 index 00000000..6d6d3217 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImageInput.java @@ -0,0 +1,35 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.text2img; + +import lombok.Data; +import modelengine.fel.pipeline.PipelineInput; +import modelengine.fitframework.annotation.Property; + +/** + * 表示文生图任务的输入参数。 + * + * @author 易文渊 + * @since 2024-06-06 + */ +@Data +public class Text2ImageInput implements PipelineInput { + private String prompt; + + @Property(name = "negative_prompt") + private String negativePrompt; + + private Integer height; + + private Integer width; + + @Property(name = "num_images_per_prompt") + private Integer numImagesPerPrompt; + + @Property(name = "num_inference_steps") + private Integer numInferenceSteps; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImagePipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImagePipeline.java new file mode 100644 index 00000000..1072214a --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/text2img/Text2ImagePipeline.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.text2img; + +import modelengine.fel.pipeline.huggingface.ExplicitPipeline; +import modelengine.fel.pipeline.huggingface.PipelineTask; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fitframework.resource.web.Media; + +import java.util.List; + +/** + * 表示 {@link PipelineTask#TEXT_TO_IMAGE} 任务的流水线。 + * + * @author 易文渊 + * @since 2024-06-06 + */ +public class Text2ImagePipeline extends ExplicitPipeline> { + /** + * 创建文生图流水线的实例。 + * + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + */ + public Text2ImagePipeline(String model, HuggingFacePipelineService service) { + super(PipelineTask.TEXT_TO_IMAGE, model, service); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsInput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsInput.java new file mode 100644 index 00000000..4c68ba89 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsInput.java @@ -0,0 +1,40 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.tts; + +import lombok.Data; +import modelengine.fel.pipeline.PipelineInput; +import modelengine.fitframework.annotation.Property; + +import java.util.Map; + +/** + * 表示语音合成任务的输入参数。 + * + * @author 易文渊 + * @since 2024-06-05 + */ +@Data +public class TtsInput implements PipelineInput { + /** + * 表示输入文本的 {@link String}。 + */ + @Property(name = "text_inputs") + private String textInputs; + + /** + * 表示底层模型推理参数的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 + */ + @Property(name = "forward_params") + private Map forwardParams; + + /** + * 表示音频模型推理参数的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 + */ + @Property(name = "generate_kwargs") + private Map generateKwargs; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsOutput.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsOutput.java new file mode 100644 index 00000000..6c2513d5 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsOutput.java @@ -0,0 +1,28 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.tts; + +import lombok.Data; +import modelengine.fitframework.annotation.Property; +import modelengine.fitframework.resource.web.Media; + +/** + * 表示语音合成任务的输出参数。 + * + * @author 易文渊 + * @since 2024-06-05 + */ +@Data +public class TtsOutput { + /** + * 表示输出音频的 {@link Media}。 + */ + private Media audio; + + @Property(name = "sampling_rate") + private Integer samplingRate; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsPipeline.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsPipeline.java new file mode 100644 index 00000000..ff338881 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/tts/TtsPipeline.java @@ -0,0 +1,29 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.tts; + +import modelengine.fel.pipeline.huggingface.ExplicitPipeline; +import modelengine.fel.pipeline.huggingface.PipelineTask; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; + +/** + * 表示 {@link PipelineTask#TEXT_TO_SPEECH} 任务的流水线。 + * + * @author 易文渊 + * @since 2024-06-05 + */ +public class TtsPipeline extends ExplicitPipeline { + /** + * 创建语音合成流水线的实例。 + * + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + */ + public TtsPipeline(String model, HuggingFacePipelineService service) { + super(PipelineTask.TEXT_TO_SPEECH, model, service); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/type/Constant.java b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/type/Constant.java new file mode 100644 index 00000000..b24b16a3 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/main/java/modelengine/fel/pipeline/huggingface/type/Constant.java @@ -0,0 +1,26 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface.type; + +import modelengine.fitframework.resource.web.Media; +import modelengine.fitframework.util.TypeUtils; + +import java.lang.reflect.Type; +import java.util.List; + +/** + * 表示 huggingface pipeline 的常量集合。 + * + * @author 易文渊 + * @since 2024-06-06 + */ +public interface Constant { + /** + * 表示 {@link List}{@code <}{@link Media}{@code >} 的 {@link Type}。 + */ + Type LIST_MEDIA_TYPE = TypeUtils.parameterized(List.class, new Type[] {Media.class}); +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineFactory.java b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineFactory.java new file mode 100644 index 00000000..93190ece --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineFactory.java @@ -0,0 +1,52 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import modelengine.fel.pipeline.Pipeline; +import modelengine.fel.pipeline.huggingface.asr.AsrPipeline; +import modelengine.fel.pipeline.huggingface.img2img.Image2ImagePipeline; +import modelengine.fel.pipeline.huggingface.text2img.Text2ImagePipeline; +import modelengine.fel.pipeline.huggingface.tts.TtsPipeline; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fitframework.inspection.Validation; +import modelengine.fitframework.util.MapBuilder; +import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.ReflectionUtils; + +import java.lang.reflect.Constructor; +import java.util.Map; + +/** + * 表示 pipeline 工厂。 + * + * @author 易文渊 + * @since 2024-06-07 + */ +public class PipelineFactory { + private static final Map> PIPELINE_CLAZZ = MapBuilder.>get() + .put(PipelineTask.AUTOMATIC_SPEECH_RECOGNITION.getId(), AsrPipeline.class) + .put(PipelineTask.TEXT_TO_SPEECH.getId(), TtsPipeline.class) + .put(PipelineTask.IMAGE_TO_IMAGE.getId(), Image2ImagePipeline.class) + .put(PipelineTask.TEXT_TO_IMAGE.getId(), Text2ImagePipeline.class) + .build(); + + /** + * 创建 pipeline 实例。 + * + * @param task 表示任务类型的 {@link PipelineTask}。 + * @param model 表示模型名的 {@link String}。 + * @param service 表示提供 pipeline 服务的 {@link HuggingFacePipelineService}。 + * @return 表示创建流水线实例的 {@link Pipeline}。 + */ + public static Pipeline create(String task, String model, HuggingFacePipelineService service) { + Class clazz = PIPELINE_CLAZZ.get(task); + Validation.notNull(clazz, "The task '{0}' class cannot be null.", task); + Constructor constructor = + ReflectionUtils.getDeclaredConstructor(clazz, String.class, HuggingFacePipelineService.class); + return ObjectUtils.cast(ReflectionUtils.instantiate(constructor, model, service)); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTest.java b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTest.java new file mode 100644 index 00000000..bc087e93 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTest.java @@ -0,0 +1,73 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import static modelengine.fitframework.util.IoUtils.content; +import static org.assertj.core.api.Assertions.assertThat; + +import modelengine.fel.pipeline.Pipeline; +import modelengine.fel.service.pipeline.HuggingFacePipelineService; +import modelengine.fit.serialization.json.jackson.JacksonObjectSerializer; +import modelengine.fitframework.serialization.ObjectSerializer; +import modelengine.fitframework.util.ObjectUtils; +import modelengine.fitframework.util.TypeUtils; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.List; +import java.util.stream.Stream; + +/** + * 表示 huggingface pipeline 的单元测试。 + * + * @author 易文渊 + * @since 2024-06-07 + */ +@DisplayName("测试 huggingface pipeline") +public class PipelineTest { + static class TestCaseProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext extensionContext) throws IOException { + ObjectSerializer serializer = new JacksonObjectSerializer(null, null, null); + + String resourceName = "/test_case.json"; + String jsonContent = content(TestCaseProvider.class, resourceName); + + List testCase = serializer.deserialize(jsonContent, + TypeUtils.parameterized(List.class, new Type[] {PipelineTestCase.class})); + + return testCase.stream().map(test -> { + PipelineTask task = PipelineTask.get(test.getTask()); + return Arguments.of(task.getId(), + test.getModel(), + ObjectUtils.toCustomObject(test.getInput(), task.getInputType()), + ObjectUtils.toCustomObject(test.getOutput(), task.getOutputType())); + }); + } + } + + @ParameterizedTest + @ArgumentsSource(TestCaseProvider.class) + @DisplayName("测试各种输入参数和输出参数,符合预期") + void shouldReturnOk(String task, String model, Object input, Object output) { + HuggingFacePipelineService service = (t, m, args) -> { + assertThat(t).isEqualTo(task); + assertThat(m).isEqualTo(model); + assertThat(args).isEqualTo(ObjectUtils.toJavaObject(input)); + return output; + }; + Pipeline pipeline = PipelineFactory.create(task, model, service); + assertThat(pipeline.apply(input)).isEqualTo(output); + } +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTestCase.java b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTestCase.java new file mode 100644 index 00000000..75d5ab38 --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/test/java/modelengine/fel/pipeline/huggingface/PipelineTestCase.java @@ -0,0 +1,23 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.pipeline.huggingface; + +import lombok.Data; + +/** + * 表示 pipline 测试用例。 + * + * @author 易文渊 + * @since 2024-06-07 + */ +@Data +public class PipelineTestCase { + private String task; + private String model; + private Object input; + private Object output; +} \ No newline at end of file diff --git a/framework/fel/java/fel-pipeline-core/src/test/resources/test_case.json b/framework/fel/java/fel-pipeline-core/src/test/resources/test_case.json new file mode 100644 index 00000000..fa5b9e0c --- /dev/null +++ b/framework/fel/java/fel-pipeline-core/src/test/resources/test_case.json @@ -0,0 +1,114 @@ +[ + { + "task": "automatic-speech-recognition", + "model": "openai/whisper-large-v3", + "input": { + "inputs": "test.wav" + }, + "output": { + "text": "hello" + } + }, + { + "task": "automatic-speech-recognition", + "model": "openai/whisper-large-v3", + "input": { + "inputs": "test.wav", + "return_timestamps": true + }, + "output": { + "text": "hello", + "chunks": [ + { + "text": "hello", + "timestamp": [ + 0.0, + 1.5 + ] + } + ] + } + }, + { + "task": "image-to-image", + "model": "stabilityai/stable-diffusion-2-base", + "input": { + "prompt": "a girl", + "image": "test.png" + }, + "output": [ + { + "mime": "image/png", + "data": "1.png" + } + ] + }, + { + "task": "image-to-image", + "model": "stabilityai/stable-diffusion-2-base", + "input": { + "prompt": "a girl", + "image": "test.png", + "negative_prompt": "ugly", + "num_images_per_prompt": 2 + }, + "output": [ + { + "mime": "image/png", + "data": "1.png" + }, + { + "mime": "image/png", + "data": "2.png" + } + ] + }, + { + "task": "text-to-image", + "model": "stabilityai/stable-diffusion-2-base", + "input": { + "prompt": "a girl" + }, + "output": [ + { + "mime": "image/png", + "data": "1.png" + } + ] + }, + { + "task": "text-to-image", + "model": "stabilityai/stable-diffusion-2-base", + "input": { + "prompt": "a girl", + "negative_prompt": "ugly", + "height": 500, + "width": 500, + "num_images_per_prompt": 2 + }, + "output": [ + { + "mime": "image/png", + "data": "1.png" + }, + { + "mime": "image/png", + "data": "2.png" + } + ] + }, + { + "task": "text-to-speech", + "model": "2Noise/ChatTTS", + "input": { + "text_inputs": "hello" + }, + "output": { + "audio": { + "mime": "audio/x-wav", + "data": "1.wav" + }, + "sampling_rate": 16000 + } + } +] \ No newline at end of file diff --git a/framework/fel/java/plugins/fel-langchain-runnable/pom.xml b/framework/fel/java/plugins/fel-langchain-runnable/pom.xml new file mode 100644 index 00000000..59d0d0ee --- /dev/null +++ b/framework/fel/java/plugins/fel-langchain-runnable/pom.xml @@ -0,0 +1,97 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-plugin-parent + 3.5.0-SNAPSHOT + + + fel-langchain-runnable + + + + + org.fitframework + fit-api + + + + + org.fitframework.fel + fel-langchain-service + ${fel.version} + + + + + + + org.fitframework + fit-build-maven-plugin + ${fit.version} + + + build-plugin + + build-plugin + + + + package-plugin + + package-plugin + + + + + + org.fitframework + fit-dependency-maven-plugin + ${fit.version} + + + dependency + compile + + dependency + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.version} + + + + FIT Lab + + + + + + org.apache.maven.plugins + maven-antrun-plugin + ${maven.antrun.version} + + + install + + + + + + + run + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/plugins/fel-langchain-runnable/src/main/java/modelengine/fel/plugin/langchain/LangChainRunnableServiceImpl.java b/framework/fel/java/plugins/fel-langchain-runnable/src/main/java/modelengine/fel/plugin/langchain/LangChainRunnableServiceImpl.java new file mode 100644 index 00000000..1a08f18b --- /dev/null +++ b/framework/fel/java/plugins/fel-langchain-runnable/src/main/java/modelengine/fel/plugin/langchain/LangChainRunnableServiceImpl.java @@ -0,0 +1,44 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.plugin.langchain; + +import modelengine.fel.service.langchain.LangChainRunnableService; +import modelengine.fitframework.annotation.Component; +import modelengine.fitframework.annotation.Fitable; +import modelengine.fitframework.broker.client.BrokerClient; +import modelengine.fitframework.broker.client.filter.route.FitableIdFilter; +import modelengine.fitframework.conf.runtime.SerializationFormat; +import modelengine.fitframework.inspection.Validation; + +import java.util.concurrent.TimeUnit; + +/** + * LangChain Runnable 算子服务的实现。 + * + * @author 刘信宏 + * @since 2024-06-12 + */ +@Component +public class LangChainRunnableServiceImpl implements LangChainRunnableService { + private static final int INVOKE_TIMEOUT = 30000; + + private final BrokerClient brokerClient; + + public LangChainRunnableServiceImpl(BrokerClient brokerClient) { + this.brokerClient = Validation.notNull(brokerClient, "The broker client cannot be null."); + } + + @Override + @Fitable("modelengine.fel.plugin.langchain.runnable.invoke") + public Object invoke(String taskId, String fitableId, Object input) { + return this.brokerClient.getRouter(Validation.notBlank(taskId, "The task id cannot be blank.")) + .route(new FitableIdFilter(Validation.notBlank(fitableId, "The fitable id cannot be blank."))) + .format(SerializationFormat.CBOR) + .timeout(INVOKE_TIMEOUT, TimeUnit.MILLISECONDS) + .invoke(Validation.notNull(input, "The input data cannot be null.")); + } +} diff --git a/framework/fel/java/plugins/fel-langchain-runnable/src/main/resources/application.yml b/framework/fel/java/plugins/fel-langchain-runnable/src/main/resources/application.yml new file mode 100644 index 00000000..b39acc76 --- /dev/null +++ b/framework/fel/java/plugins/fel-langchain-runnable/src/main/resources/application.yml @@ -0,0 +1,4 @@ +fit: + beans: + packages: + - 'modelengine.fel.plugin.langchain' diff --git a/framework/fel/java/plugins/pom.xml b/framework/fel/java/plugins/pom.xml index 6bc4994e..5568e326 100644 --- a/framework/fel/java/plugins/pom.xml +++ b/framework/fel/java/plugins/pom.xml @@ -20,5 +20,6 @@ tool-mcp-server tool-mcp-test tool-repository-simple + fel-langchain-runnable \ No newline at end of file diff --git a/framework/fel/java/pom.xml b/framework/fel/java/pom.xml index 08c3107a..e6baf7a1 100644 --- a/framework/fel/java/pom.xml +++ b/framework/fel/java/pom.xml @@ -55,8 +55,11 @@ 3.5.0-SNAPSHOT + 1.18.36 1.17.5 2.18.2 + portable-1.8.4 + 2.3.232 3.27.3 @@ -114,6 +117,11 @@ fit-message-serializer-json-jackson ${fit.version} + + org.fitframework.service + fit-security + ${fit.version} + @@ -165,6 +173,16 @@ jackson-databind ${jackson.version} + + org.projectlombok + lombok + ${lombok.version} + + + com.hankcs + hanlp + ${hanlp.version} + @@ -173,6 +191,12 @@ ${fit.version} test + + com.h2database + h2 + ${h2.version} + test + org.junit.jupiter junit-jupiter diff --git a/framework/fel/java/services/fel-langchain-service/pom.xml b/framework/fel/java/services/fel-langchain-service/pom.xml new file mode 100644 index 00000000..bfce3adf --- /dev/null +++ b/framework/fel/java/services/fel-langchain-service/pom.xml @@ -0,0 +1,65 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-services-parent + 3.5.0-SNAPSHOT + + + fel-langchain-service + + + + + org.fitframework + fit-api + + + + + + + org.fitframework + fit-build-maven-plugin + ${fit.version} + + + build-service + + build-service + + + + + + org.fitframework + fit-dependency-maven-plugin + ${fit.version} + + + dependency + compile + + dependency + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.version} + + + + FIT Lab + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/services/fel-langchain-service/src/main/java/modelengine/fel/service/langchain/LangChainRunnableService.java b/framework/fel/java/services/fel-langchain-service/src/main/java/modelengine/fel/service/langchain/LangChainRunnableService.java new file mode 100644 index 00000000..24d0fb8d --- /dev/null +++ b/framework/fel/java/services/fel-langchain-service/src/main/java/modelengine/fel/service/langchain/LangChainRunnableService.java @@ -0,0 +1,28 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.service.langchain; + +import modelengine.fitframework.annotation.Genericable; + +/** + * LangChain Runnable 算子服务。 + * + * @author 刘信宏 + * @since 2024-06-11 + */ +public interface LangChainRunnableService { + /** + * LangChain Runnable 算子服务阻塞同步调用接口。 + * + * @param taskId 表示任务名称的 {@link String}。 + * @param fitableId 表示任务实例名称的 {@link String}。 + * @param input 表示输入数据的 {@link Object}。 + * @return 表示输出数据的 {@link Object}。 + */ + @Genericable(id = "modelengine.fel.service.langchain.runnable") + Object invoke(String taskId, String fitableId, Object input); +} diff --git a/framework/fel/java/services/fel-pipeline-service/pom.xml b/framework/fel/java/services/fel-pipeline-service/pom.xml new file mode 100644 index 00000000..8061dbf4 --- /dev/null +++ b/framework/fel/java/services/fel-pipeline-service/pom.xml @@ -0,0 +1,69 @@ + + + 4.0.0 + + + org.fitframework.fel + fel-services-parent + 3.5.0-SNAPSHOT + + + fel-pipeline-service + + + + + org.fitframework + fit-api + + + org.fitframework + fit-util + + + + + + + org.fitframework + fit-build-maven-plugin + ${fit.version} + + + build-service + + build-service + + + + + + org.fitframework + fit-dependency-maven-plugin + ${fit.version} + + + dependency + compile + + dependency + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.version} + + + + FIT Lab + + + + + + + \ No newline at end of file diff --git a/framework/fel/java/services/fel-pipeline-service/src/main/java/modelengine/fel/service/pipeline/HuggingFacePipelineService.java b/framework/fel/java/services/fel-pipeline-service/src/main/java/modelengine/fel/service/pipeline/HuggingFacePipelineService.java new file mode 100644 index 00000000..26378427 --- /dev/null +++ b/framework/fel/java/services/fel-pipeline-service/src/main/java/modelengine/fel/service/pipeline/HuggingFacePipelineService.java @@ -0,0 +1,36 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.service.pipeline; + +import modelengine.fitframework.annotation.Genericable; + +import java.util.Map; + +/** + * 表示 pipeline 推理服务。 + * + * @author 易文渊 + * @since 2024-06-03 + */ +public interface HuggingFacePipelineService { + /** + * 调用 HuggingFace pipeline 生成结果。 + *

    返回结果取决于任务类型,可能是以下值中的一个: + *

      + *
    • {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。
    • + *
    • {@link java.util.List}{@code <}{@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >>}。
    • + *
    + *

    + * + * @param task 表示任务类型的 {@link String}。 + * @param model 表示模型名的 {@link String}。 + * @param args 表示调用参数的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 + * @return 表示生成结果的 {@link Object}。 + */ + @Genericable("modelengine.fel.pipeline.huggingface") + Object call(String task, String model, Map args); +} \ No newline at end of file diff --git a/framework/fel/java/services/pom.xml b/framework/fel/java/services/pom.xml index 7241c67c..f5cad469 100644 --- a/framework/fel/java/services/pom.xml +++ b/framework/fel/java/services/pom.xml @@ -17,6 +17,8 @@ tool-mcp-client-service tool-mcp-common tool-service + fel-langchain-service + fel-pipeline-service diff --git a/framework/fel/java/services/tool-service/src/main/java/modelengine/fel/tool/model/transfer/ToolData.java b/framework/fel/java/services/tool-service/src/main/java/modelengine/fel/tool/model/transfer/ToolData.java index 8d1b7176..2f37acb7 100644 --- a/framework/fel/java/services/tool-service/src/main/java/modelengine/fel/tool/model/transfer/ToolData.java +++ b/framework/fel/java/services/tool-service/src/main/java/modelengine/fel/tool/model/transfer/ToolData.java @@ -268,7 +268,7 @@ public static Tool.Info convertToInfo(ToolData toolData) { .runnables(toolData.getRunnables()) .extensions(toolData.getExtensions()) .version(toolData.getVersion()) - .isLatest(toolData.getLatest()) + .isLatest(ObjectUtils.nullIf(toolData.getLatest(), true)) .returnConverter(ObjectUtils.cast(toolData.getSchema().get(ToolSchema.RETURN_CONVERTER))) .defaultParameterValues(ToolData.defaultParamValues(toolData.getSchema())) .parameters(toolData.getParameters()) diff --git a/framework/fel/python/fel_core/__init__.py b/framework/fel/python/fel_core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/fel_core/types/__init__.py b/framework/fel/python/fel_core/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/fel_core/types/document.py b/framework/fel/python/fel_core/types/document.py new file mode 100644 index 00000000..a71b623a --- /dev/null +++ b/framework/fel/python/fel_core/types/document.py @@ -0,0 +1,22 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +from fel_core.types.serializable import Serializable +from fel_core.types.media import Media + + +class Document(Serializable): + """ + Document. + """ + content: str + media: Media = None + metadata: typing.Dict[str, object] + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/fel_core/types/media.py b/framework/fel/python/fel_core/types/media.py new file mode 100644 index 00000000..b8890a66 --- /dev/null +++ b/framework/fel/python/fel_core/types/media.py @@ -0,0 +1,18 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from fel_core.types.serializable import Serializable + + +class Media(Serializable): + """ + Media. + """ + mime: str + data: str + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/fel_core/types/serializable.py b/framework/fel/python/fel_core/types/serializable.py new file mode 100644 index 00000000..4522897f --- /dev/null +++ b/framework/fel/python/fel_core/types/serializable.py @@ -0,0 +1,25 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +try: + import pydantic + + if pydantic.__version__.startswith("1."): + raise ImportError + import pydantic.v1 as pydantic +except ImportError: + import pydantic + + +class Serializable(pydantic.BaseModel): + def json(self, **kwargs: typing.Any) -> str: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().json(**kwargs_with_defaults) + + def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().dict(**kwargs_with_defaults) \ No newline at end of file diff --git a/framework/fel/python/fel_langchain/__init__.py b/framework/fel/python/fel_langchain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/fel_langchain/langchain_registers.py b/framework/fel/python/fel_langchain/langchain_registers.py new file mode 100644 index 00000000..b7fe0c4d --- /dev/null +++ b/framework/fel/python/fel_langchain/langchain_registers.py @@ -0,0 +1,96 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +import json +from typing import List, Any, Optional, Callable, Union +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def register_function_tools(tools: List[BaseTool], + config: Optional[RunnableConfig] = None, + **kwargs: Any): + """ + langchain 函数工具注册方法,注册无需 api key 的工具。 + + Args: + tools (List[BaseTool]): 表示 langchain 工具列表。 + config (Optional[RunnableConfig]): 表示 langchain runnable 配置信息。 + **kwargs (Any): 表示额外参数。 + """ + for tool in tools: + register_api_tools(tool, [], tool.name, config, **kwargs) + + +def _pop_api_keys(input_args: dict, api_keys_name: List[str]) -> dict: + if all(key in input_args.keys() for key in api_keys_name): + api_keys_values = [input_args.pop(key) for key in api_keys_name] + return dict(zip(api_keys_name, api_keys_values)) + else: + raise ValueError(f"{input_args} not contain all api keys in {api_keys_name}") + + +def _invoke(input_args: dict, tool_builder: Union[Callable[[dict], BaseTool], BaseTool], + extra_keys: List[str], + config: Optional[RunnableConfig] = None, + **kwargs: Any) -> str: + api_keys = _pop_api_keys(input_args, extra_keys) + if "__arg1" in input_args: + _input_args = input_args["__arg1"] + else: + _input_args = input_args + + if isinstance(tool_builder, BaseTool): + tool = tool_builder + else: + tool = tool_builder(api_keys) + + try: + tool_ans = tool.invoke(_input_args, config, **kwargs) + return _dump_ans_to_str(tool_ans) + except BaseException: + return "" + + +def _dump_ans_to_str(tool_ans): + if not isinstance(tool_ans, str): + try: + content = json.dumps(tool_ans, ensure_ascii=False) + except Exception: + content = str(tool_ans) + else: + content = tool_ans + return content + + +def register_api_tools(tool_builder: Union[Callable[[dict], BaseTool], BaseTool], + extra_keys: List[str], + tool_name: str, + config: Optional[RunnableConfig] = None, + **kwargs: Any): + """ + langchain api 工具注册方法。 + + Args: + tool_builder (Callable[[dict], BaseTool]): 表示 api 工具构造器。 + extra_keys (List[str]): 表示工具的额外参数,如 api key。 + tool_name (str): 工具名称。 + config (Optional[RunnableConfig]): 表示 langchain runnable 配置信息。 + **kwargs (Any): 表示额外参数。 + """ + tool_invoke = functools.partial(_invoke, tool_builder=tool_builder, extra_keys=extra_keys, + config=config, **kwargs) + tool_invoke.__module__ = register_api_tools.__module__ + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': str + } + generic_id = 'langchain.tool' + register_fitable(generic_id, tool_name, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, tool_name, stacklevel=2) diff --git a/framework/fel/python/fel_langchain/langchain_schema_helper.py b/framework/fel/python/fel_langchain/langchain_schema_helper.py new file mode 100644 index 00000000..dbc84efc --- /dev/null +++ b/framework/fel/python/fel_langchain/langchain_schema_helper.py @@ -0,0 +1,30 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import json +import os +import stat +from typing import List + +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function + + +def dump_schema(function_tools: List[BaseTool], file_path: str): + """ + 导出 langchain 函数工具 schema 的工具方法。 + + Args: + function_tools (List[BaseTool]): 表示 langchain 工具列表。 + file_path (str): 表示 schema 文件的导出路径。 + """ + tools_schema = [{ + "runnables": {"langchain": {"genericableId": "langchain.tool", "fitableId": f"{tool.name}"}}, + "schema": {**convert_to_openai_function(tool), "return": {"type": "string"}} + } for tool in function_tools] + + fd = os.open(file_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR) + with os.fdopen(fd, "w") as file: + json.dump({"tools": tools_schema}, file) \ No newline at end of file diff --git a/framework/fel/python/fel_llama_index/__init__.py b/framework/fel/python/fel_llama_index/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/fel_llama_index/llama_schema_helper.py b/framework/fel/python/fel_llama_index/llama_schema_helper.py new file mode 100644 index 00000000..55e668fc --- /dev/null +++ b/framework/fel/python/fel_llama_index/llama_schema_helper.py @@ -0,0 +1,113 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import json +import os +import re +import stat +from inspect import Parameter +from inspect import signature +from typing import List, Tuple, Any, Callable, Optional + +from llama_index.core.bridge.pydantic import FieldInfo, create_model +from llama_index.core.tools import FunctionTool + + +def __get_ref_item(value: dict, definitions: dict) -> dict: + sub_properties_name = re.findall("^#/definitions/(.+)$", value.get("$ref")) + if len(sub_properties_name) == 0: + raise ValueError(f"Invalid reference properties {value.get('$ref')}.") + ref_item = definitions.get(sub_properties_name[0]) + ref_item["properties"] = __flat_properties(ref_item.get("properties"), definitions) + return ref_item + + +def __flat_properties(properties: dict, definitions: dict) -> dict: + if definitions is None: + return properties + flat_properties = dict() + for key, value in properties.items(): + if value.__contains__("$ref"): + flat_properties[key] = __get_ref_item(value, definitions) + continue + array_item = value.get("items") + if array_item is not None and array_item.__contains__("$ref"): + value["items"] = __get_ref_item(array_item, definitions) + flat_properties[key] = value + continue + else: + flat_properties[key] = value + return flat_properties + + +def __get_return_properties(func: Callable[..., Any], return_description: str) -> dict: + func_signature = signature(func) + param_type = func_signature.return_annotation + if param_type is Parameter.empty: + param_type = Any + + fields = {return_description: (param_type, FieldInfo())} + field_model = create_model(return_description, **fields) + parameters = field_model.schema() + parameters = { + key: value + for key, value in parameters.items() + if key in ["type", "properties", "required", "definitions"] + } + properties = __flat_properties(parameters.get("properties"), parameters.get("definitions")) + if return_description in properties: + return properties[return_description] + else: + return dict() + + +def __get_llama_rag_tool_schema(tool: Tuple[Callable[..., Any], List[str], str]) -> dict: + func = tool[0] + metadata = FunctionTool.from_defaults(fn=func).metadata + parameters_dict = metadata.get_parameters_dict() + property_key = "properties" + parameters_dict.get(property_key).pop("kwargs") + + dynamic_args = tool[1] + dynamic_args_dict = dict() + for arg in dynamic_args: + dynamic_args_dict[arg] = {"type": "string", "description": arg} + + definition = __get_param_definition(parameters_dict) + flat_properties = __flat_properties(parameters_dict.get(property_key), definition) + parameters_dict[property_key] = {**flat_properties, **dynamic_args_dict} + tool_schema = { + "name": metadata.name, + "description": func.__doc__, + "parameters": parameters_dict, + "return": __get_return_properties(func, tool[2]), + } + if len(dynamic_args_dict) != 0: + tool_schema["parameterExtensions"] = {"config": list(dynamic_args_dict.keys())} + return tool_schema + + +def __get_param_definition(parameters_dict: dict) -> Optional[dict]: + if parameters_dict.__contains__("definitions"): + return parameters_dict.pop("definitions") + return None + + +def dump_llama_schema(llama_toolkit: List[Tuple[Callable[..., Any], List[str], str]], file_path: str): + """ + 导出 LlamaIndex 函数工具 schema 的工具方法。 + + Args: + llama_toolkit (List[Tuple[Callable[..., Any], List[str]]]): 表示 llama_index rag 工具列表。 + file_path (str): 表示 schema 文件的导出路径。 + """ + tools_schema = [{ + "runnables": {"LlamaIndex": {"genericableId": "llama_index.rag.toolkit", "fitableId": f"{tool[0].__name__}"}}, + "schema": {**__get_llama_rag_tool_schema(tool)} + } for tool in llama_toolkit] + + fd = os.open(file_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR) + with os.fdopen(fd, "w") as file: + json.dump({"tools": tools_schema}, file) diff --git a/framework/fel/python/fel_llama_index/node_utils.py b/framework/fel/python/fel_llama_index/node_utils.py new file mode 100644 index 00000000..47b503d7 --- /dev/null +++ b/framework/fel/python/fel_llama_index/node_utils.py @@ -0,0 +1,49 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from llama_index.core.multi_modal_llms.generic_utils import encode_image +from llama_index.core.schema import ImageNode, TextNode, NodeWithScore + +from fel_core.types.document import Document +from fel_core.types.media import Media + + +def document_to_query_node(doc_input: Document): + if isinstance(doc_input, dict): + doc = Document(**doc_input) + else: + doc = doc_input + + if doc.media is not None: + node = ImageNode(image=doc.media.data, image_mimetype=doc.media.mime) + else: + node = TextNode() + node.set_content(doc.content) + node.metadata = doc.metadata + return NodeWithScore(node=node, score=doc.metadata["score"]) + + +def query_node_to_document(node_with_score: NodeWithScore) -> Document: + node = node_with_score.node + metadata = node.metadata or {} + metadata['score'] = node_with_score.score + content = None + image = None + file_path_key = "file_path" + if isinstance(node, ImageNode): + mime = node.image_mimetype or "image/jpeg" + data = None + if node.image and node.image != "": + data = node.image + elif node.image_url and node.image_url != "": + data = node.image_url + elif node.image_path and node.image_path != "": + data = encode_image(node.image_path) + elif file_path_key in node.metadata and node.metadata[file_path_key] != "": + data = encode_image(node.metadata[file_path_key]) + image = Media(mime=mime, data=data) + if isinstance(node, TextNode): + content = node.get_content() + return Document(content=content, media=image, metadata=metadata) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/callable_registers.py b/framework/fel/python/plugins/fel_langchain_loader_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/document_util.py b/framework/fel/python/plugins/fel_langchain_loader_tools/document_util.py new file mode 100644 index 00000000..b12158a2 --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/document_util.py @@ -0,0 +1,11 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import langchain_core.documents +from .types.document import Document + + +def langchain_doc_to_document(doc: langchain_core.documents.Document) -> Document: + return Document(content=doc.page_content, metadata=dict()) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/langchain_loader_tools.py b/framework/fel/python/plugins/fel_langchain_loader_tools/langchain_loader_tools.py new file mode 100644 index 00000000..1e36da9a --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/langchain_loader_tools.py @@ -0,0 +1,103 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import traceback +from typing import List, Callable, Tuple, Any +from urllib.parse import urlparse, parse_qs + +from fitframework import fit_logger +from langchain_community.document_loaders import PyPDFLoader, PDFPlumberLoader, PyMuPDFLoader, PyPDFDirectoryLoader, \ + PyPDFium2Loader, PDFMinerLoader +from langchain_core.document_loaders import BaseLoader + +from .types.document import Document +from .document_util import langchain_doc_to_document +from .callable_registers import register_callable_tool + + +def py_pdf_loader(file_path: str, **kwargs) -> List[Document]: + """Load PDF using pypdf into list of documents.""" + return __loader_handler(lambda nfs_file_path: PyPDFLoader(nfs_file_path), file_path) + + +def pdfplumber_loader(file_path: str, **kwargs) -> List[Document]: + """Load PDF using pdfplumber into list of documents""" + return __loader_handler(lambda nfs_file_path: PDFPlumberLoader(nfs_file_path), file_path) + + +def py_mupdf_loader(file_path: str, **kwargs) -> List[Document]: + """Load PDF using PyMuPDF into list of documents""" + return __loader_handler(lambda nfs_file_path: PyMuPDFLoader(nfs_file_path), file_path) + + +def py_pdfium2_loader(file_path: str, **kwargs) -> List[Document]: + """Load PDF using pypdfium2 into list of documents""" + return __loader_handler(lambda nfs_file_path: PyPDFium2Loader(nfs_file_path), file_path) + + +def py_miner_loader(file_path: str, **kwargs) -> List[Document]: + """Load PDF using PDFMiner into list of documents""" + return __loader_handler(lambda nfs_file_path: PDFMinerLoader(nfs_file_path), file_path) + + +def py_pdf_directory_loader(directory: str, **kwargs) -> List[Document]: + """Load a directory with `PDF` files using `pypdf` and chunks at character level""" + return __loader_handler(lambda nfs_file_dir: PyPDFDirectoryLoader(nfs_file_dir), directory) + + +def __loader_handler(loader_builder: Callable[[str], BaseLoader], file_url: str) -> List[Document]: + try: + # 解析文件路径 + fit_logger.info("file_url: " + file_url) + nfs_file_path = get_file_path(file_url) + fit_logger.info("nfs_file_path: " + nfs_file_path) + pdf_loader = loader_builder(nfs_file_path) + iterator = pdf_loader.lazy_load() + res = [] + max_page = 300 + for doc in iterator: + if len(res) > max_page: + return res + res.append(langchain_doc_to_document(doc)) + return res + except BaseException: + fit_logger.error("Invoke file loader failed.") + fit_logger.exception("Invoke file loader failed.") + traceback.print_exc() + return [] + + +def get_file_path(file_url: str): + try: + parsed_url = urlparse(file_url) + if not all([parsed_url.scheme, parsed_url.netloc]): + return file_url + file_query_param = parse_qs(parsed_url.query).get('filePath') + if file_query_param is None or len(file_query_param) == 0: + msg = "Invalid file url. missing query parameter [filePath]" + fit_logger.error(msg) + raise ValueError(msg) + else: + return file_query_param[0] + except BaseException: + fit_logger.error("Parse file path failed.") + return file_url + + +DOCUMENT_RETURN_DESC = "a piece of text and associated metadata." + +# 普通callable注册方式 +# Tuple 结构: (tool_func, config_args, return_description) +loader_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (py_pdf_loader, [], DOCUMENT_RETURN_DESC), + (pdfplumber_loader, [], DOCUMENT_RETURN_DESC), + (py_mupdf_loader, [], DOCUMENT_RETURN_DESC), + (py_pdfium2_loader, [], DOCUMENT_RETURN_DESC), + (py_miner_loader, [], DOCUMENT_RETURN_DESC), + (py_pdf_directory_loader, [], DOCUMENT_RETURN_DESC), +] + +for tool in loader_toolkit: + register_callable_tool(tool, get_file_path.__module__, 'langchain.tool') diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/tools.json b/framework/fel/python/plugins/fel_langchain_loader_tools/tools.json new file mode 100644 index 00000000..22e93a57 --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/tools.json @@ -0,0 +1,418 @@ +{ + "tools": [ + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "py_pdf_loader" + } + }, + "schema": { + "name": "py_pdf_loader", + "description": "Load PDF using pypdf into list of documents.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "title": "File Path", + "type": "string" + } + }, + "required": [ + "file_path" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "pdfplumber_loader" + } + }, + "schema": { + "name": "pdfplumber_loader", + "description": "Load PDF using pdfplumber into list of documents", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "title": "File Path", + "type": "string" + } + }, + "required": [ + "file_path" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "py_mupdf_loader" + } + }, + "schema": { + "name": "py_mupdf_loader", + "description": "Load PDF using PyMuPDF into list of documents", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "title": "File Path", + "type": "string" + } + }, + "required": [ + "file_path" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "py_pdfium2_loader" + } + }, + "schema": { + "name": "py_pdfium2_loader", + "description": "Load PDF using pypdfium2 into list of documents", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "title": "File Path", + "type": "string" + } + }, + "required": [ + "file_path" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "py_miner_loader" + } + }, + "schema": { + "name": "py_miner_loader", + "description": "Load PDF using PDFMiner into list of documents", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "title": "File Path", + "type": "string" + } + }, + "required": [ + "file_path" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "tags": [ + "langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "py_pdf_directory_loader" + } + }, + "schema": { + "name": "py_pdf_directory_loader", + "description": "Load a directory with `PDF` files using `pypdf` and chunks at character level", + "parameters": { + "type": "object", + "properties": { + "directory": { + "title": "Directory", + "type": "string" + } + }, + "required": [ + "directory" + ] + }, + "return": { + "title": "A Piece Of Text And Associated Metadata.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + } + ] +} \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/types/__init__.py b/framework/fel/python/plugins/fel_langchain_loader_tools/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/types/document.py b/framework/fel/python/plugins/fel_langchain_loader_tools/types/document.py new file mode 100644 index 00000000..4989999f --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/types/document.py @@ -0,0 +1,22 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +from .serializable import Serializable +from .media import Media + + +class Document(Serializable): + """ + Document. + """ + content: str + media: Media = None + metadata: typing.Dict[str, object] + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/types/media.py b/framework/fel/python/plugins/fel_langchain_loader_tools/types/media.py new file mode 100644 index 00000000..b1bdb54a --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/types/media.py @@ -0,0 +1,18 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from .serializable import Serializable + + +class Media(Serializable): + """ + Media. + """ + mime: str + data: str + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_langchain_loader_tools/types/serializable.py b/framework/fel/python/plugins/fel_langchain_loader_tools/types/serializable.py new file mode 100644 index 00000000..4522897f --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_loader_tools/types/serializable.py @@ -0,0 +1,25 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +try: + import pydantic + + if pydantic.__version__.startswith("1."): + raise ImportError + import pydantic.v1 as pydantic +except ImportError: + import pydantic + + +class Serializable(pydantic.BaseModel): + def json(self, **kwargs: typing.Any) -> str: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().json(**kwargs_with_defaults) + + def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().dict(**kwargs_with_defaults) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_langchain_network_tools/callable_registers.py b/framework/fel/python/plugins/fel_langchain_network_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_network_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_langchain_network_tools/langchain_network_tool.py b/framework/fel/python/plugins/fel_langchain_network_tools/langchain_network_tool.py new file mode 100644 index 00000000..f1867a5a --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_network_tools/langchain_network_tool.py @@ -0,0 +1,183 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import os +import time +from typing import List, Any, Callable, Tuple + +from langchain_community.retrievers import ArxivRetriever +from langchain_community.tools import WikipediaQueryRun, DuckDuckGoSearchRun, YouTubeSearchTool, GoogleSearchRun, \ + PubmedQueryRun, GooglePlacesTool, BraveSearch, MojeekSearch +from langchain_community.tools.google_jobs import GoogleJobsQueryRun +from langchain_community.tools.google_scholar import GoogleScholarQueryRun +from langchain_community.tools.reddit_search.tool import RedditSearchRun, RedditSearchSchema +from langchain_community.tools.wikidata.tool import WikidataQueryRun +from langchain_community.utilities import WikipediaAPIWrapper, GoogleSearchAPIWrapper, GoogleSerperAPIWrapper, \ + WolframAlphaAPIWrapper, GoogleJobsAPIWrapper, GoogleScholarAPIWrapper, BingSearchAPIWrapper, \ + GoldenQueryAPIWrapper, SearxSearchWrapper, SerpAPIWrapper, TwilioAPIWrapper +from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper +from langchain_community.utilities.wikidata import WikidataAPIWrapper +from langchain_core.documents import Document + +from .callable_registers import register_callable_tool + + +def langchain_network(**kwargs) -> str: + time.sleep(5) + return "" + + +def arxiv(arxiv_id: str, **kwargs) -> List[str]: + retriever = ArxivRetriever(load_max_docs=2) + docs: List[Document] = retriever.get_relevant_documents(query=arxiv_id) + return [doc.page_content for doc in docs] + + +def bing_search(query: str, bing_subscription_key: str, bing_search_url: str, **kwargs) -> str: + os.environ["BING_SUBSCRIPTION_KEY"] = bing_subscription_key + os.environ["BING_SEARCH_URL"] = bing_search_url + search = BingSearchAPIWrapper() + return search.run(query) + + +def brave_search(query: str, count: int, api_key: str, **kwargs) -> str: + brave_search_ = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) + return brave_search_.run(query) + + +def duck_duck_go_search(query: str, **kwargs) -> str: + search = DuckDuckGoSearchRun() + return search.invoke(query) + + +def google_jobs(query: str, serapi_api_key: str, **kwargs) -> str: + os.environ["SERPAPI_API_KEY"] = serapi_api_key + google_job_tool = GoogleJobsQueryRun(api_wrapper=GoogleJobsAPIWrapper()) + return google_job_tool.run(query) + + +def google_places(query: str, gplaces_api_key: str, **kwargs) -> str: + os.environ["GPLACES_API_KEY"] = gplaces_api_key + places = GooglePlacesTool() + return places.run(query) + + +def google_scholar(query: str, serp_api_key: str, **kwargs) -> str: + os.environ["SERP_API_KEY"] = serp_api_key + google_job_tool = GoogleScholarQueryRun(api_wrapper=GoogleScholarAPIWrapper()) + return google_job_tool.run(query) + + +def google_search(query: str, google_api_key: str, google_cse_id: str, k: int, siterestrict: bool, **kwargs) -> str: + wrapper = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id, k=k, + siterestrict=siterestrict) + search = GoogleSearchRun(api_wrapper=wrapper) + return search.run(query) + + +def google_serper(query: str, k: int, gl: str, hl: str, serper_api_key: str, **kwargs) -> str: + os.environ["SERPER_API_KEY"] = serper_api_key + search = GoogleSerperAPIWrapper(k=k, gl=gl, hl=hl) + return search.run(query) + + +def golden_query(query: str, golden_api_key: str, **kwargs) -> str: + os.environ["GOLDEN_API_KEY"] = golden_api_key + golden_query_api = GoldenQueryAPIWrapper() + return golden_query_api.run(query) + + +def pub_med(query: str) -> str: + pub_med_tool: PubmedQueryRun = PubmedQueryRun() + return pub_med_tool.invoke(query) + + +def mojeek_query(query: str, api_key: str) -> str: + search = MojeekSearch.config(api_key=api_key) + return search.run(query) + + +def reddit_search(query: str, sort: str, time_filter: str, subreddit: str, limit: str, client_id: str, + client_secret: str, user_agent: str) -> str: + search = RedditSearchRun( + api_wrapper=RedditSearchAPIWrapper( + reddit_client_id=client_id, + reddit_client_secret=client_secret, + reddit_user_agent=user_agent, + ) + ) + search_params = RedditSearchSchema(query=query, sort=sort, time_filter=time_filter, subreddit=subreddit, + limit=limit) + result = search.run(tool_input=search_params.dict()) + return result + + +def searxng_search(query: str, searx_host: str) -> str: + search = SearxSearchWrapper(searx_host=searx_host) + return search.run(query) + + +def serp_api(query: str, serpapi_api_key: str) -> str: + search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key) + return search.run(query) + + +def twilio(body: str, to: str, account_sid: str, auth_token: str, from_number: str) -> str: + twilio_api = TwilioAPIWrapper( + account_sid=account_sid, + auth_token=auth_token, + from_number=from_number + ) + return twilio_api.run(body, to) + + +def wikidata(query: str) -> str: + wikidata_query = WikidataQueryRun(api_wrapper=WikidataAPIWrapper()) + return wikidata_query.run(query) + + +def wikipedia(query: str, **kwargs) -> str: + wikipedia_query_run = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) + return wikipedia_query_run.run(query) + + +def wolfram_alpha(query: str, wolfram_alpha_appid: str) -> str: + wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) + return wolfram.run(query) + + +def youtube_search(query: str, **kwargs) -> str: + youtube_search_tool = YouTubeSearchTool() + return youtube_search_tool.run(query) + + +# Tuple 结构: (tool_func, config_args, return_description) +network_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (langchain_network, ["input"], "Youtube search."), + (arxiv, ["arxiv_id"], "ArXiv search."), + (bing_search, ["query", "bing_subscription_key", "bing_search_url"], "Bing search."), + (brave_search, ["query", "count", "api_key"], "Brave search."), + (duck_duck_go_search, ["query"], "DuckDuckGo Search."), + (google_jobs, ["query", "serapi_api_key"], "Google Jobs."), + (google_places, ["query", "gplaces_api_key"], "Google Places."), + (google_scholar, ["query", "serp_api_key"], "Google Scholar."), + (google_search, ["query", "google_api_key", "google_cse_id", "k", "siterestrict"], "Google Search."), + (google_serper, ["query", "serper_api_key", "k", "gl", "hl"], "Google Serper."), + (golden_query, ["query", "golden_api_key"], "Golden Query."), + (pub_med, ["query"], "PubMed."), + (mojeek_query, ["query", "api_key"], "Mojeek Search."), + (reddit_search, ["query", "sort", "time_filter", "subreddit", "limit", "client_id", "client_secret", "user_agent"], + "Reddit Search."), + (searxng_search, ["query", "searx_host"], "SearxNG Search."), + (serp_api, ["query", "serpapi_api_key"], "SerpAPI."), + (twilio, ["body", "to", "account_sid", "auth_token", "from_number"], "Twilio."), + (wikidata, ["query"], "Wikidata."), + (wikipedia, ["query"], "Wikipedia."), + (wolfram_alpha, ["query", "wolfram_alpha_appid"], "Wolfram Alpha."), + (youtube_search, ["query"], "Youtube Search."), +] + +for tool in network_toolkit: + register_callable_tool(tool, langchain_network.__module__, "langchain.tool") diff --git a/framework/fel/python/plugins/fel_langchain_tools/langchain_registers.py b/framework/fel/python/plugins/fel_langchain_tools/langchain_registers.py new file mode 100644 index 00000000..b7fe0c4d --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_tools/langchain_registers.py @@ -0,0 +1,96 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +import json +from typing import List, Any, Optional, Callable, Union +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def register_function_tools(tools: List[BaseTool], + config: Optional[RunnableConfig] = None, + **kwargs: Any): + """ + langchain 函数工具注册方法,注册无需 api key 的工具。 + + Args: + tools (List[BaseTool]): 表示 langchain 工具列表。 + config (Optional[RunnableConfig]): 表示 langchain runnable 配置信息。 + **kwargs (Any): 表示额外参数。 + """ + for tool in tools: + register_api_tools(tool, [], tool.name, config, **kwargs) + + +def _pop_api_keys(input_args: dict, api_keys_name: List[str]) -> dict: + if all(key in input_args.keys() for key in api_keys_name): + api_keys_values = [input_args.pop(key) for key in api_keys_name] + return dict(zip(api_keys_name, api_keys_values)) + else: + raise ValueError(f"{input_args} not contain all api keys in {api_keys_name}") + + +def _invoke(input_args: dict, tool_builder: Union[Callable[[dict], BaseTool], BaseTool], + extra_keys: List[str], + config: Optional[RunnableConfig] = None, + **kwargs: Any) -> str: + api_keys = _pop_api_keys(input_args, extra_keys) + if "__arg1" in input_args: + _input_args = input_args["__arg1"] + else: + _input_args = input_args + + if isinstance(tool_builder, BaseTool): + tool = tool_builder + else: + tool = tool_builder(api_keys) + + try: + tool_ans = tool.invoke(_input_args, config, **kwargs) + return _dump_ans_to_str(tool_ans) + except BaseException: + return "" + + +def _dump_ans_to_str(tool_ans): + if not isinstance(tool_ans, str): + try: + content = json.dumps(tool_ans, ensure_ascii=False) + except Exception: + content = str(tool_ans) + else: + content = tool_ans + return content + + +def register_api_tools(tool_builder: Union[Callable[[dict], BaseTool], BaseTool], + extra_keys: List[str], + tool_name: str, + config: Optional[RunnableConfig] = None, + **kwargs: Any): + """ + langchain api 工具注册方法。 + + Args: + tool_builder (Callable[[dict], BaseTool]): 表示 api 工具构造器。 + extra_keys (List[str]): 表示工具的额外参数,如 api key。 + tool_name (str): 工具名称。 + config (Optional[RunnableConfig]): 表示 langchain runnable 配置信息。 + **kwargs (Any): 表示额外参数。 + """ + tool_invoke = functools.partial(_invoke, tool_builder=tool_builder, extra_keys=extra_keys, + config=config, **kwargs) + tool_invoke.__module__ = register_api_tools.__module__ + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': str + } + generic_id = 'langchain.tool' + register_fitable(generic_id, tool_name, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, tool_name, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_langchain_tools/langchain_schema_helper.py b/framework/fel/python/plugins/fel_langchain_tools/langchain_schema_helper.py new file mode 100644 index 00000000..dbc84efc --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_tools/langchain_schema_helper.py @@ -0,0 +1,30 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import json +import os +import stat +from typing import List + +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function + + +def dump_schema(function_tools: List[BaseTool], file_path: str): + """ + 导出 langchain 函数工具 schema 的工具方法。 + + Args: + function_tools (List[BaseTool]): 表示 langchain 工具列表。 + file_path (str): 表示 schema 文件的导出路径。 + """ + tools_schema = [{ + "runnables": {"langchain": {"genericableId": "langchain.tool", "fitableId": f"{tool.name}"}}, + "schema": {**convert_to_openai_function(tool), "return": {"type": "string"}} + } for tool in function_tools] + + fd = os.open(file_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR) + with os.fdopen(fd, "w") as file: + json.dump({"tools": tools_schema}, file) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_langchain_tools/langchain_tools.py b/framework/fel/python/plugins/fel_langchain_tools/langchain_tools.py new file mode 100644 index 00000000..90ce8c4f --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_tools/langchain_tools.py @@ -0,0 +1,192 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import json +from urllib.parse import quote_plus +import psycopg2 +from langchain.agents import AgentExecutor + +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_community.agent_toolkits import create_sql_agent +from langchain_community.tools.sql_database.tool import ( + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool, + QuerySQLDataBaseTool, +) +from langchain_core.tools import BaseTool +from langchain_openai import ChatOpenAI, OpenAI +from langchain_community.utilities.requests import TextRequestsWrapper +from langchain_community.agent_toolkits import JsonToolkit, create_json_agent +from langchain_community.tools.json.tool import JsonSpec +from langchain_community.tools.requests.tool import ( + RequestsDeleteTool, + RequestsGetTool, + RequestsPatchTool, + RequestsPostTool, + RequestsPutTool, +) +from .langchain_registers import register_function_tools, register_api_tools + + +# 从app_engine加密传输敏感信息 +def get_db(sql_url: str, sql_table: str, sql_name: str, sql_pwd: str) -> SQLDatabase: + return SQLDatabase.from_uri( + "postgresql+psycopg2://%s:%s@%s/%s" % (quote_plus(sql_name), quote_plus(sql_pwd), sql_url, + quote_plus(sql_table))) + + +def langchain_sql_query(kwargs) -> BaseTool: + db = get_db(kwargs.get("sql_url"), kwargs.get("sql_table"), kwargs.get("sql_name"), kwargs.get("sql_pwd")) + + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + "'xxxx' in 'field list', use sql_db_schema " + "to query the correct table fields." + ) + + query_sql_database_tool = QuerySQLDataBaseTool( + db=db, description=query_sql_database_tool_description + ) + return query_sql_database_tool + + +def langchain_sql_info(kwargs) -> BaseTool: + db = get_db(kwargs.get("sql_url"), kwargs.get("sql_table"), kwargs.get("sql_name"), kwargs.get("sql_pwd")) + + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "schema and sample rows for those tables. " + "Be sure that the tables actually exist by calling " + "sql_db_list_tables first! " + "Example Input: table1, table2, table3" + ) + info_sql_database_tool = InfoSQLDatabaseTool( + db=db, description=info_sql_database_tool_description + ) + return info_sql_database_tool + + +def langchain_sql_list(kwargs) -> BaseTool: + db = get_db(kwargs.get("sql_url"), kwargs.get("sql_table"), kwargs.get("sql_name"), kwargs.get("sql_pwd")) + + list_sql_database_tool = ListSQLDatabaseTool(db=db) + return list_sql_database_tool + + +def langchain_sql_checker(kwargs) -> BaseTool: + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") + api_base = kwargs.get("api_base") + temperature = kwargs.get("temperature") or 0 + + db = get_db(kwargs.get("sql_url"), kwargs.get("sql_table"), kwargs.get("sql_name"), kwargs.get("sql_pwd")) + llm = ChatOpenAI(model_name=model_name, openai_api_base=api_base, openai_api_key=api_key, temperature=temperature) + + query_sql_checker_tool_description = ( + "Use this tool to double check if your query is correct before executing " + "it. Always use this tool before executing a query with " + "sql_db_query!" + ) + query_sql_checker_tool = QuerySQLCheckerTool( + db=db, llm=llm, description=query_sql_checker_tool_description + ) + return query_sql_checker_tool + + +def langchain_sql_agent(kwargs) -> AgentExecutor: + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") + api_base = kwargs.get("api_base") + temperature = kwargs.get("temperature") or 0 + + db = get_db(kwargs.get("sql_url"), kwargs.get("sql_table"), kwargs.get("sql_name"), kwargs.get("sql_pwd")) + llm = ChatOpenAI(model_name=model_name, openai_api_base=api_base, openai_api_key=api_key, temperature=temperature) + agent_executor = create_sql_agent(llm, db=db) + return agent_executor + + +def langchain_request_get(kwargs) -> BaseTool: + return RequestsGetTool( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, + ) + + +def langchain_request_post(kwargs) -> BaseTool: + return RequestsPostTool( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, + ) + + +def langchain_request_patch(kwargs) -> BaseTool: + return RequestsPatchTool( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, + ) + + +def langchain_request_delete(kwargs) -> BaseTool: + return RequestsDeleteTool( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, + ) + + +def langchain_request_put(kwargs) -> BaseTool: + return RequestsPutTool( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, + ) + + +def langchain_json_agent(kwargs) -> AgentExecutor: + json_str = kwargs.get("json_str") + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") + api_base = kwargs.get("api_base") + temperature = kwargs.get("temperature") or 0 + llm = ChatOpenAI(openai_api_base=api_base, openai_api_key=api_key, + model=model_name, temperature=temperature) + json_spec = JsonSpec(dict_=json.loads(json_str), max_value_length=4000) + json_toolkit = JsonToolkit(spec=json_spec) + json_agent_executor = create_json_agent(llm=llm, toolkit=json_toolkit, verbose=True) + return json_agent_executor + + +# function tools +function_tools = [] +register_function_tools(function_tools) + +api_tools = [ + (langchain_sql_query, ["sql_url", "sql_table", "sql_name", "sql_pwd"], "sql_db_query"), + (langchain_sql_info, ["sql_url", "sql_table", "sql_name", "sql_pwd"], "sql_db_schema"), + (langchain_sql_list, ["sql_url", "sql_table", "sql_name", "sql_pwd"], "sql_db_list_tables"), + (langchain_sql_checker, + ["model_name", "api_key", "api_base", "sql_url", "sql_table", "sql_name", "sql_pwd", "temperature"], + "sql_db_query_checker"), + (langchain_sql_agent, + ["model_name", "api_key", "api_base", "sql_url", "sql_table", "sql_name", "sql_pwd", "temperature"], "sql_agent"), + (langchain_request_get, ["url"], "request_get"), + (langchain_request_put, ["url"], "request_put"), + (langchain_request_post, ["url"], "request_post"), + (langchain_request_delete, ["url"], "request_delete"), + (langchain_request_patch, ["url"], "request_patch"), + (langchain_json_agent, ["model_name", "api_key", "api_base", "temperature", "json_str", "input"], "json_agent") +] +# api tools +for tool in api_tools: + register_api_tools(tool[0], tool[1], tool[2]) + +if __name__ == "__main__": + import time + from .langchain_schema_helper import dump_schema + + current_timestamp = time.strftime('%Y%m%d%H%M%S') + dump_schema(function_tools, f"./tool_schema-{str(current_timestamp)}.json") diff --git a/framework/fel/python/plugins/fel_langchain_tools/tools.json b/framework/fel/python/plugins/fel_langchain_tools/tools.json new file mode 100644 index 00000000..f76dd771 --- /dev/null +++ b/framework/fel/python/plugins/fel_langchain_tools/tools.json @@ -0,0 +1,353 @@ +{ + "tools": [ + { + "tags": [ + "Langchain" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "Python_REPL" + } + }, + "schema": { + "name": "Python_REPL", + "description": "A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.", + "parameters": { + "properties": { + "__arg1": { + "title": "__arg1", + "type": "string" + } + }, + "required": [ + "__arg1" + ], + "type": "object" + }, + "return": { + "type": "string" + } + } + }, + { + "tags": [ + "Langchain", + "Config" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "google_search" + } + }, + "schema": { + "name": "google_search", + "description": "A wrapper around Google Search. Useful for when you need to answer questions about current events. Input should be a search query.", + "parameters": { + "properties": { + "__arg1": { + "title": "__arg1", + "type": "string" + }, + "google_api_key": { + "description": "google search api key", + "type": "string" + }, + "google_cse_id": { + "description": "google search cse id", + "type": "string" + }, + "k": { + "description": "number of search results", + "default": 10, + "type": "integer" + }, + "siterestrict": { + "description": "restricts search results", + "default": false, + "type": "boolean" + } + }, + "required": [ + "__arg1", + "google_api_key", + "google_cse_id" + ], + "type": "object" + }, + "return": { + "type": "string" + }, + "parameterExtensions": { + "config": [ + "google_api_key", + "google_cse_id", + "k", + "siterestrict" + ] + } + } + }, + { + "tags": [ + "Langchain", + "Config" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "sql_db_query" + } + }, + "schema": { + "name": "sql_db_query", + "description": "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", + "parameters": { + "type": "object", + "properties": { + "query": { + "description": "A detailed and correct SQL query.", + "type": "string" + }, + "sql_url": { + "type": "string", + "description": "sql_url" + }, + "sql_table": { + "type": "string", + "description": "sql_table" + }, + "sql_name": { + "type": "string", + "description": "sql_name" + }, + "sql_pwd": { + "type": "string", + "description": "sql_pwd" + } + }, + "required": [ + "query", + "sql_url", + "sql_table", + "sql_name", + "sql_pwd" + ] + }, + "return": { + "type": "string" + }, + "parameterExtensions": { + "config": [ + "sql_url", + "sql_table", + "sql_name", + "sql_pwd" + ] + } + } + }, + { + "tags": [ + "Langchain", + "Config" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "sql_db_schema" + } + }, + "schema": { + "name": "sql_db_schema", + "description": "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3", + "parameters": { + "type": "object", + "properties": { + "table_names": { + "description": "A comma-separated list of the table names for which to return the schema. Example input: 'table1, table2, table3'", + "type": "string" + }, + "sql_url": { + "type": "string", + "description": "sql_url" + }, + "sql_table": { + "type": "string", + "description": "sql_table" + }, + "sql_name": { + "type": "string", + "description": "sql_name" + }, + "sql_pwd": { + "type": "string", + "description": "sql_pwd" + } + }, + "required": [ + "table_names", + "sql_url", + "sql_table", + "sql_name", + "sql_pwd" + ] + }, + "return": { + "type": "string" + }, + "parameterExtensions": { + "config": [ + "sql_url", + "sql_table", + "sql_name", + "sql_pwd" + ] + } + } + }, + { + "tags": [ + "Langchain", + "Config" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "sql_db_list_tables" + } + }, + "schema": { + "name": "sql_db_list_tables", + "description": "Input is an empty string, output is a comma-separated list of tables in the database.", + "parameters": { + "type": "object", + "properties": { + "tool_input": { + "description": "An empty string", + "default": "", + "type": "string" + }, + "sql_url": { + "type": "string", + "description": "sql_url" + }, + "sql_table": { + "type": "string", + "description": "sql_table" + }, + "sql_name": { + "type": "string", + "description": "sql_name" + }, + "sql_pwd": { + "type": "string", + "description": "sql_pwd" + } + }, + "required": [ + "sql_table", + "sql_url", + "sql_name", + "sql_pwd" + ] + }, + "return": { + "type": "string" + }, + "parameterExtensions": { + "config": [ + "sql_url", + "sql_table", + "sql_name", + "sql_pwd" + ] + } + } + }, + { + "tags": [ + "Langchain", + "Config" + ], + "runnables": { + "langchain": { + "genericableId": "langchain.tool", + "fitableId": "sql_db_query_checker" + } + }, + "schema": { + "name": "sql_db_query_checker", + "description": "Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!", + "parameters": { + "type": "object", + "properties": { + "query": { + "description": "A detailed and SQL query to be checked.", + "type": "string" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + }, + "sql_url": { + "type": "string", + "description": "sql_url" + }, + "sql_table": { + "type": "string", + "description": "sql_table" + }, + "sql_name": { + "type": "string", + "description": "sql_name" + }, + "sql_pwd": { + "type": "string", + "description": "sql_pwd" + }, + "temperature": { + "type": "string", + "description": "temperature" + } + }, + "required": [ + "query", + "model_name", + "api_key", + "api_base", + "sql_url", + "sql_table", + "sql_name", + "sql_pwd", + "temperature" + ] + }, + "return": { + "type": "string" + }, + "parameterExtensions": { + "config": [ + "model_name", + "api_key", + "api_base", + "sql_url", + "sql_table", + "sql_name", + "sql_pwd", + "temperature" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_index_tools/callable_registers.py b/framework/fel/python/plugins/fel_llama_index_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_llama_index_tools/llama_rag_basic_toolkit.py b/framework/fel/python/plugins/fel_llama_index_tools/llama_rag_basic_toolkit.py new file mode 100644 index 00000000..d6084d68 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/llama_rag_basic_toolkit.py @@ -0,0 +1,158 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +import os +import traceback +from enum import Enum, unique +from inspect import signature +from typing import List, Callable, Any, Tuple + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable +from llama_index.core import PromptTemplate +from llama_index.core.base.base_selector import SingleSelection +from llama_index.core.postprocessor import SimilarityPostprocessor, SentenceEmbeddingOptimizer, LLMRerank, \ + LongContextReorder, FixedRecencyPostprocessor +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.prompts import PromptType +from llama_index.core.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT_TMPL +from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector +from llama_index.core.selectors.prompts import DEFAULT_SINGLE_SELECT_PROMPT_TMPL, DEFAULT_MULTI_SELECT_PROMPT_TMPL +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.legacy.llms import OpenAILike + +from .callable_registers import register_callable_tool +from .types.document import Document +from .node_utils import document_to_query_node, query_node_to_document + +os.environ["no_proxy"] = "*" + + +def __invoke_postprocessor(postprocessor: BaseNodePostprocessor, nodes: List[Document], + query_str: str) -> List[Document]: + if len(nodes) == 0: + return [] + try: + postprocess_nodes = postprocessor.postprocess_nodes([document_to_query_node(node) for node in nodes], + query_str=query_str) + return [query_node_to_document(node) for node in postprocess_nodes] + except BaseException: + fit_logger.error("Invoke postprocessor failed.") + traceback.print_exc() + return nodes + + +def similarity_filter(nodes: List[Document], query_str: str, **kwargs) -> List[Document]: + """Remove documents that are below a similarity score threshold.""" + similarity_cutoff = float(kwargs.get("similarity_cutoff") or 0.3) + postprocessor = SimilarityPostprocessor(similarity_cutoff=similarity_cutoff) + return __invoke_postprocessor(postprocessor, nodes, query_str) + + +def sentence_embedding_optimizer(nodes: List[Document], query_str: str, **kwargs) -> List[Document]: + """Optimization of a text chunk given the query by shortening the input text.""" + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") or "bce-embedding-base_v1" + api_base = kwargs.get("api_base") or ("http://51.36.139.24:8010/v1" if api_key == "EMPTY" else None) + percentile_cutoff = kwargs.get("percentile_cutoff") + threshold_cutoff = kwargs.get("threshold_cutoff") + percentile_cutoff = percentile_cutoff if percentile_cutoff is None else float(percentile_cutoff) + threshold_cutoff = threshold_cutoff if threshold_cutoff is None else float(threshold_cutoff) + + embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key) + optimizer = SentenceEmbeddingOptimizer(embed_model=embed_model, percentile_cutoff=percentile_cutoff, + threshold_cutoff=threshold_cutoff) + return __invoke_postprocessor(optimizer, nodes, query_str) + + +def llm_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]: + """ + Re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are. + Returns the top N ranked nodes. + """ + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat" + api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None) + prompt = kwargs.get("prompt") or DEFAULT_CHOICE_SELECT_PROMPT_TMPL + choice_batch_size = int(kwargs.get("choice_batch_size") or 10) + top_n = int(kwargs.get("top_n") or 10) + + llm = OpenAILike(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096) + choice_select_prompt = PromptTemplate(prompt, prompt_type=PromptType.CHOICE_SELECT) + llm_rerank_obj = LLMRerank(llm=llm, choice_select_prompt=choice_select_prompt, choice_batch_size=choice_batch_size, + top_n=top_n) + return __invoke_postprocessor(llm_rerank_obj, nodes, query_str) + + +def long_context_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]: + """Re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed.""" + return __invoke_postprocessor(LongContextReorder(), nodes, query_str) + + +@unique +class SelectorMode(Enum): + SINGLE = "single" + MULTI = "multi" + + +def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[SingleSelection]: + """LLM-based selector that chooses one or multiple out of many options.""" + if len(choice) == 0: + return [] + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat" + api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None) + prompt = kwargs.get("prompt") + mode = str(kwargs.get("mode") or SelectorMode.SINGLE.value) + if mode.lower() not in [m.value for m in SelectorMode]: + raise ValueError(f"Invalid mode {mode}.") + + llm = OpenAILike(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096) + if mode.lower() == SelectorMode.SINGLE.value: + selector_prompt = prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL + selector = LLMSingleSelector.from_defaults(llm=llm, prompt_template_str=selector_prompt) + else: + multi_selector_prompt = prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL + selector = LLMMultiSelector.from_defaults(llm=llm, prompt_template_str=multi_selector_prompt) + try: + return selector.select(choice, query_str).selections + except BaseException: + fit_logger.error("Invoke choice selector failed.") + traceback.print_exc() + return [] + + +def fixed_recency(nodes: List[Document], tok_k: int, date_key: str, query_str: str, **kwargs) -> List[Document]: + """This postprocessor returns the top K nodes sorted by date""" + postprocessor = FixedRecencyPostprocessor( + tok_k=tok_k, date_key=date_key if date_key else "date" + ) + return __invoke_postprocessor(postprocessor, nodes, query_str) + + +# Tuple 结构: (tool_func, config_args, return_description) +rag_basic_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (similarity_filter, ["similarity_cutoff"], "The filtered documents."), + (sentence_embedding_optimizer, ["model_name", "api_key", "api_base", "percentile_cutoff", "threshold_cutoff"], + "The optimized documents."), + (llm_rerank, ["model_name", "api_key", "api_base", "prompt", "choice_batch_size", "top_n"], + "The re-ordered documents."), + (long_context_rerank, [], "The re-ordered documents."), + (llm_choice_selector, ["model_name", "api_key", "api_base", "prompt", "mode"], "The selected choice."), + (fixed_recency, ["nodes", "tok_k", "date_key", "query_str"], "The fixed recency postprocessor") +] + + +for tool in rag_basic_toolkit: + register_callable_tool(tool, llm_choice_selector.__module__, "llama_index.rag.toolkit") + + +if __name__ == '__main__': + import time + from .llama_schema_helper import dump_llama_schema + + current_timestamp = time.strftime('%Y%m%d%H%M%S') + dump_llama_schema(rag_basic_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json") diff --git a/framework/fel/python/plugins/fel_llama_index_tools/llama_schema_helper.py b/framework/fel/python/plugins/fel_llama_index_tools/llama_schema_helper.py new file mode 100644 index 00000000..0c22bc99 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/llama_schema_helper.py @@ -0,0 +1,126 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import json +import os +import re +import stat +from inspect import Parameter +from inspect import signature +from typing import List, Tuple, Any, Callable, Optional + +from llama_index.core.bridge.pydantic import FieldInfo, create_model +from llama_index.core.tools import FunctionTool + + +def __get_ref_item(value: dict, definitions: dict) -> dict: + sub_properties_name = re.findall("^#/definitions/(.+)$", value.get("$ref")) + if len(sub_properties_name) == 0: + raise ValueError(f"Invalid reference properties {value.get('$ref')}.") + ref_item = definitions.get(sub_properties_name[0]) + ref_item["properties"] = __flat_properties(ref_item.get("properties"), definitions) + return ref_item + + +def __flat_properties(properties: dict, definitions: dict) -> dict: + if definitions is None: + return properties + flat_properties = dict() + for key, value in properties.items(): + if value.__contains__("$ref"): + flat_properties[key] = __get_ref_item(value, definitions) + continue + array_item = value.get("items") + if array_item is not None and array_item.__contains__("$ref"): + value["items"] = __get_ref_item(array_item, definitions) + flat_properties[key] = value + continue + else: + flat_properties[key] = value + return flat_properties + + +def __get_return_properties(func: Callable[..., Any], return_description: str) -> dict: + func_signature = signature(func) + param_type = func_signature.return_annotation + if param_type is Parameter.empty: + param_type = Any + + fields = {return_description: (param_type, FieldInfo())} + field_model = create_model(return_description, **fields) + parameters = field_model.schema() + parameters = { + key: value + for key, value in parameters.items() + if key in ["type", "properties", "required", "definitions"] + } + properties = __flat_properties(parameters.get("properties"), parameters.get("definitions")) + if return_description in properties: + return properties[return_description] + else: + return dict() + + +def __get_llama_rag_tool_schema(tool: Tuple[Callable[..., Any], List[str], str]) -> dict: + func = tool[0] + metadata = FunctionTool.from_defaults(fn=func).metadata + parameters_dict = metadata.get_parameters_dict() + property_key = "properties" + parameters_dict.get(property_key).pop("kwargs") + + dynamic_args = tool[1] + dynamic_args_dict = dict() + for arg in dynamic_args: + dynamic_args_dict[arg] = {"type": "string", "description": arg} + + definition = __get_param_definition(parameters_dict) + flat_properties = __flat_properties(parameters_dict.get(property_key), definition) + parameters_dict[property_key] = {**flat_properties, **dynamic_args_dict} + tool_schema = { + "name": metadata.name, + "description": func.__doc__, + "parameters": parameters_dict, + "return": __get_return_properties(func, tool[2]), + } + if len(dynamic_args_dict) != 0: + tool_schema["parameterExtensions"] = {"config": list(dynamic_args_dict.keys())} + return tool_schema + + +def __get_param_definition(parameters_dict: dict) -> Optional[dict]: + if parameters_dict.__contains__("definitions"): + return parameters_dict.pop("definitions") + return None + + +def dump_llama_schema(llama_toolkit: List[Tuple[Callable[..., Any], List[str], str]], file_path: str): + """ + 导出 LlamaIndex 函数工具 schema 的工具方法。 + + Args: + llama_toolkit (List[Tuple[Callable[..., Any], List[str]]]): 表示 llama_index rag 工具列表。 + file_path (str): 表示 schema 文件的导出路径。 + """ + dump_callable_schema(llama_toolkit, file_path, "LlamaIndex", "llama_index.rag.toolkit") + + +def dump_callable_schema(callable_toolkit: List[Tuple[Callable[..., Any], List[str], str]], file_path: str, tag: str, + genericable_id: str): + """ + 导出函数工具 schema 的工具方法。 + + Args: + callable_toolkit (List[Tuple[Callable[..., Any], List[str]]]): 表示函数工具列表。 + file_path (str): 表示 schema 文件的导出路径。 + """ + tools_schema = [{ + "tags": [tag], + "runnables": {tag: {"genericableId": genericable_id, "fitableId": f"{tool[0].__name__}"}}, + "schema": {**__get_llama_rag_tool_schema(tool)} + } for tool in callable_toolkit] + + fd = os.open(file_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR) + with os.fdopen(fd, "w") as file: + json.dump({"tools": tools_schema}, file) diff --git a/framework/fel/python/plugins/fel_llama_index_tools/node_utils.py b/framework/fel/python/plugins/fel_llama_index_tools/node_utils.py new file mode 100644 index 00000000..65bcabf9 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/node_utils.py @@ -0,0 +1,58 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from llama_index.core.multi_modal_llms.generic_utils import encode_image +from llama_index.core.schema import ImageNode, TextNode, NodeWithScore +from llama_index.core import Document as LDocument + +from .types.document import Document +from .types.media import Media + + +def document_to_query_node(doc_input: Document): + if isinstance(doc_input, dict): + doc = Document(**doc_input) + else: + doc = doc_input + + if doc.media is not None: + node = ImageNode(image=doc.media.data, image_mimetype=doc.media.mime) + else: + node = TextNode() + node.set_content(doc.content) + node.metadata = doc.metadata + return NodeWithScore(node=node, score=doc.metadata["score"]) + + +def query_node_to_document(node_with_score: NodeWithScore) -> Document: + node = node_with_score.node + metadata = node.metadata or {} + metadata['score'] = node_with_score.score + content = None + image = None + file_path_key = "file_path" + if isinstance(node, ImageNode): + mime = node.image_mimetype or "image/jpeg" + data = None + if node.image and node.image != "": + data = node.image + elif node.image_url and node.image_url != "": + data = node.image_url + elif node.image_path and node.image_path != "": + data = encode_image(node.image_path) + elif file_path_key in node.metadata and node.metadata[file_path_key] != "": + data = encode_image(node.metadata[file_path_key]) + image = Media(mime=mime, data=data) + if isinstance(node, TextNode): + content = node.get_content() + return Document(content=content, media=image, metadata=metadata) + + +def to_llama_index_document(doc: Document) -> LDocument: + metadata = {} + metadata.update(doc.metadata) + if doc.media is not None: + metadata.update({"mime": doc.media.mime, "data": doc.media.data}) + return LDocument(text=doc.content, metadata=metadata) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_index_tools/tools.json b/framework/fel/python/plugins/fel_llama_index_tools/tools.json new file mode 100644 index 00000000..62283787 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/tools.json @@ -0,0 +1,685 @@ +{ + "tools": [ + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "similarity_filter" + } + }, + "schema": { + "name": "similarity_filter", + "description": "Remove documents that are below a similarity score threshold.", + "parameters": { + "type": "object", + "properties": { + "nodes": { + "title": "Nodes", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + }, + "similarity_cutoff": { + "type": "string", + "description": "similarity_cutoff" + } + }, + "required": [ + "nodes", + "query_str" + ] + }, + "return": { + "title": "The Filtered Documents.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "parameterExtensions": { + "config": [ + "similarity_cutoff" + ] + } + } + }, + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "sentence_embedding_optimizer" + } + }, + "schema": { + "name": "sentence_embedding_optimizer", + "description": "Optimization of a text chunk given the query by shortening the input text.", + "parameters": { + "type": "object", + "properties": { + "nodes": { + "title": "Nodes", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + }, + "percentile_cutoff": { + "type": "string", + "description": "percentile_cutoff" + }, + "threshold_cutoff": { + "type": "string", + "description": "threshold_cutoff" + } + }, + "required": [ + "nodes", + "query_str" + ] + }, + "return": { + "title": "The Optimized Documents.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "parameterExtensions": { + "config": [ + "model_name", + "api_key", + "api_base", + "percentile_cutoff", + "threshold_cutoff" + ] + } + } + }, + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "llm_rerank" + } + }, + "schema": { + "name": "llm_rerank", + "description": "\n Re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are.\n Returns the top N ranked nodes.\n ", + "parameters": { + "type": "object", + "properties": { + "nodes": { + "title": "Nodes", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + }, + "prompt": { + "type": "string", + "description": "prompt" + }, + "choice_batch_size": { + "type": "string", + "description": "choice_batch_size" + }, + "top_n": { + "type": "string", + "description": "top_n" + } + }, + "required": [ + "nodes", + "query_str" + ] + }, + "return": { + "title": "The Re-Ordered Documents.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "parameterExtensions": { + "config": [ + "model_name", + "api_key", + "api_base", + "prompt", + "choice_batch_size", + "top_n" + ] + } + } + }, + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "long_context_rerank" + } + }, + "schema": { + "name": "long_context_rerank", + "description": "Re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed.", + "parameters": { + "type": "object", + "properties": { + "nodes": { + "title": "Nodes", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + } + }, + "required": [ + "nodes", + "query_str" + ] + }, + "return": { + "title": "The Re-Ordered Documents.", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + } + } + }, + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "llm_choice_selector" + } + }, + "schema": { + "name": "llm_choice_selector", + "description": "LLM-based selector that chooses one or multiple out of many options.", + "parameters": { + "type": "object", + "properties": { + "choice": { + "title": "Choice", + "type": "array", + "items": { + "type": "string" + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + }, + "prompt": { + "type": "string", + "description": "prompt" + }, + "mode": { + "type": "string", + "description": "mode" + } + }, + "required": [ + "choice", + "query_str" + ] + }, + "return": { + "title": "The Selected Choice.", + "type": "array", + "items": { + "title": "SingleSelection", + "description": "A single selection of a choice.", + "type": "object", + "properties": { + "index": { + "title": "Index", + "type": "integer" + }, + "reason": { + "title": "Reason", + "type": "string" + } + }, + "required": [ + "index", + "reason" + ] + } + }, + "parameterExtensions": { + "config": [ + "model_name", + "api_key", + "api_base", + "prompt", + "mode" + ] + } + } + }, + { + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "fixed_recency" + } + }, + "schema": { + "name": "fixed_recency", + "description": "This postprocessor returns the top K nodes sorted by date", + "parameters": { + "type": "object", + "properties": { + "nodes": { + "type": "string", + "description": "nodes" + }, + "tok_k": { + "type": "string", + "description": "tok_k" + }, + "date_key": { + "type": "string", + "description": "date_key" + }, + "query_str": { + "type": "string", + "description": "query_str" + } + }, + "required": [ + "nodes", + "tok_k", + "date_key", + "query_str" + ] + }, + "return": { + "title": "The Fixed Recency Postprocessor", + "type": "array", + "items": { + "title": "Document", + "description": "Document.", + "type": "object", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "media": { + "title": "Media", + "description": "Media.", + "type": "object", + "properties": { + "mime": { + "title": "Mime", + "type": "string" + }, + "data": { + "title": "Data", + "type": "string" + } + }, + "required": [ + "mime", + "data" + ] + }, + "metadata": { + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "content", + "metadata" + ] + } + }, + "parameterExtensions": { + "config": [ + "nodes", + "tok_k", + "date_key", + "query_str" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_index_tools/types/__init__.py b/framework/fel/python/plugins/fel_llama_index_tools/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/plugins/fel_llama_index_tools/types/document.py b/framework/fel/python/plugins/fel_llama_index_tools/types/document.py new file mode 100644 index 00000000..4989999f --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/types/document.py @@ -0,0 +1,22 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +from .serializable import Serializable +from .media import Media + + +class Document(Serializable): + """ + Document. + """ + content: str + media: Media = None + metadata: typing.Dict[str, object] + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_llama_index_tools/types/media.py b/framework/fel/python/plugins/fel_llama_index_tools/types/media.py new file mode 100644 index 00000000..b1bdb54a --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/types/media.py @@ -0,0 +1,18 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from .serializable import Serializable + + +class Media(Serializable): + """ + Media. + """ + mime: str + data: str + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_llama_index_tools/types/serializable.py b/framework/fel/python/plugins/fel_llama_index_tools/types/serializable.py new file mode 100644 index 00000000..4522897f --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_index_tools/types/serializable.py @@ -0,0 +1,25 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +try: + import pydantic + + if pydantic.__version__.startswith("1."): + raise ImportError + import pydantic.v1 as pydantic +except ImportError: + import pydantic + + +class Serializable(pydantic.BaseModel): + def json(self, **kwargs: typing.Any) -> str: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().json(**kwargs_with_defaults) + + def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().dict(**kwargs_with_defaults) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_selector_tools/callable_registers.py b/framework/fel/python/plugins/fel_llama_selector_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_selector_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_llama_selector_tools/llama_selector.py b/framework/fel/python/plugins/fel_llama_selector_tools/llama_selector.py new file mode 100644 index 00000000..32d57516 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_selector_tools/llama_selector.py @@ -0,0 +1,48 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import traceback +from typing import Tuple, List, Any, Callable + +from fitframework import fit_logger +from llama_index.core.base.base_selector import SingleSelection +from llama_index.core.selectors import EmbeddingSingleSelector +from llama_index.embeddings.openai import OpenAIEmbedding + +from .callable_registers import register_callable_tool + + +def embedding_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[SingleSelection]: + """ Embedding selector that chooses one out of many options.""" + if len(choice) == 0: + return [] + api_key = kwargs.get("api_key") or "EMPTY" + model_name = kwargs.get("model_name") or "bge-large-zh" + api_base = kwargs.get("api_base") or None + + embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key) + selector = EmbeddingSingleSelector.from_defaults(embed_model=embed_model) + try: + return selector.select(choice, query_str).selections + except BaseException: + fit_logger.error("Invoke embedding choice selector failed.") + traceback.print_exc() + return [] + + +# Tuple 结构: (tool_func, config_args, return_description) +selector_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (embedding_choice_selector, ["model_name", "api_key", "api_base", "prompt", "mode"], "The selected choice."), +] + +for tool in selector_toolkit: + register_callable_tool(tool, embedding_choice_selector.__module__, "llama_index.rag.toolkit") + +if __name__ == '__main__': + import time + from .llama_schema_helper import dump_llama_schema + + current_timestamp = time.strftime('%Y%m%d%H%M%S') + dump_llama_schema(selector_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json") diff --git a/framework/fel/python/plugins/fel_llama_selector_tools/tools.json b/framework/fel/python/plugins/fel_llama_selector_tools/tools.json new file mode 100644 index 00000000..df2b6f70 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_selector_tools/tools.json @@ -0,0 +1,91 @@ +{ + "tools": [ + { + "tags": [ + "LlamaIndex" + ], + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "embedding_choice_selector" + } + }, + "schema": { + "name": "embedding_choice_selector", + "description": " Embedding selector that chooses one out of many options.", + "parameters": { + "type": "object", + "properties": { + "choice": { + "title": "Choice", + "type": "array", + "items": { + "type": "string" + } + }, + "query_str": { + "title": "Query Str", + "type": "string" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + }, + "prompt": { + "type": "string", + "description": "prompt" + }, + "mode": { + "type": "string", + "description": "mode" + } + }, + "required": [ + "choice", + "query_str" + ] + }, + "return": { + "title": "The Selected Choice.", + "type": "array", + "items": { + "title": "SingleSelection", + "description": "A single selection of a choice.", + "type": "object", + "properties": { + "index": { + "title": "Index", + "type": "integer" + }, + "reason": { + "title": "Reason", + "type": "string" + } + }, + "required": [ + "index", + "reason" + ] + } + }, + "parameterExtensions": { + "config": [ + "model_name", + "api_key", + "api_base", + "prompt", + "mode" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/callable_registers.py b/framework/fel/python/plugins/fel_llama_splitter_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/llama_splitter_tool.py b/framework/fel/python/plugins/fel_llama_splitter_tools/llama_splitter_tool.py new file mode 100644 index 00000000..9c8fb421 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/llama_splitter_tool.py @@ -0,0 +1,121 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import traceback +from typing import Tuple, List, Any, Callable + +from fitframework import fit_logger +from llama_index.core.node_parser import ( + SentenceSplitter, + TokenTextSplitter, + SemanticSplitterNodeParser, + SentenceWindowNodeParser +) +from llama_index.core.schema import BaseNode +from llama_index.core.schema import Document as LDocument +from llama_index.embeddings.openai import OpenAIEmbedding + +from .callable_registers import register_callable_tool +from .node_utils import to_llama_index_document + + +def sentence_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int, **kwargs) -> List[str]: + """Parse text with a preference for complete sentences.""" + if len(text) == 0: + return [] + splitter = SentenceSplitter( + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + try: + return splitter.split_text(text) + except BaseException: + fit_logger.error("Invoke sentence splitter failed.") + traceback.print_exc() + return [] + + +def token_text_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int, **kwargs) -> List[str]: + """Splitting text that looks at word tokens.""" + if len(text) == 0: + return [] + splitter = TokenTextSplitter( + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + try: + return splitter.split_text(text) + except BaseException: + fit_logger.error("Invoke token text splitter failed.") + traceback.print_exc() + return [] + + +def semantic_splitter(buffer_size: int, breakpoint_percentile_threshold: int, docs: List[LDocument], **kwargs) \ + -> List[BaseNode]: + """Splitting text that looks at word tokens.""" + if len(docs) == 0: + return [] + api_key = kwargs.get("api_key") + model_name = kwargs.get("model_name") + api_base = kwargs.get("api_base") + + embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key, max_tokens=4096) + + splitter = SemanticSplitterNodeParser( + buffer_size=buffer_size, + breakpoint_percentile_threshold=breakpoint_percentile_threshold, + embed_model=embed_model + ) + ldocs = [to_llama_index_document(doc) for doc in docs] + try: + return splitter.build_semantic_nodes_from_documents(documents=ldocs) + except BaseException: + fit_logger.error("Invoke semantic splitter failed.") + traceback.print_exc() + return [] + + +def sentence_window_node_parser(window_size: int, window_metadata_key: str, original_text_metadata_key: str, + docs: List[LDocument], **kwargs) -> List[BaseNode]: + """Splitting text that looks at word tokens.""" + if len(docs) == 0: + return [] + + node_parser = SentenceWindowNodeParser.from_defaults( + window_size=window_size, + window_metadata_key=window_metadata_key, + original_text_metadata_key=original_text_metadata_key, + ) + try: + return node_parser.get_nodes_from_documents(docs) + except BaseException: + fit_logger.error("Invoke semantic splitter failed.") + traceback.print_exc() + return [] + + +# Tuple 结构: (tool_func, config_args, return_description) +splitter_basic_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (sentence_splitter, ["text", "separator", "chunk_size", "chunk_overlap"], "Split sentences by sentence."), + (token_text_splitter, ["text", "separator", "chunk_size", "chunk_overlap"], "Split sentences by token."), + (semantic_splitter, + ["docs", "buffer_size", "breakpoint_percentile_threshold", "chunk_overlap", "model_name", "api_key", "api_base"], + "Split sentences by semantic."), + (sentence_window_node_parser, ["docs", "window_size", "window_metadata_key", "original_text_metadata_key"], + "Splits all documents into individual sentences") +] + +for tool in splitter_basic_toolkit: + register_callable_tool(tool, sentence_splitter.__module__, "llama_index.rag.toolkit") + +if __name__ == '__main__': + import time + from .llama_schema_helper import dump_llama_schema + + current_timestamp = time.strftime('%Y%m%d%H%M%S') + dump_llama_schema(splitter_basic_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json") diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/node_utils.py b/framework/fel/python/plugins/fel_llama_splitter_tools/node_utils.py new file mode 100644 index 00000000..deba86e3 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/node_utils.py @@ -0,0 +1,62 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from typing import Dict + +from llama_index.core.multi_modal_llms.generic_utils import encode_image +from llama_index.core.schema import ImageNode, TextNode, NodeWithScore +from llama_index.core import Document as LDocument + +from .types.document import Document +from .types.media import Media + + +def document_to_query_node(doc_input: Document): + if isinstance(doc_input, dict): + doc = Document(**doc_input) + else: + doc = doc_input + + if doc.media is not None: + node = ImageNode(image=doc.media.data, image_mimetype=doc.media.mime) + else: + node = TextNode() + node.set_content(doc.content) + node.metadata = doc.metadata + return NodeWithScore(node=node, score=doc.metadata["score"]) + + +def query_node_to_document(node_with_score: NodeWithScore) -> Document: + node = node_with_score.node + metadata = node.metadata or {} + metadata['score'] = node_with_score.score + content = None + image = None + file_path_key = "file_path" + if isinstance(node, ImageNode): + mime = node.image_mimetype or "image/jpeg" + data = None + if node.image and node.image != "": + data = node.image + elif node.image_url and node.image_url != "": + data = node.image_url + elif node.image_path and node.image_path != "": + data = encode_image(node.image_path) + elif file_path_key in node.metadata and node.metadata[file_path_key] != "": + data = encode_image(node.metadata[file_path_key]) + image = Media(mime=mime, data=data) + if isinstance(node, TextNode): + content = node.get_content() + return Document(content=content, media=image, metadata=metadata) + + +def to_llama_index_document(doc: Document) -> LDocument: + metadata = {} + if isinstance(doc, Dict): + doc = Document.parse_obj(doc) + metadata.update(doc.metadata) + if doc.media is not None: + metadata.update({"mime": doc.media.mime, "data": doc.media.data}) + return LDocument(text=doc.content, metadata=metadata) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/tools.json b/framework/fel/python/plugins/fel_llama_splitter_tools/tools.json new file mode 100644 index 00000000..ad8fad58 --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/tools.json @@ -0,0 +1,368 @@ +{ + "tools": [ + { + "tags": [ + "LlamaIndex" + ], + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "sentence_splitter" + } + }, + "schema": { + "name": "sentence_splitter", + "description": "Parse text with a preference for complete sentences.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "text" + }, + "separator": { + "type": "string", + "description": "separator" + }, + "chunk_size": { + "type": "string", + "description": "chunk_size" + }, + "chunk_overlap": { + "type": "string", + "description": "chunk_overlap" + } + }, + "required": [ + "text", + "separator", + "chunk_size", + "chunk_overlap" + ] + }, + "return": { + "title": "Split Sentences By Sentence.", + "type": "array", + "items": { + "type": "string" + } + }, + "parameterExtensions": { + "config": [ + "text", + "separator", + "chunk_size", + "chunk_overlap" + ] + } + } + }, + { + "tags": [ + "LlamaIndex" + ], + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "token_text_splitter" + } + }, + "schema": { + "name": "token_text_splitter", + "description": "Splitting text that looks at word tokens.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "text" + }, + "separator": { + "type": "string", + "description": "separator" + }, + "chunk_size": { + "type": "string", + "description": "chunk_size" + }, + "chunk_overlap": { + "type": "string", + "description": "chunk_overlap" + } + }, + "required": [ + "text", + "separator", + "chunk_size", + "chunk_overlap" + ] + }, + "return": { + "title": "Split Sentences By Token.", + "type": "array", + "items": { + "type": "string" + } + }, + "parameterExtensions": { + "config": [ + "text", + "separator", + "chunk_size", + "chunk_overlap" + ] + } + } + }, + { + "tags": [ + "LlamaIndex" + ], + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "semantic_splitter" + } + }, + "schema": { + "name": "semantic_splitter", + "description": "Splitting text that looks at word tokens.", + "parameters": { + "type": "object", + "properties": { + "buffer_size": { + "type": "string", + "description": "buffer_size" + }, + "breakpoint_percentile_threshold": { + "type": "string", + "description": "breakpoint_percentile_threshold" + }, + "docs": { + "type": "string", + "description": "docs" + }, + "chunk_overlap": { + "type": "string", + "description": "chunk_overlap" + }, + "model_name": { + "type": "string", + "description": "model_name" + }, + "api_key": { + "type": "string", + "description": "api_key" + }, + "api_base": { + "type": "string", + "description": "api_base" + } + }, + "required": [ + "buffer_size", + "breakpoint_percentile_threshold", + "docs" + ] + }, + "return": { + "title": "Split Sentences By Semantic.", + "type": "array", + "items": { + "title": "BaseNode", + "description": "Base node Object.\n\nGeneric abstract interface for retrievable nodes", + "type": "object", + "properties": { + "id_": { + "title": "Id ", + "description": "Unique ID of the node.", + "type": "string" + }, + "embedding": { + "title": "Embedding", + "description": "Embedding of the node.", + "type": "array", + "items": { + "type": "number" + } + }, + "extra_info": { + "title": "Extra Info", + "description": "A flat dictionary of metadata fields", + "type": "object" + }, + "excluded_embed_metadata_keys": { + "title": "Excluded Embed Metadata Keys", + "description": "Metadata keys that are excluded from text for the embed model.", + "type": "array", + "items": { + "type": "string" + } + }, + "excluded_llm_metadata_keys": { + "title": "Excluded Llm Metadata Keys", + "description": "Metadata keys that are excluded from text for the LLM.", + "type": "array", + "items": { + "type": "string" + } + }, + "relationships": { + "title": "Relationships", + "description": "A mapping of relationships to other node information.", + "type": "object", + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/definitions/RelatedNodeInfo" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/RelatedNodeInfo" + } + } + ] + } + }, + "class_name": { + "title": "Class Name", + "type": "string", + "default": "base_component" + } + } + } + }, + "parameterExtensions": { + "config": [ + "docs", + "buffer_size", + "breakpoint_percentile_threshold", + "chunk_overlap", + "model_name", + "api_key", + "api_base" + ] + } + } + }, + { + "tags": [ + "LlamaIndex" + ], + "runnables": { + "LlamaIndex": { + "genericableId": "llama_index.rag.toolkit", + "fitableId": "sentence_window_node_parser" + } + }, + "schema": { + "name": "sentence_window_node_parser", + "description": "Splitting text that looks at word tokens.", + "parameters": { + "type": "object", + "properties": { + "window_size": { + "type": "string", + "description": "window_size" + }, + "window_metadata_key": { + "type": "string", + "description": "window_metadata_key" + }, + "original_text_metadata_key": { + "type": "string", + "description": "original_text_metadata_key" + }, + "docs": { + "type": "string", + "description": "docs" + } + }, + "required": [ + "window_size", + "window_metadata_key", + "original_text_metadata_key", + "docs" + ] + }, + "return": { + "title": "Splits All Documents Into Individual Sentences", + "type": "array", + "items": { + "title": "BaseNode", + "description": "Base node Object.\n\nGeneric abstract interface for retrievable nodes", + "type": "object", + "properties": { + "id_": { + "title": "Id ", + "description": "Unique ID of the node.", + "type": "string" + }, + "embedding": { + "title": "Embedding", + "description": "Embedding of the node.", + "type": "array", + "items": { + "type": "number" + } + }, + "extra_info": { + "title": "Extra Info", + "description": "A flat dictionary of metadata fields", + "type": "object" + }, + "excluded_embed_metadata_keys": { + "title": "Excluded Embed Metadata Keys", + "description": "Metadata keys that are excluded from text for the embed model.", + "type": "array", + "items": { + "type": "string" + } + }, + "excluded_llm_metadata_keys": { + "title": "Excluded Llm Metadata Keys", + "description": "Metadata keys that are excluded from text for the LLM.", + "type": "array", + "items": { + "type": "string" + } + }, + "relationships": { + "title": "Relationships", + "description": "A mapping of relationships to other node information.", + "type": "object", + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/definitions/RelatedNodeInfo" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/RelatedNodeInfo" + } + } + ] + } + }, + "class_name": { + "title": "Class Name", + "type": "string", + "default": "base_component" + } + } + } + }, + "parameterExtensions": { + "config": [ + "docs", + "window_size", + "window_metadata_key", + "original_text_metadata_key" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/types/__init__.py b/framework/fel/python/plugins/fel_llama_splitter_tools/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/types/document.py b/framework/fel/python/plugins/fel_llama_splitter_tools/types/document.py new file mode 100644 index 00000000..4989999f --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/types/document.py @@ -0,0 +1,22 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +from .serializable import Serializable +from .media import Media + + +class Document(Serializable): + """ + Document. + """ + content: str + media: Media = None + metadata: typing.Dict[str, object] + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/types/media.py b/framework/fel/python/plugins/fel_llama_splitter_tools/types/media.py new file mode 100644 index 00000000..b1bdb54a --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/types/media.py @@ -0,0 +1,18 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +from .serializable import Serializable + + +class Media(Serializable): + """ + Media. + """ + mime: str + data: str + + class Config: + frozen = True + smart_union = True diff --git a/framework/fel/python/plugins/fel_llama_splitter_tools/types/serializable.py b/framework/fel/python/plugins/fel_llama_splitter_tools/types/serializable.py new file mode 100644 index 00000000..4522897f --- /dev/null +++ b/framework/fel/python/plugins/fel_llama_splitter_tools/types/serializable.py @@ -0,0 +1,25 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import typing + +try: + import pydantic + + if pydantic.__version__.startswith("1."): + raise ImportError + import pydantic.v1 as pydantic +except ImportError: + import pydantic + + +class Serializable(pydantic.BaseModel): + def json(self, **kwargs: typing.Any) -> str: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().json(**kwargs_with_defaults) + + def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().dict(**kwargs_with_defaults) \ No newline at end of file diff --git a/framework/fel/python/plugins/fel_llamaindex_network_tools/callable_registers.py b/framework/fel/python/plugins/fel_llamaindex_network_tools/callable_registers.py new file mode 100644 index 00000000..0cde3122 --- /dev/null +++ b/framework/fel/python/plugins/fel_llamaindex_network_tools/callable_registers.py @@ -0,0 +1,29 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +import functools +from inspect import signature +from typing import Callable, Any, Tuple, List + +from fitframework import fit_logger +from fitframework.core.repo.fitable_register import register_fitable + + +def __invoke_tool(input_args: dict, tool_func: Callable[..., Any], **kwargs) -> Any: + return tool_func(**input_args, **kwargs) + + +def register_callable_tool(tool: Tuple[Callable[..., Any], List[str], str], module: str, generic_id: str): + func = tool[0] + fitable_id = f"{func.__name__}" + + tool_invoke = functools.partial(__invoke_tool, tool_func=func) + tool_invoke.__module__ = module + tool_invoke.__annotations__ = { + 'input_args': dict, + 'return': signature(func).return_annotation + } + register_fitable(generic_id, fitable_id, False, [], tool_invoke) + fit_logger.info("register: generic_id = %s, fitable_id = %s", generic_id, fitable_id, stacklevel=2) diff --git a/framework/fel/python/plugins/fel_llamaindex_network_tools/llamaindex_network_tool.py b/framework/fel/python/plugins/fel_llamaindex_network_tools/llamaindex_network_tool.py new file mode 100644 index 00000000..de2f1202 --- /dev/null +++ b/framework/fel/python/plugins/fel_llamaindex_network_tools/llamaindex_network_tool.py @@ -0,0 +1,26 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== + +import time +from typing import List, Any, Optional, Callable, Union, Tuple + +from .callable_registers import register_callable_tool + + +def llamaindex_network(**kwargs) -> str: + time.sleep(5) + return "" + + +# Tuple 结构: (tool_func, config_args, return_description) +network_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [ + (llamaindex_network, ["input"], "Youtube search.") +] + + +for tool in network_toolkit: + register_callable_tool(tool, llamaindex_network.__module__, "llama_index.rag.toolkit") + diff --git a/framework/fel/python/setup.py b/framework/fel/python/setup.py new file mode 100644 index 00000000..165f8225 --- /dev/null +++ b/framework/fel/python/setup.py @@ -0,0 +1,30 @@ +# -- encoding: utf-8 -- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the ModelEngine Project. +# Licensed under the MIT License. See License.txt in the project root for license information. +# ====================================================================================================================== +""" +功 能:用于打包工程为 wheel 文件的脚本。 +""" +# 打包方式:在 fel/python 目录下执行 python setup.py sdist bdist_wheel +import setuptools + +_FEL_FRAMEWORK_VERSION = "0.0.1.dev" + +setuptools.setup( + name="fel", + version=_FEL_FRAMEWORK_VERSION, + author="fit", + url="https://gitlab.huawei.com/fitlab/fit", + packages=setuptools.find_packages( + exclude=["*.tests", "*.tests.*", "tests.*", "tests", "_test.*", "_test"]), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: Huawei license", + "Operating System :: OS Independent", + ], + install_requires=["langchain==0.2.6", + "llama_index==0.10.47", + "requests==2.31"], + python_requires='==3.9.11' +) diff --git a/framework/waterflow/java/pom.xml b/framework/waterflow/java/pom.xml index 3dc312f3..ee8d0adb 100644 --- a/framework/waterflow/java/pom.xml +++ b/framework/waterflow/java/pom.xml @@ -36,10 +36,10 @@ - waterflow-bridge-fit-reactor + waterflow-common waterflow-core waterflow-dependency - waterflow-genericable + waterflow-eco diff --git a/framework/waterflow/java/waterflow-common/pom.xml b/framework/waterflow/java/waterflow-common/pom.xml new file mode 100644 index 00000000..3d449994 --- /dev/null +++ b/framework/waterflow/java/waterflow-common/pom.xml @@ -0,0 +1,57 @@ + + + 4.0.0 + + org.fitframework.waterflow + waterflow-parent + 3.5.0-SNAPSHOT + + + waterflow-common + 3.5.0-SNAPSHOT + + + 17 + 17 + UTF-8 + + + + + org.fitframework + fit-api + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.version} + + + + FIT Lab + + + + + + org.fitframework + fit-build-maven-plugin + ${fit.version} + + + build-service + + build-service + + + + + + + \ No newline at end of file diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/ErrorCodes.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/ErrorCodes.java new file mode 100644 index 00000000..ace33101 --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/ErrorCodes.java @@ -0,0 +1,338 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow; + +/** + * 异常类型枚举类。 + * Generic Exception: 10000000-10000999. + * FlowEngines Exception: 10007000-10007999. + * + * @author 陈镕希 + * @since 2025-03-02 + */ +public enum ErrorCodes { + /** + * 入参为空。 + */ + INPUT_PARAM_IS_EMPTY(10000000, "Input param is empty, empty param is {0}."), + + /** + * 枚举类转换异常。 + */ + ENUM_CONVERT_FAILED(10000001, "Cannot convert enum {0} by name: {1}."), + + /** + * 实体对象未找到。 + */ + ENTITY_NOT_FOUND(10000002, "Cannot find entity {0} by id: {1}."), + + /** + * 入参不合法。 + */ + INPUT_PARAM_IS_INVALID(10000003, "Input param is invalid, invalid param is {0}."), + + /** + * 不符合预期。 + */ + UN_EXCEPTED_ERROR(10000006, "unexpected error:{0}"), + + /** + * 分页查询时Offset范围不正确。 + */ + PAGINATION_OFFSET_INVALID(10000008, "The range of offset is incorrect."), + + /** + * 分页查询时Limit范围不正确。 + */ + PAGINATION_LIMIT_INVALID(10000009, "The range of limit is incorrect."), + + /** + * 类型转换失败。 + */ + TYPE_CONVERT_FAILED(10000011, "Cannot convert type."), + + /** + * 流程节点转换不支持操作。 + */ + FLOW_NODE_CREATE_ERROR(10007000, "Processor can not be null during create flowable node."), + + /** + * 流程节点不支持执行操作。 + */ + FLOW_NODE_OPERATOR_NOT_SUPPORT(10007001, "Flow node with id: {0}, type: {1}, for operator [{2}] not supported."), + + /** + * 流程没有开始节点。 + */ + FLOW_HAS_NO_START_NODE(10007002, "Flow definition with id: {0} has no start node."), + + /** + * 流程任务不支持执行操作。 + */ + FLOW_TASK_OPERATOR_NOT_SUPPORT(10007003, "Flow task with name: {0}, type: {1}, are not supported."), + + /** + * 流程执行错误,没有手动执行任务。 + */ + FLOW_ENGINE_INVALID_MANUAL_TASK(10007004, "Flow engine executor error for invalid manual task."), + + /** + * 流程执行错误,非法节点标识。 + */ + FLOW_ENGINE_INVALID_NODE_ID(10007005, "Flow engine executor error for invalid node id: {0}."), + + /** + * 流程定义解析失败。 + */ + FLOW_ENGINE_PARSER_NOT_SUPPORT(10007010, "Flow engine parser not support {0} operator."), + + /** + * 流程启动失败。 + */ + FLOW_START_ERROR(10007011, "Flow status is invalid"), + + /** + * 执行任务失败。 + */ + FLOW_EXECUTE_FITABLE_TASK_FAILED(10007012, + "execute jober failed, jober name: {0}, jober type: {1}, fitables: {2}, errors: {3}"), + + /** + * 流程执行不支持发送事件。 + */ + FLOW_SEND_EVENT_NOT_SUPPORT(100070013, "Flow send event are not supported."), + + /** + * 流程引擎数据库不支持该操作。 + */ + FLOW_ENGINE_DATABASE_NOT_SUPPORT(100070014, "Operation :{0} is not supported."), + + /** + * 流程定义更新失败。 + */ + FLOW_DEFINITION_UPDATE_NOT_SUPPORT(100070015, "Flow status :{0} update not supported."), + + /** + * 通过 eventMetaId 查询 to 节点失败。 + */ + FLOW_FIND_TO_NODE_BY_EVENT_FAILED(100070016, "Find to node by event metaId :{0} failed."), + + /** + * 找不到流程图。 + */ + FLOW_GRAPH_NOT_FOUND(100070017, "Flow graph id: {0} version: {1} not found."), + + /** + * 试图修改已发布的流程图。 + */ + FLOW_MODIFY_PUBLISHED_GRAPH(100070018, + "graph data with id: {0} version: {1} has been published, can not be modified"), + + /** + * 流程标识不匹配。 + */ + FLOW_ID_NOT_MATCH(100070019, "Flow id {0} does not match id {1} in data."), + + /** + * 流程图数据解析失败。 + */ + FLOW_GRAPH_DATA_PARSE_FAILED(100070020, "Parse graph data failed."), + + /** + * 处理智能表单任务失败。 + */ + FLOW_HANDLE_SMART_FORM_FAILED(100070021, "Failed to handle the smart form task."), + + /** + * 终止流程失败。 + */ + FLOW_TERMINATE_FAILED(100070022, + "Failed to terminate flows by trace id {0}, when the flow status is error, archived or terminate."), + + /** + * ElsaFlowsGraphRepo 不支持该操作。 + */ + ELSA_FLOW_GRAPH_NOT_SUPPORT(100070023, "Operation :{0} is not supported."), + + /** + * DbFlowsGraphRepo 不支持该操作。 + */ + NOT_SUPPORT(100070024, "Operation :{0} is not supported."), + + /** + * 流程已存在。 + */ + FLOW_ALREADY_EXIST(100070025, "flow already exist, {0}."), + + /** + * 流程回调函数执行 fitables 失败。 + */ + FLOW_EXECUTE_CALLBACK_FITABLES_FAILED(10007026, + "Failed to execute callback, callback name: {0}, callback type: {1}, fitables: {2}, errors: {3}"), + + /** + * 流程引擎 OhScript 语法错误。 + */ + FLOW_ENGINE_OHSCRIPT_GRAMMAR_ERROR(100070024, "OhScript grammar error. Source Code: {0}"), + + /** + * 流程引擎条件规则变量未找到。 + */ + FLOW_ENGINE_CONDITION_RULE_PARSE_ERROR(100070025, "Condition rule parse error. Condition Rule: {0}"), + + /** + * 找不到对应流程节点。 + */ + FLOW_NODE_NOT_FOUND(100070024, "Flow node id {0} not found, flow meta id {1}, version {2}."), + + /** + * flow节点任务数达到最大值。 + */ + FLOW_NODE_MAX_TASK(100070024, "Flow node id {0} tasks over the limit."), + + /** + * 流程自动任务特定异常重试失败。 + */ + FLOW_RETRY_JOBER_UPDATE_DATABASE_FAILED(10007024, "Failed to update the retry record for retryJober, toBatch: {0}"), + + /** + * 异步 jober 执行失败。 + */ + FLOW_EXECUTE_ASYNC_JOBER_FAILED(10007027, + "execute async jober failed."), + + /** + * 流程执行过程出现异常。 + */ + FLOW_ENGINE_EXECUTOR_ERROR(10007500, "Flow engine executor errors " + + "stream id: {0}, node id: {1}, name: {2}, exception: {3}, errors: {4}."), + + /** + * 流程执行过程通过 ohscript 调用 fitable 出现异常。 + */ + FLOW_OHSCRIPT_INVOKE_FITABLE_ERROR(10007501, + "Error code: 10007501, Flow engine executor ohscript code error when invoke fitable."), + + /** + * 流程定义删除失败。 + */ + FLOW_DEFINITION_DELETE_ERROR(10007502, "Error code: 10007502, Flow definition delete error"), + + /** + * 流程出现系统错误。 + */ + FLOW_SYSTEM_ERROR(10007503, "SYSTEM ERROR"), + + /** + * 流程调用过程出现网络错误。 + */ + FLOW_NETWORK_ERROR(10007504, "Error code: 10007504, Network error when Invoke fitable"), + + /** + * 流程执行过程中不支持处理该类型。 + */ + CONTEXT_TYPE_NOT_SUPPORT(10007505, "Not support this type."), + + /** + * 中间节点连接线不合法。 + */ + INVALID_STATE_NODE_EVENT_SIZE(10007518, "State node event size must be 1, please check config"), + + /** + * 节点对应的 event 个数不合法。 + */ + INVALID_EVENT_SIZE(10007506, "Error code: 10007506, Invalid event size."), + + /** + * 流程 storeJober 调用过程执行出错。 + */ + FLOW_STORE_JOBER_INVOKE_ERROR(10007507, "Flow store jober invoke error, tool id:{0}."), + + /** + * 流程 httpJober 调用过程执行出错。 + */ + FLOW_HTTP_JOBER_INVOKE_ERROR(10007508, "Flow http jober invoke error."), + + /** + * 流程 genericableJober 调用过程执行出错。 + */ + FLOW_GENERICALBE_JOBER_INVOKE_ERROR(10007509, "Flow genericable jober invoke error."), + + /** + * 流程 generalJober 调用过程执行出错。 + */ + FLOW_GENERAL_JOBER_INVOKE_ERROR(100075010, "Flow general jober invoke error."), + + /** + * 条件节点执行出错。 + */ + CONDITION_NODE_EXEC_ERROR(10007511, "Condition node executor error."), + + /** + * 流程图保存失败。 + */ + FLOW_GRAPH_SAVE_ERROR(10007512, "Flow graph save error, flow id: {0}, version: {1}."), + + /** + * 流程图升级失败。 + */ + FLOW_GRAPH_UPGRADE_ERROR(10007513, "Flow graph upgrade error, flow id: {0}, version: {1}."), + + /** + * 流程校验失败。 + */ + FLOW_VALIDATE_ERROR(10007514, "Flow graph validate error, detail: {0}"), + + /** + * 流程节点个数不合法。 + */ + INVALID_FLOW_NODE_SIZE(10007515, "Node size must more than 3, please check config"), + + /** + * 开始节点连接线不合法。 + */ + INVALID_START_NODE_EVENT_SIZE(10007516, "Start node event size must be 1, please check config"), + + /** + * 连接线配置不合法。 + */ + INVALID_EVENT_CONFIG(10007517, "Event config is invalid, event id: {0}"); + + private final Integer errorCode; + + private final String message; + + ErrorCodes(Integer errorCode, String message) { + this.errorCode = errorCode; + this.message = message; + } + + /** + * Retrieves the error code associated with this instance. + * The error code typically represents a specific error condition or status. + * + * @return the numeric error code. + */ + public Integer getErrorCode() { + return errorCode; + } + + /** + * Retrieves the descriptive message associated with this instance. + * The message typically provides human-readable details about the status or error condition. + * + * @return the descriptive message text. + */ + public String getMessage() { + return message; + } + + @Override + public String toString() { + return "err " + this.errorCode + ": " + this.message; + } +} diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/DefaultOperationContext.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/DefaultOperationContext.java new file mode 100644 index 00000000..24fe3ffa --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/DefaultOperationContext.java @@ -0,0 +1,140 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow.entity; + +import modelengine.fitframework.util.StringUtils; + +import java.util.Arrays; + +import static modelengine.fitframework.util.ObjectUtils.nullIf; + +/** + * 操作上下文。 + * + * @author 梁济时 + * @since 2023-10-30 + */ +class DefaultOperationContext implements OperationContext { + static final DefaultOperationContext EMPTY = new DefaultOperationContext("", "", + "0.0.0.0", "", ""); + + private final String tenantId; + + private final String operator; + + private final String operatorIp; + + private final String language; + + private final String sourcePlatform; + + DefaultOperationContext(String tenantId, String operator, String operatorIp, String language, + String sourcePlatform) { + this.tenantId = nullIf(tenantId, ""); + this.operator = nullIf(operator, ""); + this.operatorIp = nullIf(operatorIp, ""); + this.language = nullIf(language, ""); + this.sourcePlatform = nullIf(sourcePlatform, ""); + } + + @Override + public String tenantId() { + return this.tenantId; + } + + @Override + public String operator() { + return this.operator; + } + + @Override + public String operatorIp() { + return this.operatorIp; + } + + @Override + public String language() { + return this.language; + } + + @Override + public String sourcePlatform() { + return this.sourcePlatform; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } else if (obj instanceof DefaultOperationContext) { + DefaultOperationContext another = (DefaultOperationContext) obj; + return this.tenantId.equals(another.tenantId) && this.operator.equals(another.operator) + && this.operatorIp.equals(another.operatorIp); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(new Object[] {this.getClass(), this.tenantId, this.operator, this.operatorIp}); + } + + @Override + public String toString() { + return StringUtils.format("[tenantId={0}, operator={1}, operatorIp={2}]", this.tenantId, this.operator, + this.operatorIp); + } + + static class Builder implements OperationContext.Builder { + private String tenantId; + + private String operator; + + private String operatorIp; + + private String language; + + private String sourcePlatform; + + @Override + public OperationContext.Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + + @Override + public OperationContext.Builder operator(String operator) { + this.operator = operator; + return this; + } + + @Override + public OperationContext.Builder operatorIp(String operatorIp) { + this.operatorIp = operatorIp; + return this; + } + + @Override + public OperationContext.Builder langage(String language) { + this.language = language; + return this; + } + + @Override + public OperationContext.Builder sourcePlatform(String sourcePlatform) { + this.sourcePlatform = sourcePlatform; + return this; + } + + @Override + public OperationContext build() { + return new DefaultOperationContext(this.tenantId, this.operator, this.operatorIp, this.language, + this.sourcePlatform); + } + } +} diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/OperationContext.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/OperationContext.java new file mode 100644 index 00000000..1b324214 --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/entity/OperationContext.java @@ -0,0 +1,123 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow.entity; + +/** + * 操作人相关上下文。 + * + * @author 陈镕希 + * @since 2023-08-28 + */ +public interface OperationContext { + /** + * 获取正在操作的租户的唯一标识。 + * + * @return 表示正在操作的租户的唯一标识的 {@link String}。 + */ + String tenantId(); + + /** + * 获取操作人的名称。 + * + * @return 表示操作人名称的 {@link String}。 + */ + String operator(); + + /** + * 获取操作方的 IP 地址。 + * + * @return 表示 IP 地址的 {@link String}。 + */ + String operatorIp(); + + /** + * 获取操作方的语言。 + * + * @return 表示 语言 {@link String}。 + */ + String language(); + + /** + * 获取操作方的标识。 + * + * @return 表示 操作方标识 {@link String}。 + */ + String sourcePlatform(); + + /** + * 为 {@link OperationContext} 提供构建器。 + * + * @author 梁济时 + * @since 2023-08-08 + */ + interface Builder { + /** + * 设置正在操作的租户的唯一标识。 + * + * @param tenantId 表示正在操作的租户的唯一标识的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder tenantId(String tenantId); + + /** + * 设置操作人的名称。 + * + * @param operator 表示操作人名称的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder operator(String operator); + + /** + * 设置操作方的 IP 地址。 + * + * @param operatorIp 表示 IP 地址的 {@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder operatorIp(String operatorIp); + + /** + * 设置操作方的语言。 + * + * @param langage 表示 语言{@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder langage(String langage); + + /** + * 设置操作方的标识。 + * + * @param sourcePlatform 表示 操作标识{@link String}。 + * @return 表示当前构建器的 {@link Builder}。 + */ + Builder sourcePlatform(String sourcePlatform); + + /** + * 构建操作上下文的新实例。 + * + * @return 表示操作上下文新实例的 {@link OperationContext}。 + */ + OperationContext build(); + } + + /** + * 返回一个构建器,用以构建操作上下文的新实例。 + * + * @return 表示用以构建操作上下文新实例的构建器的 {@link Builder}。 + */ + static Builder custom() { + return new DefaultOperationContext.Builder(); + } + + /** + * 获取空的上下文信息。 + * + * @return 表示空的上下文信息的 {@link OperationContext}。 + */ + static OperationContext empty() { + return DefaultOperationContext.EMPTY; + } +} \ No newline at end of file diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/BadRequestException.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/BadRequestException.java similarity index 90% rename from framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/BadRequestException.java rename to framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/BadRequestException.java index 055c729e..b8954017 100644 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/BadRequestException.java +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/BadRequestException.java @@ -4,9 +4,9 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -package modelengine.fit.waterflow.common.exceptions; +package modelengine.fit.waterflow.exceptions; -import modelengine.fit.waterflow.common.ErrorCodes; +import modelengine.fit.waterflow.ErrorCodes; /** * 错误请求异常类。 diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/ServerInternalException.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/ServerInternalException.java new file mode 100644 index 00000000..222f9e95 --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/ServerInternalException.java @@ -0,0 +1,38 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow.exceptions; + +import modelengine.fitframework.exception.ErrorCode; +import modelengine.fitframework.exception.FitException; + +/** + * 服务器内部异常,用于服务器内部报错。 + * + * @author 陈镕希 + * @since 2023-07-06 + */ +@ErrorCode(500) +public class ServerInternalException extends FitException { + /** + * 使用异常信息初始化 {@link ServerInternalException} 类的新实例。 + * + * @param message 表示异常信息的 {@link String}。 + */ + public ServerInternalException(String message) { + super(message); + } + + /** + * 使用异常信息和引发异常的原因初始化 {@link ServerInternalException} 类的新实例。 + * + * @param message 表示异常信息的 {@link String}。 + * @param cause 表示引发异常的原因的 {@link Throwable}。 + */ + public ServerInternalException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowException.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowException.java similarity index 94% rename from framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowException.java rename to framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowException.java index 3579c22e..1794dd7a 100644 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowException.java +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowException.java @@ -4,9 +4,9 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -package modelengine.fit.waterflow.common.exceptions; +package modelengine.fit.waterflow.exceptions; -import modelengine.fit.waterflow.common.ErrorCodes; +import modelengine.fit.waterflow.ErrorCodes; import modelengine.fitframework.exception.FitException; import java.text.MessageFormat; diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowParamException.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowParamException.java similarity index 90% rename from framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowParamException.java rename to framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowParamException.java index d83da4d0..dfa0e910 100644 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/exceptions/WaterflowParamException.java +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/exceptions/WaterflowParamException.java @@ -4,9 +4,9 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -package modelengine.fit.waterflow.common.exceptions; +package modelengine.fit.waterflow.exceptions; -import modelengine.fit.waterflow.common.ErrorCodes; +import modelengine.fit.waterflow.ErrorCodes; /** * 参数错误抛出异常类。 diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Dates.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Dates.java new file mode 100644 index 00000000..56054bf3 --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Dates.java @@ -0,0 +1,84 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow.utils; + +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.Arrays; +import java.util.List; + +/** + * 为日期提供工具方法。 + * + * @author 陈镕希 + * @since 2023-08-07 + */ +public final class Dates { + private static final List PATTERNS = Arrays.asList("yyyy-MM-dd HH:mm:ss.SSS", "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSSSSS", "yyyy-MM-dd HH:mm:ss.SSSSSSSSS"); + + /** + * 隐藏默认构造方法,避免工具类被实例化。 + */ + private Dates() { + } + + /** + * 返回一个字符串,用以描述指定的日期时间。 + * + * @param value 表示待转为字符串表现形式的日期时间的 {@link LocalDateTime}。 + * @return 表示该日期时间的字符串的 {@link String}。 + */ + public static String toString(LocalDateTime value) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS"); + return value.format(formatter); + } + + /** + * 从字符串中解析日期时间。 + * + * @param text 表示包含日期时间信息的字符串的 {@link String}。 + * @return 从字符串中解析到的日期时间的 {@link LocalDateTime}。 + */ + public static LocalDateTime parse(String text) { + for (String pattern : PATTERNS) { + if (text.length() == pattern.length()) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern); + return LocalDateTime.parse(text, formatter); + } + } + throw new DateTimeParseException("Invalid datetime format.", text, 0); + } + + /** + * 当本地时间转为 UTC 时间。 + * + * @param value 表示本地时间的 {@link LocalDateTime}。 + * @return 表示 UTC 时间的 {@link LocalDateTime}。 + */ + public static LocalDateTime toUtc(LocalDateTime value) { + ZonedDateTime zoned = value.atZone(ZoneId.systemDefault()); + ZonedDateTime utc = zoned.withZoneSameInstant(ZoneOffset.UTC); + return utc.toLocalDateTime(); + } + + /** + * 将 UTC 时间转为本地时间。 + * + * @param value 表示 UTC 时间的 {@link LocalDateTime}。 + * @return 表示本地时间的 {@link LocalDateTime}。 + */ + public static LocalDateTime fromUtc(LocalDateTime value) { + ZonedDateTime zoned = value.atZone(ZoneOffset.UTC); + ZonedDateTime local = zoned.withZoneSameInstant(ZoneId.systemDefault()); + return local.toLocalDateTime(); + } +} diff --git a/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Entities.java b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Entities.java new file mode 100644 index 00000000..661d644c --- /dev/null +++ b/framework/waterflow/java/waterflow-common/src/main/java/modelengine/fit/waterflow/utils/Entities.java @@ -0,0 +1,307 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fit.waterflow.utils; + +import modelengine.fitframework.model.RangedResultSet; +import modelengine.fitframework.util.StringUtils; + +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * 为数据实体提供工具方法。 + * + * @author 梁济时 + * @since 2023-08-08 + */ +public final class Entities { + private static final String EMPTY_ID = "00000000000000000000000000000000"; + + /** + * 隐藏默认构造方法,避免工具类被实例化。 + */ + private Entities() { + } + + /** + * 生成实体的唯一标识。 + * + * @return 表示实体唯一标识的 {@link String}。 + */ + public static String generateId() { + return UUID.randomUUID().toString().replace("-", ""); + } + + /** + * 获取空的唯一标识。 + * + * @return 表示空的唯一标识的 {@link String}。 + */ + public static String emptyId() { + return EMPTY_ID; + } + + /** + * 校验唯一标识。 + * + * @param id 表示待校验的唯一标识的 {@link String}。 + * @param exceptionSupplier 表示当唯一标识的格式不正确时引发的异常的创建方法的 {@link Supplier}。 + * @return 表示符合校验规则的唯一标识的 {@link String}。 + */ + public static String validateId(String id, Supplier exceptionSupplier) { + if (isId(id)) { + return canonicalizeId(id); + } else { + throw exceptionSupplier.get(); + } + } + + /** + * 规范化唯一标识。 + * + * @param id 表示待规范化的唯一标识的 {@link String}。 + * @return 表示规范化后的唯一标识的 {@link String}。 + */ + public static String canonicalizeId(String id) { + return StringUtils.toLowerCase(id); + } + + /** + * 检查指定的字符串是否包含有效格式的唯一标识信息。 + * + * @param value 表示待检查的字符串的 {@link String}。 + * @return 若包含了有效格式的唯一标识,则为 {@code true},否则为 {@code false}。 + */ + public static boolean isId(String value) { + if (value == null || value.length() != 32) { + return false; + } + for (int i = 0; i < value.length(); i++) { + char ch = value.charAt(i); + if (isInvalidId(ch)) { + return false; + } + } + return true; + } + + private static boolean isInvalidId(char ch) { + return (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F'); + } + + /** + * 忽略空的唯一标识。 + * + * @param id 表示唯一标识的 {@link String}。 + * @return 当 {@code id} 为 {@link #emptyId()} 时,返回 {@code null},否则返回输入的唯一标识的 {@link String}。 + */ + public static String ignoreEmpty(String id) { + if (emptyId().equals(id)) { + return null; + } else { + return id; + } + } + + /** + * 设置实体的跟踪信息,包括创建人、创建时间、修改人、修改时间。 + * + * @param entity 表示待填充跟踪信息的实体的 {@link Object}。 + * @param row 表示数据行的 {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}。 + */ + public static void fillTraceInfo(Object entity, Map row) { + if (entity instanceof CreationTraceable) { + CreationTraceable traceable = (CreationTraceable) entity; + if (row.get("created_by") instanceof String) { + traceable.setCreator((String) row.get("created_by")); + } + if (row.get("created_at") instanceof Timestamp) { + traceable.setCreationTime(Dates.fromUtc(((Timestamp) row.get("created_at")).toLocalDateTime())); + } + } + if (entity instanceof ModificationTraceable) { + ModificationTraceable traceable = (ModificationTraceable) entity; + if (row.get("updated_by") instanceof String) { + traceable.setLastModifier((String) row.get("updated_by")); + } + if (row.get("updated_at") instanceof Timestamp) { + traceable.setLastModificationTime(Dates.fromUtc(((Timestamp) row.get("updated_at")).toLocalDateTime())); + } + } + } + + /** + * 生成一个空的分页结果集。 + * + * @param offset 表示待查询的分页结果集在全量集中的偏移量的 64 位整数。 + * @param limit 表示期望的分页结果集中包含数据记录的最大数量的 32 位整数。 + * @param 表示结果集中元素的类型。 + * @return 表示空的分页结果集的 {@link RangedResultSet}。 + */ + public static RangedResultSet emptyRangedResultSet(long offset, int limit) { + return RangedResultSet.create(Collections.emptyList(), (int) offset, limit, 0); + } + + /** + * 检查指定的唯一标识是否为空。 + * + * @param id 表示待检查的唯一标识的 {@link String}。 + * @return 若唯一标识为空,则为 {@code true},否则为 {@code false}。 + */ + public static boolean isEmpty(String id) { + return StringUtils.isEmpty(id) || StringUtils.equalsIgnoreCase(id, emptyId()); + } + + /** + * 检查两个唯一标识是否匹配。 + * + * @param expectedId 表示所期望的唯一标识的 {@link String}。 + * @param actualId 表示实际的唯一标识的 {@link String}。 + * @return 若唯一标识匹配成功,则为 {@code true},否则为 {@code false}。 + */ + public static boolean match(String expectedId, String actualId) { + return StringUtils.equalsIgnoreCase(ignoreEmpty(expectedId), ignoreEmpty(actualId)); + } + + /** + * Compares two maps for equality, handling null cases and checking both keys and values. + *

    Two maps are considered equal if they are both null, or if they:

    + *
      + *
    1. Have the same size.
    2. + *
    3. Contain the same keys.
    4. + *
    5. Have equal values for each key (using {@link Objects#equals}).
    6. + *
    + * + * @param The type of keys maintained by the maps. + * @param The type of mapped values。 + * @param map1 The first map to compare (may be null). + * @param map2 The second map to compare (may be null). + * @return {@code true} If the maps are equal according to the specified criteria, + * {@code false} otherwise. + */ + public static boolean equals(Map map1, Map map2) { + if (map1 == null) { + return map2 == null; + } else if (map2 == null || map1.size() != map2.size()) { + return false; + } else { + for (Map.Entry entry : map1.entrySet()) { + V value1 = entry.getValue(); + V value2 = map2.get(entry.getKey()); + if (!Objects.equals(value1, value2)) { + return false; + } + } + return true; + } + } + + /** + * Compares two lists for equality, handling null cases and checking elements regardless of order. + *

    Two lists are considered equal if they are both null, or if they:

    + *
      + *
    1. Have the same size.
    2. + *
    3. Contain the same elements (using set comparison).
    4. + *
    + *

    Note: This implementation considers [1,2,2] and [1,1,2] as equal due to set conversion.

    + * + * @param The type of elements in the lists. + * @param list1 The first list to compare (may be null). + * @param list2 The second list to compare (may be null). + * @return {@code true} If the lists contain the same elements regardless of order, + * {@code false} otherwise. + * @apiNote This method performs a set-based comparison, which means it doesn't preserve + * element ordering or duplicate counts. For strict list equality that considers + * order and duplicates, use {@link List#equals}. + */ + public static boolean equals(List list1, List list2) { + if (list1 == null) { + return list2 == null; + } else if (list2 == null || list1.size() != list2.size()) { + return false; + } else { + Set set1 = new HashSet<>(list1); + Set set2 = new HashSet<>(list2); + if (set1.size() != set2.size()) { + return false; + } + set1.removeAll(set2); + return set1.isEmpty(); + } + } + + /** + * Represents an object that can track its creator and creation time. + * Implementing classes should maintain audit information about when and by whom + * the object was initially created. + */ + public interface CreationTraceable { + /** + * Sets the creator identifier for this object. + * + * @param creator The username or identifier of the creator. + */ + void setCreator(String creator); + + /** + * Sets the creation timestamp for this object. + * + * @param creationTime The date and time when the object was created. + */ + void setCreationTime(LocalDateTime creationTime); + } + + /** + * Represents an object that can track its last modifier and modification time. + * Implementing classes should maintain audit information about when and by whom + * the object was last modified. + * + * @since 2023-09-15 + */ + public interface ModificationTraceable { + /** + * Sets the last modifier identifier for this object. + * + * @param lastModifier The username or identifier of the last modifier. + */ + void setLastModifier(String lastModifier); + + /** + * Sets the last modification timestamp for this object. + * + * @param lastModificationTime The date and time when the object was last modified. + */ + void setLastModificationTime(LocalDateTime lastModificationTime); + } + + /** + * 将字符串列表规范化为标准格式 + * + * @param values 字符串列表 + * @return 标准格式列表 + */ + public static List canonicalizeStringList(List values) { + return Optional.ofNullable(values) + .map(Collection::stream) + .orElseGet(Stream::empty) + .map(StringUtils::trim) + .filter(StringUtils::isNotEmpty) + .collect(Collectors.toList()); + } +} diff --git a/framework/waterflow/java/waterflow-core/pom.xml b/framework/waterflow/java/waterflow-core/pom.xml index b63b9c73..da80900d 100644 --- a/framework/waterflow/java/waterflow-core/pom.xml +++ b/framework/waterflow/java/waterflow-core/pom.xml @@ -15,8 +15,7 @@ org.fitframework.waterflow - waterflow-genericable - 3.5.0-SNAPSHOT + waterflow-common org.projectlombok diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapSourceWindow.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapSourceWindow.java index 036df9e7..0632dd29 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapSourceWindow.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapSourceWindow.java @@ -65,13 +65,14 @@ public FlatMapSourceWindow(Window window, FlowContextRepo repo) { /** * 根据输入的原始窗口和上下文仓库创建或获取一个 FlatMapSourceWindow 实例。 * - * @param window 原始窗口 - * @param repo 上下文仓库 - * @param 输入类型 + * @param 输入类型。 + * @param flowId 流标识。 + * @param window 原始窗口。 + * @param repo 上下文仓库。 * @return FlatMapSourceWindow 实例 */ - public static FlatMapSourceWindow from(Window window, FlowContextRepo repo) { - return FlowSessionRepo.getFlatMapSource(window, repo); + public static FlatMapSourceWindow from(String flowId, Window window, FlowContextRepo repo) { + return FlowSessionRepo.getFlatMapSource(flowId, window, repo); } /** diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapWindow.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapWindow.java index 97b5667d..cd3ea729 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapWindow.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlatMapWindow.java @@ -7,7 +7,7 @@ package modelengine.fit.waterflow.domain.context; import lombok.Getter; -import lombok.Setter; +import modelengine.fit.waterflow.domain.stream.nodes.To; import modelengine.fit.waterflow.domain.stream.reactive.Publisher; import java.util.UUID; @@ -28,7 +28,6 @@ public class FlatMapWindow extends Window { * from是对应的flatmap节点的整个window * 注意三个window的关系:source,from,this */ - @Setter @Getter private Window source; @@ -69,6 +68,11 @@ public boolean isDone() { return this.from.isDone(); } + @Override + public boolean isComplete() { + return this.from.isComplete(); + } + @Override public Integer tokenCount() { return this.from.tokenCount(); @@ -89,9 +93,6 @@ public boolean accept() { */ @Override public void complete() { - if (this.isComplete()) { - return; - } super.complete(); this.from.complete(); } @@ -115,4 +116,19 @@ public Object acc() { public void setAcc(Object acc) { this.from.setAcc(acc); } + + @Override + public void setCompleteHook(To to, FlowContext context) { + this.from.setCompleteHook(to, context); + } + + /** + * Set source window. + * + * @param source The source window of {@link Window}. + */ + public void setSource(Window source) { + this.source = source; + source.onDone(this.id(), this::complete); + } } diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlowSession.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlowSession.java index 0ef8f9ea..2e833c3c 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlowSession.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/FlowSession.java @@ -117,12 +117,48 @@ public FlowSession(String id, boolean preserved) { public FlowSession(FlowSession session) { this(session.getId(), session.preserved); this.copyState(session); - this.keyBy = Optional.ofNullable(session).map(FlowSession::keyBy).orElse(null); + this.keyBy = session.keyBy; this.begin(); this.window.setFrom(session.getWindow()); session.getWindow().addTo(this.window); } + /** + * Creates a new FlowSession based on an existing session and window configuration. + * The new session will inherit the ID, preserved state, keyBy value, and other state + * from the original session, but with the specified window setting. + * + * @param session the original {@link FlowSession} to copy properties from. + * @param window the {@link Window} configuration to apply to the new session. + * @return a new {@link FlowSession} instance with properties copied from the original session. + * and the specified window configuration applied + */ + public static FlowSession from(FlowSession session, Window window) { + FlowSession newSession = new FlowSession(session.getId(), session.preserved); + newSession.copyState(session); + newSession.keyBy = session.keyBy; + newSession.setWindow(window); + return newSession; + } + + /** + * Creates a new root-level FlowSession based on an existing session with preservation control. + * The new session inherits state and key configuration from the original session, + * initializes as a new session, and incorporates the original session's window settings. + * + * @param session the original {@link FlowSession} to copy state from. + * @param preserved {@code boolean} indicates whether the new session should be created as a preserved session. + * @return a new root-level {@link FlowSession} initialized with the specified preservation state. + * and containing copied state from the original session + */ + public static FlowSession newRootSession(FlowSession session, boolean preserved) { + FlowSession newSession = new FlowSession(preserved); + newSession.copyState(session); + newSession.keyBy = session.keyBy; + newSession.begin(); + return newSession; + } + /** * 将本context设置为accumulator */ @@ -160,6 +196,15 @@ public void setWindow(Window window) { } } + /** + * 获取当前实例的数据是否全部处理完成。 + * + * @return 表示前实例的数据是否全部处理完成的 {@code boolean}。 + */ + public boolean isCompleted() { + return this.window.getRootWindow().isAllDone(); + } + /** * 判断两个会话是否相同。 * diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/Window.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/Window.java index 4d93260f..a5aa805b 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/Window.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/Window.java @@ -18,11 +18,13 @@ import java.time.Duration; import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -38,12 +40,6 @@ * @since 1.0 */ public class Window implements Completable { - private final UUID id; - - private final List tokens = new ArrayList<>(16); - - private final Set tos = new HashSet<>(); - /** * window最后更新时间 */ @@ -60,6 +56,13 @@ public class Window implements Completable { @Getter protected Window from = null; + private final UUID id; + private final List tokens = new ArrayList<>(16); + @Getter + private final Set tos = new CopyOnWriteArraySet<>(); + private final Map onDoneHandlers = new ConcurrentHashMap<>(); + + private Boolean isFinished = false; /** * accumulator for reduce */ @@ -146,6 +149,43 @@ public boolean isComplete() { return this.isComplete.get(); } + /** + * 监听窗口完成事件。 + * + * @param handlerId 表示监听者的唯一标识的 {@link String}。 + * @param handler 表示监听者接收处理的 {@link Runnable}。 + */ + public synchronized void onDone(String handlerId, Runnable handler) { + synchronized (this) { + if (!this.isDone()) { + this.onDoneHandlers.put(handlerId, handler); + return; + } + } + handler.run(); + } + + /** + * 获取顶层窗口。 + * + * @return 表示顶层窗口的 {@link Window}。 + */ + public Window getRootWindow() { + if (this.from == null) { + return this; + } + return this.from.getRootWindow(); + } + + /** + * 获取该窗口以及后续所有窗口是否全部结束。 + * + * @return 表示该窗口以及后续所有窗口是否全部结束的 {@code boolean}。 + */ + public boolean isAllDone() { + return this.isDone() && this.tos.stream().allMatch(Window::isAllDone); + } + /** * 创建window token * @@ -183,17 +223,19 @@ public void generateIndex(FlowContext context, Publisher publisher) { @Override public void complete() { - if (this.isComplete()) { - return; + synchronized (this) { + if (this.isComplete()) { + return; + } + this.isComplete.set(true); } - this.isComplete.set(true); this.fire(); this.tryFinish(); } private void fire() { // only when all elements are consumed(done), fire the possible reduce - if (completeContext != null && session.isAccumulator() && this.isDone()) { + if (completeContext != null && (session.isAccumulator() || this.acc != null) && this.isDone()) { List> cs = new ArrayList<>(); cs.add(completeContext); List contexts = node.getProcessMode().process(node, cs); @@ -275,9 +317,17 @@ public void setCompleteHook(To to, FlowContext context) { * if this session window is closed and all elements have been consumed, then notify listener stream that i'm totally consumed **/ public void tryFinish() { - if (this.isDone()) { - this.completed(); + synchronized (this) { + if (this.isFinished) { + return; + } + if (!this.isDone()) { + return; + } + this.isFinished = true; } + this.completed(); + this.onDoneHandlers.values().forEach(Runnable::run); } /** diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowcontext/FlowContextRepo.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowcontext/FlowContextRepo.java index c7d58772..ac7bcef3 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowcontext/FlowContextRepo.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowcontext/FlowContextRepo.java @@ -6,8 +6,8 @@ package modelengine.fit.waterflow.domain.context.repo.flowcontext; -import modelengine.fit.waterflow.common.ErrorCodes; -import modelengine.fit.waterflow.common.exceptions.WaterflowException; +import modelengine.fit.waterflow.ErrorCodes; +import modelengine.fit.waterflow.exceptions.WaterflowException; import modelengine.fit.waterflow.domain.context.FlowContext; import modelengine.fit.waterflow.domain.context.FlowTrace; import modelengine.fit.waterflow.domain.enums.FlowNodeStatus; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocks.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocks.java index d00c3190..323a30b0 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocks.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocks.java @@ -9,11 +9,7 @@ import modelengine.fit.waterflow.domain.common.Constants; import modelengine.fitframework.util.StringUtils; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; /** * 流程实例的锁接口 @@ -22,10 +18,6 @@ * @since 1.0 */ public interface FlowLocks { - /** - * 本地锁全局静态对象 - */ - Map locks = new ConcurrentHashMap<>(); /** * 节点分布式锁key前缀 @@ -38,9 +30,7 @@ public interface FlowLocks { * @param key 获取本地锁的key值,一般是流程版本的streamID * @return {@link Lock} 锁对象 */ - default Lock getLocalLock(String key) { - return Optional.ofNullable(locks.putIfAbsent(key, new ReentrantLock())).orElseGet(() -> locks.get(key)); - } + Lock getLocalLock(String key); /** * 获取分布式锁 @@ -50,16 +40,6 @@ default Lock getLocalLock(String key) { */ Lock getDistributeLock(String key); - /** - * 删除本地锁 - * TODO xiangyu 删除流程定义的时候需要删除该定义的本地锁资源 - * - * @param key 删除本地锁的key值,一般是流程版本的streamID - */ - default void removeLocalLock(String key) { - locks.remove(key); - } - /** * 获取节点分布式锁key值 * 获取分布式锁的key值,一般是prefix-streamID-nodeID-type diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocksMemo.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocksMemo.java index 920b374b..26bde109 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocksMemo.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowlock/FlowLocksMemo.java @@ -6,7 +6,13 @@ package modelengine.fit.waterflow.domain.context.repo.flowlock; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; /** * 流程锁,内存版本的实现 @@ -15,6 +21,18 @@ * @since 1.0 */ public class FlowLocksMemo implements FlowLocks { + private final Map locks = new ConcurrentHashMap<>(); + + @Override + public Lock getLocalLock(String key) { + return locks.compute(key, (__, value) -> { + if (value == null) { + return new MemLockWrapper(key, new ReentrantLock(), this); + } + return value; + }); + } + /** * 获取分布式锁 * 获取分布式锁的key值,一般是prefix-streamID-nodeID-suffixes @@ -27,4 +45,69 @@ public class FlowLocksMemo implements FlowLocks { public Lock getDistributeLock(String key) { return getLocalLock(key); } + + private void tryCleanLocalLock(String key) { + this.locks.compute(key, (__, value) -> { + if (value == null) { + return null; + } + if (value.getRefCount() == 0) { + return null; + } + return value; + }); + } + + private static class MemLockWrapper implements Lock { + private final String key; + private final AtomicInteger refCount = new AtomicInteger(1); + private final ReentrantLock target; + private final FlowLocksMemo locksMemo; + + private MemLockWrapper(String key, ReentrantLock target, FlowLocksMemo locksMemo) { + this.key = key; + this.target = target; + this.locksMemo = locksMemo; + } + + @Override + public void lock() { + this.target.lock(); + } + + @Override + public void lockInterruptibly() throws InterruptedException { + this.target.lockInterruptibly(); + } + + @Override + public boolean tryLock() { + return this.target.tryLock(); + } + + @Override + public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { + return this.target.tryLock(time, unit); + } + + @Override + public void unlock() { + this.target.unlock(); + this.refCount.decrementAndGet(); + this.locksMemo.tryCleanLocalLock(this.key); + } + + @Override + public Condition newCondition() { + return this.target.newCondition(); + } + + private void addRef() { + this.refCount.incrementAndGet(); + } + + private int getRefCount() { + return this.refCount.get(); + } + } } diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowsession/FlowSessionRepo.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowsession/FlowSessionRepo.java index 51b12a8e..be71e73f 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowsession/FlowSessionRepo.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/context/repo/flowsession/FlowSessionRepo.java @@ -17,48 +17,109 @@ import java.util.concurrent.ConcurrentHashMap; /** - * 流程运行中 session 相关的数据缓存,用于统一管理这些数据和在 session 完成时统一进行释放。 + * Manages data caching during flow execution and handles unified resource release upon session completion. + * This repository centrally stores session-related data and ensures proper cleanup when sessions finish. * * @author 宋永坦 * @since 2025-02-12 */ public class FlowSessionRepo { - private static final Map cache = new ConcurrentHashMap<>(); + /** + * Stores flow session resources for coordinated management and release. + * The outer map key is flow identifier, inner map key is session unique identifier. + */ + private static final Map> cache = new ConcurrentHashMap<>(); + + /** + * Retrieves the next session for data propagation to downstream nodes. + * + * @param flowId The unique identifier of the flow. + * @param session The current session context. + * @return The next session for data transmission. + */ + public static FlowSession getNextToSession(String flowId, FlowSession session) { + Validation.notNull(flowId, "Flow id cannot be null."); + Validation.notNull(session, "Session cannot be null."); + return getFlowSessionCache(flowId, session) + .getNextToSession(session); + } /** - * 获取该 session 的 window 对应的向下一个节点传递数据使用的 session。 + * Retrieves the session for handling emitter operations in the next processing step. * - * @param session session。 - * @return 下一个 session。 + * @param flowId The unique identifier of the flow. + * @param session The current session context. + * @return The session configured for emitter handling. */ - public static FlowSession getNextSession(FlowSession session) { + public static FlowSession getNextEmitterHandleSession(String flowId, FlowSession session) { + Validation.notNull(flowId, "Flow id cannot be null."); Validation.notNull(session, "Session cannot be null."); - return cache.computeIfAbsent(session.getId(), __ -> new FlowSessionCache()).getNextSession(session); + return getFlowSessionCache(flowId, session) + .getNextEmitterHandleSession(session); } /** - * 获取 flatMap 节点生成的 {@link FlatMapSourceWindow}。 + * Gets the next accumulation order number for the specified node. * - * @param window 进入到 flatMap 节点数据对应的window。 - * @param repo 流程数据上下文的持久化对象。 - * @return 对应的 {@link FlatMapSourceWindow}。 + * @param flowId The unique identifier of the flow. + * @param nodeId The target node identifier. + * @param session The current session context. + * @return The next accumulation sequence number. */ - public static FlatMapSourceWindow getFlatMapSource(Window window, FlowContextRepo repo) { + public static int getNextAccOrder(String flowId, String nodeId, FlowSession session) { + Validation.notNull(flowId, "Flow id cannot be null."); + Validation.notNull(nodeId, "Node id cannot be null."); + Validation.notNull(session, "Session cannot be null."); + return getFlowSessionCache(flowId, session).getNextAccOrder(nodeId); + } + + /** + * Retrieves the {@link FlatMapSourceWindow} generated by a flatMap node operation. + * + * @param flowId The unique identifier of the flow. + * @param window The input window entering the flatMap node. + * @param repo The flow context persistence repository. + * @return The corresponding {@link FlatMapSourceWindow} instance. + */ + public static FlatMapSourceWindow getFlatMapSource(String flowId, Window window, FlowContextRepo repo) { + Validation.notNull(flowId, "Flow id cannot be null."); Validation.notNull(window, "Window cannot be null."); Validation.notNull(window.getSession(), "Session cannot be null."); Validation.notNull(repo, "Repo cannot be null."); - return cache.computeIfAbsent(window.getSession().getId(), __ -> new FlowSessionCache()) + return getFlowSessionCache(flowId, window.getSession()) .getFlatMapSourceWindow(window, repo); } /** - * 释放 session 下的所有资源。 + * Releases all resources associated with a specific flow session. * - * @param session 需要释放资源的 session。 + * @param flowId The unique identifier of the flow. + * @param session The target session for resource cleanup. */ - public static void release(FlowSession session) { + public static void release(String flowId, FlowSession session) { + Validation.notNull(flowId, "Flow id cannot be null."); Validation.notNull(session, "Session cannot be null."); - cache.remove(session.getId()); + cache.compute(flowId, (__, value) -> { + if (value == null) { + return null; + } + value.remove(session.getId()); + if (value.isEmpty()) { + return null; + } + return value; + }); + } + + private static FlowSessionCache getFlowSessionCache(String flowId, FlowSession session) { + return cache.compute(flowId, (__, value) -> { + Map sessionCacheMap = value; + if (sessionCacheMap == null) { + sessionCacheMap = new ConcurrentHashMap<>(); + } + sessionCacheMap.computeIfAbsent(session.getId(), id -> new FlowSessionCache()); + return sessionCacheMap; + }).get(session.getId()); } private static class FlowSessionCache { @@ -66,7 +127,9 @@ private static class FlowSessionCache { * 记录每个节点向下个节点流转数据时,下个节点使用的 session,用于将同一批数据汇聚。 * 其中索引为当前节点正在处理数据的窗口的唯一标识。 */ - private final Map nextSessions = new ConcurrentHashMap<>(); + private final Map nextToSessions = new ConcurrentHashMap<>(); + + private final Map nextEmitterHandleSessions = new ConcurrentHashMap<>(); /** * 记录流程中经过 flatMap 节点产生的窗口信息,用于将同一批数据汇聚。 @@ -74,15 +137,15 @@ private static class FlowSessionCache { */ private final Map flatMapSourceWindows = new ConcurrentHashMap<>(); - /** - * 获取该 session 的 window 对应的向下一个节点传递数据使用的 session。 - * - * @param session session。 - * @return 下一个 session。 - */ - private FlowSession getNextSession(FlowSession session) { - return this.nextSessions.computeIfAbsent(session.getWindow().key(), __ -> { - FlowSession next = new FlowSession(session); + private final Map accOrders = new ConcurrentHashMap<>(); + + private FlowSession getNextToSession(FlowSession session) { + return this.nextToSessions.computeIfAbsent(session.getWindow().key(), __ -> generateNextSession(session)); + } + + private FlowSession getNextEmitterHandleSession(FlowSession session) { + return this.nextEmitterHandleSessions.computeIfAbsent(session.getWindow().key(), __ -> { + FlowSession next = FlowSession.newRootSession(session, session.preserved()); Window nextWindow = next.begin(); // if the processor is not reduce, then inherit previous window condition if (!session.isAccumulator()) { @@ -92,13 +155,6 @@ private FlowSession getNextSession(FlowSession session) { }); } - /** - * 获取 flatMap 节点生成的 {@link FlatMapSourceWindow}。 - * - * @param window 进入到 flatMap 节点数据对应的window。 - * @param repo 流程数据上下文的持久化对象。 - * @return 对应的 {@link FlatMapSourceWindow}。 - */ private FlatMapSourceWindow getFlatMapSourceWindow(Window window, FlowContextRepo repo) { return this.flatMapSourceWindows.computeIfAbsent(window.key(), __ -> { FlatMapSourceWindow newWindow = new FlatMapSourceWindow(window, repo); @@ -108,5 +164,24 @@ private FlatMapSourceWindow getFlatMapSourceWindow(Window window, FlowContextRep return newWindow; }); } + + private int getNextAccOrder(String nodeId) { + return this.accOrders.compute(nodeId, (key, value) -> { + if (value == null) { + return 0; + } + return value + 1; + }); + } + + private static FlowSession generateNextSession(FlowSession session) { + FlowSession next = new FlowSession(session); + Window nextWindow = next.begin(); + // if the processor is not reduce, then inherit previous window condition + if (!session.isAccumulator()) { + nextWindow.setCondition(session.getWindow().getCondition()); + } + return next; + } } } diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/Emitter.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/Emitter.java index 02345376..30b9d028 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/Emitter.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/Emitter.java @@ -21,6 +21,13 @@ public interface Emitter extends Completable { */ void register(EmitterListener listener); + /** + * 取消监听 + * + * @param listener 监听器 + */ + void unregister(EmitterListener listener); + /** * 发布一个数据,并制定session * diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/FlowEmitter.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/FlowEmitter.java index 9599c00b..8068c13d 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/FlowEmitter.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/emitters/FlowEmitter.java @@ -11,7 +11,9 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; /** * 流程数据发布器 @@ -23,7 +25,7 @@ public class FlowEmitter implements Emitter { /** * Emitter的监听器 */ - protected List> listeners = new ArrayList<>(); + protected Set> listeners = new LinkedHashSet<>(); /** * 关联的 session 信息 @@ -42,6 +44,11 @@ public class FlowEmitter implements Emitter { private final List data = new ArrayList<>(); + /** + * 构造空数据的发射器,具体数据由用户自己投递。 + */ + public FlowEmitter() {} + /** * 构造单个数据的Emitter * @@ -90,23 +97,29 @@ public static FlowEmitter flux(I... data) { * @return 新的发射器 */ public static FlowEmitter from(Emitter emitter) { - FlowEmitter cachedEmitter = new FlowEmitter<>(); - EmitterListener emitterListener = (data, token) -> { - cachedEmitter.emit(data, token); - }; - emitter.register(emitterListener); + FlowEmitter cachedEmitter = new AutoCompleteEmitter<>(); + emitter.register(cachedEmitter::emit); return cachedEmitter; } @Override public synchronized void register(EmitterListener listener) { + if (listener == null) { + return; + } this.listeners.add(listener); - if (this.isStart) { this.fire(); } } + @Override + public synchronized void unregister(EmitterListener listener) { + if (listener != null) { + this.listeners.remove(listener); + } + } + @Override public synchronized void emit(D data, FlowSession trans) { if (!this.isStart) { @@ -118,6 +131,9 @@ public synchronized void emit(D data, FlowSession trans) { @Override public synchronized void start(FlowSession session) { + if (this.isStart) { + return; + } if (session != null) { session.begin(); } @@ -183,4 +199,33 @@ protected void tryCompleteWindow() { this.flowSession.getWindow().complete(); } } + + /** + * An emitter implementation that automatically completes based on emission conditions. + * This emitter subclass handles automatic completion logic when certain emission + * criteria are met, reducing the need for manual completion management. + * + * @param the type of data processed by this emitter. + */ + public static class AutoCompleteEmitter extends FlowEmitter { + @Override + public synchronized void start(FlowSession session) { + if (session != null) { + session.begin(); + } + this.setFlowSession(session); + this.setStarted(); + this.fire(); + } + + @Override + public synchronized void emit(D data, FlowSession session) { + session.getWindow().onDone(getOnDoneHandlerId(session), this::complete); + this.listeners.forEach(listener -> listener.handle(data, this.flowSession)); + } + + private static String getOnDoneHandlerId(FlowSession session) { + return "AutoCompleteEmitter" + session.getWindow().id(); + } + } } diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowDefinitionStatus.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowDefinitionStatus.java index 2179f836..98bc1eb6 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowDefinitionStatus.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowDefinitionStatus.java @@ -6,10 +6,10 @@ package modelengine.fit.waterflow.domain.enums; -import static modelengine.fit.waterflow.common.ErrorCodes.ENUM_CONVERT_FAILED; +import static modelengine.fit.waterflow.ErrorCodes.ENUM_CONVERT_FAILED; import lombok.Getter; -import modelengine.fit.waterflow.common.exceptions.WaterflowParamException; +import modelengine.fit.waterflow.exceptions.WaterflowParamException; import java.util.Arrays; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowNodeType.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowNodeType.java index 94321e13..c883777a 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowNodeType.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowNodeType.java @@ -7,10 +7,10 @@ package modelengine.fit.waterflow.domain.enums; import static java.util.Locale.ROOT; -import static modelengine.fit.waterflow.common.ErrorCodes.ENUM_CONVERT_FAILED; +import static modelengine.fit.waterflow.ErrorCodes.ENUM_CONVERT_FAILED; import lombok.Getter; -import modelengine.fit.waterflow.common.exceptions.WaterflowParamException; +import modelengine.fit.waterflow.exceptions.WaterflowParamException; import java.util.Arrays; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowTraceStatus.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowTraceStatus.java index 414a7bb5..22d9d921 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowTraceStatus.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/FlowTraceStatus.java @@ -6,9 +6,9 @@ package modelengine.fit.waterflow.domain.enums; -import static modelengine.fit.waterflow.common.ErrorCodes.ENUM_CONVERT_FAILED; +import static modelengine.fit.waterflow.ErrorCodes.ENUM_CONVERT_FAILED; -import modelengine.fit.waterflow.common.exceptions.WaterflowParamException; +import modelengine.fit.waterflow.exceptions.WaterflowParamException; import java.util.Arrays; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/ParallelMode.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/ParallelMode.java index 75475e95..bc1b6ce6 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/ParallelMode.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/enums/ParallelMode.java @@ -6,10 +6,10 @@ package modelengine.fit.waterflow.domain.enums; -import static modelengine.fit.waterflow.common.ErrorCodes.ENUM_CONVERT_FAILED; +import static modelengine.fit.waterflow.ErrorCodes.ENUM_CONVERT_FAILED; import lombok.Getter; -import modelengine.fit.waterflow.common.exceptions.WaterflowParamException; +import modelengine.fit.waterflow.exceptions.WaterflowParamException; import java.util.Arrays; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/flow/ProcessFlow.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/flow/ProcessFlow.java index 16514cd8..2829fce2 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/flow/ProcessFlow.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/flow/ProcessFlow.java @@ -34,8 +34,8 @@ public ProcessFlow(FlowContextRepo repo, FlowContextMessenger messenger, FlowLoc } @Override - public void handle(D data, FlowSession token) { - this.offer(data, token == null ? new FlowSession() : token); + public void handle(D data, FlowSession session) { + this.offer(data, session == null ? new FlowSession() : session); } @Override @@ -43,6 +43,13 @@ public void register(EmitterListener handler) { this.end.register(handler); } + @Override + public void unregister(EmitterListener handler) { + if (handler != null) { + this.end.unregister(handler); + } + } + @Override public void emit(Object data, FlowSession token) { this.end.emit(data, token); diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Start.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Start.java index cfb6c1f0..bcfd22b4 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Start.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Start.java @@ -10,6 +10,7 @@ import modelengine.fit.waterflow.domain.context.FlowContext; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.context.Window; +import modelengine.fit.waterflow.domain.context.repo.flowsession.FlowSessionRepo; import modelengine.fit.waterflow.domain.emitters.EmitterListener; import modelengine.fit.waterflow.domain.enums.ParallelMode; import modelengine.fit.waterflow.domain.flow.Flow; @@ -173,7 +174,12 @@ public State map(Operators.Map processor) { public State process(Operators.Process processor) { AtomicReference> wrapper = new AtomicReference<>(); State state = new State<>(this.publisher().map(input -> { - processor.process(input.getData(), input, data -> wrapper.get().from.offer(data, input.getSession())); + FlowSession nextSession = + FlowSessionRepo.getNextToSession(this.publisher().getStreamId(), input.getSession()); + processor.process(input.getData(), nextSession, data -> wrapper.get().from.offer(data, nextSession)); + if (input.getSession().getWindow().isOngoing()) { + nextSession.getWindow().complete(); + } return null; }, null), this.getFlow()); wrapper.set(state); diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/State.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/State.java index ad70eac4..6aac4b54 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/State.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/State.java @@ -76,10 +76,16 @@ public void handle(O data, FlowSession trans) { * * @param handler 表示监听器的 {@link EmitterListener}{@code <}{@link O}{@code ,}{@link FlowSession}{@code >}。 */ + @Override public void register(EmitterListener handler) { this.processor.register(handler); } + @Override + public void unregister(EmitterListener listener) { + this.processor.unregister(listener); + } + @Override public void emit(O data, FlowSession token) { this.processor.emit(data, token); @@ -204,13 +210,13 @@ public F close(Operators.Just>> callback, Operators.Just nodes.stream().filter(node -> !node.subscribed()).forEach(node -> node.subscribe(getFlow().end())); getFlow().end().onComplete((Operators.Just>>) input -> { FlowDebug.log(input.get().getSession(), - "[close] " + this.getFlow().end().getId() + ":" + "end. data:" + input.get().getData()); + "[close] " + this.getFlow().end().getStreamId() + ":" + "end. data:" + input.get().getData()); callback.process(input); input.get().getWindow().peekAndConsume().finishConsume(); - if (input.get().getWindow().isDone()) { - FlowSessionRepo.release(input.get().getSession()); + input.get().getWindow().onDone(this.getFlow().end().getId(), () -> { + FlowSessionRepo.release(this.processor.getStreamId(), input.get().getSession()); this.getFlow().completeSession(input.get().getSession().getId()); - } + }); }); if (sessionComplete != null) { getFlow().end().onSessionComplete(session -> { @@ -247,7 +253,9 @@ public F close(BiConsumer sessionConsumer, Consumer private Operators.ErrorHandler buildGlobalHandler(Operators.ErrorHandler errHandler, FlowContextRepo repo) { return (exception, retryable, contexts) -> { - contexts.stream().findFirst().ifPresent(context -> FlowSessionRepo.release(context.getSession())); + contexts.stream() + .findFirst() + .ifPresent(context -> FlowSessionRepo.release(this.processor.getStreamId(), context.getSession())); contexts.forEach(context -> context.setStatus(FlowNodeStatus.ERROR)); repo.save(contexts); if (errHandler != null) { diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/ConditionsNode.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/ConditionsNode.java index 06fd9ecc..d080cf19 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/ConditionsNode.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/ConditionsNode.java @@ -11,15 +11,9 @@ import modelengine.fit.waterflow.domain.context.repo.flowcontext.FlowContextRepo; import modelengine.fit.waterflow.domain.context.repo.flowlock.FlowLocks; import modelengine.fit.waterflow.domain.enums.FlowNodeType; -import modelengine.fit.waterflow.domain.stream.reactive.Subscription; -import modelengine.fit.waterflow.domain.utils.IdGenerator; import modelengine.fit.waterflow.domain.utils.UUIDUtil; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; /** @@ -77,44 +71,20 @@ private static From initFrom(String streamId, FlowContextRepo repo, FlowC } private static class ConditionFrom extends From { - private final Map>> sessionSubscription = new ConcurrentHashMap<>(); - public ConditionFrom(String streamId, FlowContextRepo repo, FlowContextMessenger messenger, FlowLocks locks) { super(streamId, repo, messenger, locks); } @Override public void offer(List> contexts) { - this.offerUserContexts(contexts); - } - - private void offerUserContexts(List> contexts) { this.getSubscriptions().forEach(subscription -> { List> matched = contexts.stream() .filter(context -> subscription.getWhether().is(context.getData())) - .peek(context -> { - this.record(subscription, context); - }) .collect(Collectors.toList()); matched.forEach(contexts::remove); subscription.cache(matched); }); } - private void record(Subscription subscription, FlowContext context) { - String sessionId = getSessionId(context); - if (sessionId == null) { - return; - } - this.sessionSubscription.putIfAbsent(sessionId, new LinkedHashMap<>()); - Map> subscriptionMap = this.sessionSubscription.get(sessionId); - if (!subscriptionMap.containsKey(subscription.getId())) { - subscriptionMap.put(subscription.getId(), subscription); - } - } - - private String getSessionId(FlowContext context) { - return Optional.ofNullable(context.getSession()).map(IdGenerator::getId).orElse(null); - } } } diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/From.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/From.java index 663291ad..13416a39 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/From.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/From.java @@ -6,9 +6,10 @@ package modelengine.fit.waterflow.domain.stream.nodes; -import static modelengine.fit.waterflow.common.ErrorCodes.FLOW_ENGINE_INVALID_MANUAL_TASK; +import static modelengine.fit.waterflow.ErrorCodes.FLOW_ENGINE_INVALID_MANUAL_TASK; -import modelengine.fit.waterflow.common.exceptions.WaterflowException; +import modelengine.fit.waterflow.domain.context.repo.flowsession.FlowSessionRepo; +import modelengine.fit.waterflow.exceptions.WaterflowException; import modelengine.fit.waterflow.domain.context.FlatMapSourceWindow; import modelengine.fit.waterflow.domain.context.FlatMapWindow; import modelengine.fit.waterflow.domain.context.FlowContext; @@ -178,16 +179,15 @@ public Processor flatMap(Operators.FlatMap, O> processo Validation.notNull(processor, "Flat map processor can not be null."); AtomicReference> processRef = new AtomicReference<>(); Operators.Map, O> wrapper = input -> { - FlatMapSourceWindow fWindow = FlatMapSourceWindow.from(input.getWindow(), this.repo); + FlatMapSourceWindow fWindow = FlatMapSourceWindow.from(this.streamId, input.getWindow(), this.repo); - final FlowSession session = new FlowSession(input.getSession()); FlatMapWindow flatMapWindow = new FlatMapWindow(fWindow); - session.setWindow(flatMapWindow); + final FlowSession session = FlowSession.from(input.getSession(), flatMapWindow); session.begin(); DataStart start = processor.process(input); - FlowSession startSession = new FlowSession(); + FlowSession startSession = new FlowSession(input.getSession().preserved()); flatMapWindow.setSource(startSession.begin()); startSession.onError(exception -> { processRef.get().fail(exception, Collections.singletonList(input)); @@ -207,7 +207,11 @@ public Processor flatMap(Operators.FlatMap, O> processo public Processor process(Operators.Process, O> processor, Operators.Whether whether) { AtomicReference> processRef = new AtomicReference<>(); Operators.Map, O> wrapper = input -> { - processor.process(input, input, data -> processRef.get().offer(data, input.getSession())); + FlowSession nextSession = FlowSessionRepo.getNextToSession(this.streamId, input.getSession()); + processor.process(input, input, data -> processRef.get().offer(data, nextSession)); + if (input.getSession().getWindow().isOngoing()) { + nextSession.getWindow().complete(); + } return null; }; Node node = new Node<>(this.getStreamId(), wrapper, repo, messenger, locks); diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java index 7a0925b5..e03d75a1 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java @@ -6,11 +6,11 @@ package modelengine.fit.waterflow.domain.stream.nodes; -import static modelengine.fit.waterflow.common.ErrorCodes.FLOW_NODE_CREATE_ERROR; -import static modelengine.fit.waterflow.common.ErrorCodes.FLOW_NODE_MAX_TASK; +import static modelengine.fit.waterflow.ErrorCodes.FLOW_NODE_CREATE_ERROR; +import static modelengine.fit.waterflow.ErrorCodes.FLOW_NODE_MAX_TASK; import lombok.Getter; -import modelengine.fit.waterflow.common.exceptions.WaterflowException; +import modelengine.fit.waterflow.exceptions.WaterflowException; import modelengine.fit.waterflow.domain.common.Constants; import modelengine.fit.waterflow.domain.context.FlowContext; import modelengine.fit.waterflow.domain.context.FlowSession; @@ -30,12 +30,12 @@ import modelengine.fit.waterflow.domain.stream.reactive.Callback; import modelengine.fit.waterflow.domain.stream.reactive.Subscriber; import modelengine.fit.waterflow.domain.stream.reactive.Subscription; -import modelengine.fit.waterflow.domain.utils.FlowDebug; import modelengine.fit.waterflow.domain.utils.FlowExecutors; import modelengine.fit.waterflow.domain.utils.IdGenerator; import modelengine.fit.waterflow.domain.utils.Identity; import modelengine.fit.waterflow.domain.utils.SleepUtil; import modelengine.fit.waterflow.domain.utils.UUIDUtil; +import modelengine.fitframework.inspection.Validation; import modelengine.fitframework.log.Logger; import modelengine.fitframework.schedule.Task; import modelengine.fitframework.util.CollectionUtils; @@ -138,7 +138,7 @@ public class To extends IdGenerator implements Subscriber { @Getter private ProcessMode processMode; - private Map processingSessions = new ConcurrentHashMap<>();//todo:夏斐,确定合适清除,否则有内存泄露风险 + private Map processingSessions = new ConcurrentHashMap<>(); private Operators.Validator validator = (repo, to) -> repo.requestMappingContext(to.streamId, to.froms.stream().map(Identity::getId).collect(Collectors.toList()), to.processingSessions); @@ -192,7 +192,7 @@ public class To extends IdGenerator implements Subscriber { private Thread preProcessT = null; - private final Set listeners = new HashSet<>(); + private final Map> listeners = new ConcurrentHashMap<>(); private final Map nextSessions = new ConcurrentHashMap<>(); @@ -487,6 +487,11 @@ public void onProcess(ProcessType type, List> preList, boolean is return; } List> afterList = this.getProcessMode().process(this, preList); + preList.forEach(context -> { + context.getWindow() + .onDone(getCleanProcessingSessionHandlerId(context), + () -> this.processingSessions.remove(context.getSession().getId())); + }); this.afterProcess(preList, afterList); if (CollectionUtils.isNotEmpty(afterList)) { // 查找一个transaction里的所有数据的都完成了,运行callback给stream外反馈数据 @@ -495,6 +500,12 @@ public void onProcess(ProcessType type, List> preList, boolean is } // 处理好数据后对外送数据,驱动其他flow响应 afterList.forEach(context -> this.emit(context.getData(), context.getSession())); + // keep order + preList.forEach(context -> { + if (context.getIndex() > Constants.NOT_PRESERVED_INDEX && !context.getWindow().isDone()) { + this.processingSessions.put(context.getSession().getId(), context.getIndex() + 1); + } + }); } catch (Exception ex) { LOG.error("Node process exception stream-id: {}, node-id: {}, position-id: {}, traceId: {}. caused by: {}", this.streamId, this.id, preList.get(0).getPosition(), preList.get(0).getTraceId(), @@ -521,8 +532,12 @@ protected void fail(Exception exception, List> preList) { Optional.ofNullable(this.errorHandler).ifPresent(handler -> handler.handle(exception, retryable, preList)); Optional.ofNullable(this.globalErrorHandler) .ifPresent(handler -> handler.handle(exception, retryable, preList)); + preList.forEach(context -> this.processingSessions.remove(context.getSession().getId())); } + private static String getCleanProcessingSessionHandlerId(FlowContext ctx) { + return "ProcessingSession" + ctx.getSession().getId(); + } private List> filterTerminate(List> contexts) { if (CollectionUtils.isEmpty(contexts)) { @@ -554,20 +569,11 @@ public void onNext(String batchId) { private void feedback(List> contexts) { this.callback.process(new ToCallback<>(contexts)); - if (this.sessionCompleteCallback != null) { contexts.forEach(context -> { - FlowDebug.log(String.format("[feedback] nodeId=%s isComplete=%s, sessionId=%s, windowId=%s, data=%s" - + ", tokens=%s", - this.getId() + this.getNodeType(), - context.getSession().getWindow().isComplete(), - context.getSession().getId(), context.getSession().getWindow().id(), - context.getData().toString(), - context.getSession().getWindow().debugTokens() - )); - if (context.getSession().getWindow().isComplete()) { + context.getSession().getWindow().onDone(context.getSession().getWindow().id(), () -> { this.sessionCompleteCallback.process(context.getSession()); - } + }); }); } } @@ -694,17 +700,28 @@ public String getStreamId() { } @Override - public void register(EmitterListener handler) { - this.listeners.add(handler); + public void register(EmitterListener listener) { + Validation.notNull(listener, "The emitter listener should not be null."); + this.listeners.put(listener, listener); } @Override - public void emit(O data, FlowSession trans) { - this.listeners.forEach(listener -> listener.handle(data, trans)); + public void unregister(EmitterListener listener) { + Validation.notNull(listener, "The emitter listener should not be null."); + this.listeners.remove(listener); + } + + @Override + public void emit(O data, FlowSession session) { + this.listeners.values().forEach(listener -> listener.handle(data, session)); } private FlowSession getNextSession(FlowSession session) { - return FlowSessionRepo.getNextSession(session); + return FlowSessionRepo.getNextToSession(this.streamId, session); + } + + private int getNextAccOrder(FlowSession session) { + return FlowSessionRepo.getNextAccOrder(this.streamId, this.id, session); } /** @@ -741,36 +758,29 @@ public List> process(To to, List clonedContext = context.generate(data, to.getId()); clonedContext.setSession(nextSession); if (context.getSession().isAccumulator()) { - Integer index = to.counter.get(context.getSession().getId()); - if (index == null) { - index = 0; - } else { - index++; + if (clonedContext.getIndex() > Constants.NOT_PRESERVED_INDEX) { + clonedContext.setIndex(0); } - to.counter.put(context.getSession().getId(), index); - clonedContext.setIndex(index); } //accept the consumed token, and create a new token for the handled data, meanwhile,consume the peeked nextSession.getWindow().acceptToken(peekedToken); cs.add(clonedContext); + //if previous stream complete, complete this stream + if (context.getSession().getWindow().isDone()) { + nextSession.getWindow().complete(); + } } else { - peekedToken.finishConsume();//consume the peeked - } - //keep order - if (context.getIndex() > Constants.NOT_PRESERVED_INDEX) { - to.processingSessions.put(context.getSession().getId(), context.getIndex() + 1); - } - - //if previous stream complete, complete this stream - if (context.getSession().getWindow().isDone()) { - nextSession.getWindow().complete(); + peekedToken.finishConsume(); + if (window.isDone()) { + window.tryFinish(); + } } } return cs; diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/reactive/Publisher.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/reactive/Publisher.java index 439ee382..3dfcf621 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/reactive/Publisher.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/reactive/Publisher.java @@ -9,6 +9,7 @@ import modelengine.fit.waterflow.domain.context.FlowContext; import modelengine.fit.waterflow.domain.context.FlowSession; import modelengine.fit.waterflow.domain.context.repo.flowcontext.FlowContextRepo; +import modelengine.fit.waterflow.domain.context.repo.flowsession.FlowSessionRepo; import modelengine.fit.waterflow.domain.emitters.EmitterListener; import modelengine.fit.waterflow.domain.enums.ParallelMode; import modelengine.fit.waterflow.domain.stream.operators.Operators; @@ -37,7 +38,9 @@ default void handle(I data) { @Override default void handle(I data, FlowSession flowSession) { - this.offer(data, new FlowSession(flowSession)); + FlowSession nextSession = FlowSessionRepo.getNextEmitterHandleSession(this.getStreamId(), flowSession); + this.offer(data, nextSession); + flowSession.getWindow().onDone(this.getStreamId(), () -> nextSession.getWindow().complete()); } /** diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/utils/FlowDebug.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/utils/FlowDebug.java index 60d4555c..828d623e 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/utils/FlowDebug.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/utils/FlowDebug.java @@ -42,8 +42,11 @@ public static void log(FlowSession session, String msg) { if (!isOpen) { return; } - LOG.debug("Thread:{0}. tokenCount:{1}, getTosSize={2}, isComplete={3}. msg={4}", Thread.currentThread().getId(), - session.getWindow().tokenCount(), session.getWindow().getTosSize(), session.getWindow().isComplete(), + LOG.debug("Thread:{0}. tokenCount:{1}, getTosSize={2}, isComplete={3}. msg={4}", + Thread.currentThread().getId(), + session.getWindow().tokenCount(), + session.getWindow().getTosSize(), + session.getWindow().isComplete(), msg); } } diff --git a/framework/waterflow/java/waterflow-core/src/test/java/modelengine/fit/waterflow/domain/WaterFlowsTest.java b/framework/waterflow/java/waterflow-core/src/test/java/modelengine/fit/waterflow/domain/WaterFlowsTest.java index 84a73ef0..1ee27b92 100644 --- a/framework/waterflow/java/waterflow-core/src/test/java/modelengine/fit/waterflow/domain/WaterFlowsTest.java +++ b/framework/waterflow/java/waterflow-core/src/test/java/modelengine/fit/waterflow/domain/WaterFlowsTest.java @@ -750,6 +750,11 @@ public void register(EmitterListener handler) { this.handler = handler; } + @Override + public void unregister(EmitterListener listener) { + this.handler = null; + } + @Override public void emit(Integer data, FlowSession trance) { this.handler.handle(data, trance); diff --git a/framework/waterflow/java/waterflow-dependency/pom.xml b/framework/waterflow/java/waterflow-dependency/pom.xml index 10ebe7e3..23a240df 100644 --- a/framework/waterflow/java/waterflow-dependency/pom.xml +++ b/framework/waterflow/java/waterflow-dependency/pom.xml @@ -69,11 +69,6 @@ fit-api ${fit.version} - - org.fitframework - fit-data-repository-service - ${fit.version} - org.fitframework.service fit-http-classic @@ -109,6 +104,11 @@ fit-util ${fit.version} + + org.fitframework.service + fit-security + ${fit.version} + @@ -123,7 +123,7 @@ org.fitframework.waterflow - waterflow-genericable + waterflow-common ${waterflow.version} @@ -145,6 +145,11 @@ fastjson ${fastjson.version} + + org.projectlombok + lombok + ${lombok.version} + com.fasterxml.jackson.core diff --git a/framework/waterflow/java/waterflow-genericable/pom.xml b/framework/waterflow/java/waterflow-eco/pom.xml similarity index 66% rename from framework/waterflow/java/waterflow-genericable/pom.xml rename to framework/waterflow/java/waterflow-eco/pom.xml index a0776268..c9bc340f 100644 --- a/framework/waterflow/java/waterflow-genericable/pom.xml +++ b/framework/waterflow/java/waterflow-eco/pom.xml @@ -2,19 +2,16 @@ 4.0.0 - org.fitframework.waterflow waterflow-parent 3.5.0-SNAPSHOT - waterflow-genericable + waterflow-eco + pom - - - org.fitframework - fit-api - - - \ No newline at end of file + + waterflow-bridge-fit-reactor + + diff --git a/framework/waterflow/java/waterflow-bridge-fit-reactor/pom.xml b/framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/pom.xml similarity index 88% rename from framework/waterflow/java/waterflow-bridge-fit-reactor/pom.xml rename to framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/pom.xml index 92aae4a9..5942fdc2 100644 --- a/framework/waterflow/java/waterflow-bridge-fit-reactor/pom.xml +++ b/framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/pom.xml @@ -5,7 +5,7 @@ org.fitframework.waterflow - waterflow-parent + waterflow-eco 3.5.0-SNAPSHOT @@ -20,10 +20,6 @@ org.fitframework fit-reactor - - org.projectlombok - lombok - org.junit.jupiter junit-jupiter diff --git a/framework/waterflow/java/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java b/framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java similarity index 96% rename from framework/waterflow/java/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java rename to framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java index 8cf82770..852aa270 100644 --- a/framework/waterflow/java/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java +++ b/framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/src/main/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitter.java @@ -24,8 +24,8 @@ public abstract class FitBoundedEmitter extends FlowEmitter { private final Function dataConverter; private boolean isError = false; - private Exception exception; + private Publisher publisher; /** * 通过数据发布者和有限流数据构造器初始化 {@link FitBoundedEmitter}{@code <}{@link O}{@code , }{@link D}{@code >}。 @@ -35,7 +35,7 @@ public abstract class FitBoundedEmitter extends FlowEmitter { */ public FitBoundedEmitter(Publisher publisher, Function dataConverter) { this.dataConverter = dataConverter; - publisher.subscribe(new FitBoundedEmitter.EmitterSubscriber<>(this)); + this.publisher = publisher; } @Override @@ -52,6 +52,7 @@ public synchronized void start(FlowSession session) { // 启动时先发射缓存的数据,此时可能先缓存了数据,所以开始时发射完数据就可能结束了。 this.fire(); this.tryCompleteWindow(); + this.publisher.subscribe(new EmitterSubscriber<>(this)); } private void doEmit(D data) { diff --git a/framework/waterflow/java/waterflow-bridge-fit-reactor/src/test/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitterTest.java b/framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/src/test/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitterTest.java similarity index 100% rename from framework/waterflow/java/waterflow-bridge-fit-reactor/src/test/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitterTest.java rename to framework/waterflow/java/waterflow-eco/waterflow-bridge-fit-reactor/src/test/java/modelengine/fit/waterflow/bridge/fitflow/FitBoundedEmitterTest.java diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/ErrorCodes.java b/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/ErrorCodes.java deleted file mode 100644 index 29bf9390..00000000 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/common/ErrorCodes.java +++ /dev/null @@ -1,119 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. - * This file is a part of the ModelEngine Project. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -package modelengine.fit.waterflow.common; - -/** - * 异常类型枚举类 - * - * @author 陈镕希 - * @since 1.0 - */ -public enum ErrorCodes { - // /** ------------ Generic Exception. From 10000000 to 10000999 --------------------- */ - /** - * 枚举类转换异常 - */ - ENUM_CONVERT_FAILED(10000001, "Cannot convert enum {0} by name: {1}."), - /** - * 入参不合法 - */ - INPUT_PARAM_IS_INVALID(10000003, "Input param is invalid, invalid param is {0}."), - /** - * 分页查询时Offset范围不正确。 - */ - PAGINATION_OFFSET_INVALID(10000008, "The range of offset is incorrect."), - /** - * 分页查询时Limit范围不正确。 - */ - PAGINATION_LIMIT_INVALID(10000009, "The range of limit is incorrect."), - - /** - * 类型转换失败。 - */ - TYPE_CONVERT_FAILED(10000011, "Cannot convert type."), - - /** ------------ waterflow Exception 10007000-10007999 --------------------- */ - /** - * flow节点任务数达到最大值 - */ - FLOW_NODE_MAX_TASK(100070024, "Flow node id {0} tasks over the limit."), - - /** - * 流程节点转换不支持操作 - */ - FLOW_NODE_CREATE_ERROR(10007000, "Processor can not be null during create flowable node."), - /** - * 流程节点不支持执行操作 - */ - FLOW_NODE_OPERATOR_NOT_SUPPORT(10007001, "Flow node with id: {0}, type: {1}, for operator [{2}] not supported."), - /** - * 流程没有开始节点 - */ - FLOW_HAS_NO_START_NODE(10007002, "Flow definition with id: {0} has no start node."), - /** - * 流程执行错误,没有手动执行任务 - */ - FLOW_ENGINE_INVALID_MANUAL_TASK(10007004, "Flow engine executor error for invalid manual task."), - /** - * 流程定义解析失败 - */ - FLOW_ENGINE_PARSER_NOT_SUPPORT(10007010, "Flow engine parser not support {0} operator."), - FLOW_EXECUTE_FITABLE_TASK_FAILED(10007012, - "execute jober failed, jober name: {0}, jober type: {1}, fitables: {2}, errors: {3}"), - /** - * 流程引擎数据库不支持该操作 - */ - FLOW_ENGINE_DATABASE_NOT_SUPPORT(100070014, "Operation :{0} is not supported."), - /** - * 通过eventMetaId查询to节点失败 - */ - FLOW_FIND_TO_NODE_BY_EVENT_FAILED(100070016, "Find to node by event metaId :{0} failed."), - /** - * 流程回调函数执行fitables失败 - */ - FLOW_EXECUTE_CALLBACK_FITABLES_FAILED(100070023, - "Failed to execute callback, callback name: {0}, callback type: {1}, fitables: {2}, errors: {3}"), - - /** - * 流程引擎OhScript语法错误 - */ - FLOW_ENGINE_OHSCRIPT_GRAMMAR_ERROR(100070024, "OhScript grammar error. Source Code: {0}"), - - /** - * 流程引擎条件规则变量未找到 - */ - FLOW_ENGINE_CONDITION_RULE_PARSE_ERROR(100070025, "Condition rule parse error. Condition Rule: {0}"), - - /** - * 流程执行过程出现异常 - */ - FLOW_ENGINE_EXECUTOR_ERROR(10007500, "Error code: 10007500, Flow engine executor errors " - + "stream id: {0}, node id: {1}, name: {2}, exception: {3}, error message: {4}."), - ; - - private final Integer errorCode; - - private final String message; - - ErrorCodes(Integer errorCode, String message) { - this.errorCode = errorCode; - this.message = message; - } - - public Integer getErrorCode() { - return errorCode; - } - - public String getMessage() { - return message; - } - - @Override - public String toString() { - return "err " + this.errorCode + ": " + this.message; - } -} diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowExceptionNotify.java b/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowExceptionNotify.java deleted file mode 100644 index e7c9ab02..00000000 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowExceptionNotify.java +++ /dev/null @@ -1,35 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. - * This file is a part of the ModelEngine Project. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -package modelengine.fit.waterflow.spi; - -import modelengine.fitframework.annotation.Genericable; - -import java.util.List; -import java.util.Map; - -/** - * 流程实例异常 Genericable。 - * - * @author 李哲峰 - * @since 1.0 - */ -public interface WaterflowExceptionNotify { - /** - * ON_EXCEPTION_GENERICABLE - */ - String ON_EXCEPTION_GENERICABLE = "1b5ffv4ib16iui8ddizapuejgqtsjj59"; - - /** - * 异常回调实现 - * - * @param nodeId 异常发生的节点Id - * @param contexts 流程上下文 - * @param errorMessage 异常错误信息 - */ - @Genericable(id = ON_EXCEPTION_GENERICABLE) - void onException(String nodeId, List> contexts, String errorMessage); -} diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowNodeNotify.java b/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowNodeNotify.java deleted file mode 100644 index 12d6abd5..00000000 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowNodeNotify.java +++ /dev/null @@ -1,33 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. - * This file is a part of the ModelEngine Project. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -package modelengine.fit.waterflow.spi; - -import modelengine.fitframework.annotation.Genericable; - -import java.util.List; -import java.util.Map; - -/** - * 流程实例回调函数 Genericable。 - * - * @author 李哲峰 - * @since 1.0 - */ -public interface WaterflowNodeNotify { - /** - * ON_CONTEXT_COMPLETE_GENERICABLE - */ - String ON_CONTEXT_COMPLETE_GENERICABLE = "w8onlgq9xsw13jce4wvbcz3kbmjv3tuw"; - - /** - * 回调函数实现 - * - * @param contexts 流程上下文信息 - */ - @Genericable(id = ON_CONTEXT_COMPLETE_GENERICABLE) - void onContextComplete(List> contexts); -} diff --git a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowTaskHandler.java b/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowTaskHandler.java deleted file mode 100644 index 7ee436b3..00000000 --- a/framework/waterflow/java/waterflow-genericable/src/main/java/modelengine/fit/waterflow/spi/WaterflowTaskHandler.java +++ /dev/null @@ -1,29 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. - * This file is a part of the ModelEngine Project. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -package modelengine.fit.waterflow.spi; - -import modelengine.fitframework.annotation.Genericable; - -import java.util.List; -import java.util.Map; - -/** - * 流程服务的Genericable - * - * @author 晏钰坤 - * @since 1.0 - */ -public interface WaterflowTaskHandler { - /** - * 处理流程中的任务调用 - * - * @param flowData 流程执行上下文数据 - * @return 任务执行返回结果 - */ - @Genericable(id = "b735c87f5e7e408d852d8440d0b2ecdf") - List> handleTask(List> flowData); -}