-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
Please do a quick search on GitHub issues first, there might be already a duplicate issue for the one you are about to create.
If the bug is trivial, just go ahead and create the issue. Otherwise, please take a few moments and fill in the following sections:
Bug description
A clear and concise description of what the bug is about.
Environment
spring ai 1.1.2
Steps to reproduce
The 'thinking' parameter was not returned in the streaming response
Expected behavior
A clear and concise description of what you expected to happen.
Minimal Complete Reproducible example
package org.springframework.ai.ollama.OllamaChatModel
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// Before moving any further, build the final request Prompt,
// merging runtime and default options.
Prompt requestPrompt = buildRequestPrompt(prompt);
return this.internalStream(requestPrompt, null);
}
private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true);
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OllamaApiConstants.PROVIDER_NAME)
.build();
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
Flux<ChatResponse> chatResponse = ollamaResponse.map(chunk -> {
String content = (chunk.message() != null) ? chunk.message().content() : "";
List<AssistantMessage.ToolCall> toolCalls = List.of();
// Added null checks to prevent NPE when accessing tool calls
if (chunk.message() != null && chunk.message().toolCalls() != null) {
toolCalls = chunk.message()
.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
.toList();
}
var assistantMessage = AssistantMessage.builder()
.content(content)
.properties(Map.of())
.toolCalls(toolCalls)
.build();
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
}
var generator = new Generation(assistantMessage, generationMetadata);
return new ChatResponse(List.of(generator), from(chunk, previousChatResponse));
});
// @formatter:off
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}).subscribeOn(Schedulers.boundedElastic());
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s ->
observation.stop()
)
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on
return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
});
}