diff --git a/Examples/Examples/Chat/ChatExampleToolsSimple.cs b/Examples/Examples/Chat/ChatExampleToolsSimple.cs index d3a6e37..9681976 100644 --- a/Examples/Examples/Chat/ChatExampleToolsSimple.cs +++ b/Examples/Examples/Chat/ChatExampleToolsSimple.cs @@ -9,9 +9,9 @@ public class ChatExampleToolsSimple : IExample public async Task Start() { OpenAiExample.Setup(); //We need to provide OpenAi API key - - Console.WriteLine("(OpenAi) ChatExample with tools is running!"); - + + Console.WriteLine("(OpenAi) ChatExample with tools is running!"); + await AIHub.Chat() .WithModel("gpt-5-nano") .WithMessage("What time is it right now?") @@ -24,4 +24,4 @@ await AIHub.Chat() .Build()) .CompleteAsync(interactive: true); } -} \ No newline at end of file +} diff --git a/Examples/Examples/Chat/ChatExampleToolsSimpleLocalLLM.cs b/Examples/Examples/Chat/ChatExampleToolsSimpleLocalLLM.cs new file mode 100644 index 0000000..9694c1a --- /dev/null +++ b/Examples/Examples/Chat/ChatExampleToolsSimpleLocalLLM.cs @@ -0,0 +1,25 @@ +using Examples.Utils; +using MaIN.Core.Hub; +using MaIN.Core.Hub.Utils; + +namespace Examples.Chat; + +public class ChatExampleToolsSimpleLocalLLM : IExample +{ + public async Task Start() + { + Console.WriteLine("Local LLM ChatExample with tools is running!"); + + await AIHub.Chat() + .WithModel("gemma3:4b") + .WithMessage("What time is it right now?") + .WithTools(new ToolsConfigurationBuilder() + .AddTool( + name: "get_current_time", + description: "Get the current date and time", + execute: Tools.GetCurrentTime) + .WithToolChoice("auto") + .Build()) + .CompleteAsync(interactive: true); + } +} \ No newline at end of file diff --git a/Examples/Examples/Program.cs b/Examples/Examples/Program.cs index f1b2bdc..627261c 100644 --- a/Examples/Examples/Program.cs +++ b/Examples/Examples/Program.cs @@ -51,6 +51,7 @@ static void RegisterExamples(IServiceCollection services) services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); @@ -161,6 +162,7 @@ public class ExampleRegistry(IServiceProvider serviceProvider) ("\u25a0 Chat with Files from stream", serviceProvider.GetRequiredService()), ("\u25a0 Chat with Vision", serviceProvider.GetRequiredService()), ("\u25a0 Chat with Tools (simple)", serviceProvider.GetRequiredService()), + ("\u25a0 Chat with Tools (simple Local LLM)", serviceProvider.GetRequiredService()), ("\u25a0 Chat with Image Generation", serviceProvider.GetRequiredService()), ("\u25a0 Chat from Existing", serviceProvider.GetRequiredService()), ("\u25a0 Chat with reasoning", serviceProvider.GetRequiredService()), @@ -197,4 +199,4 @@ public class ExampleRegistry(IServiceProvider serviceProvider) ]; } }; -} \ No newline at end of file +} diff --git a/Releases/0.9.0.md b/Releases/0.9.0.md new file mode 100644 index 0000000..dc98b6e --- /dev/null +++ b/Releases/0.9.0.md @@ -0,0 +1,3 @@ +# 0.9.0 release + +- Add tool calling to local models \ No newline at end of file diff --git a/src/MaIN.Core/.nuspec b/src/MaIN.Core/.nuspec index 5135042..c4b8816 100644 --- a/src/MaIN.Core/.nuspec +++ b/src/MaIN.Core/.nuspec @@ -2,7 +2,7 @@ MaIN.NET - 0.8.1 + 0.9.0 Wisedev Wisedev favicon.png diff --git a/src/MaIN.Domain/Entities/Tools/FunctionCall.cs b/src/MaIN.Domain/Entities/Tools/FunctionCall.cs index b3ecead..065c173 100644 --- a/src/MaIN.Domain/Entities/Tools/FunctionCall.cs +++ b/src/MaIN.Domain/Entities/Tools/FunctionCall.cs @@ -1,7 +1,12 @@ -namespace MaIN.Domain.Entities.Tools; +using System.Text.Json.Serialization; -public class FunctionCall +namespace MaIN.Domain.Entities.Tools; + +public sealed record FunctionCall { - public string Name { get; set; } = null!; - public string Arguments { get; set; } = null!; -} \ No newline at end of file + [JsonPropertyName("name")] + public string Name { get; init; } = string.Empty; + + [JsonPropertyName("arguments")] + public string Arguments { get; init; } = "{}"; +} diff --git a/src/MaIN.Domain/Entities/Tools/ToolCall.cs b/src/MaIN.Domain/Entities/Tools/ToolCall.cs new file mode 100644 index 0000000..8eb5776 --- /dev/null +++ b/src/MaIN.Domain/Entities/Tools/ToolCall.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace MaIN.Domain.Entities.Tools; + +public sealed record ToolCall +{ + [JsonPropertyName("id")] + public string Id { get; init; } = string.Empty; + + [JsonPropertyName("type")] + public string Type { get; init; } = "function"; + + [JsonPropertyName("function")] + public FunctionCall Function { get; init; } = new(); +} diff --git a/src/MaIN.Domain/Entities/Tools/ToolDefinition.cs b/src/MaIN.Domain/Entities/Tools/ToolDefinition.cs index e6c68f9..cfd73fd 100644 --- a/src/MaIN.Domain/Entities/Tools/ToolDefinition.cs +++ b/src/MaIN.Domain/Entities/Tools/ToolDefinition.cs @@ -1,8 +1,12 @@ -namespace MaIN.Domain.Entities.Tools; +using System.Text.Json.Serialization; + +namespace MaIN.Domain.Entities.Tools; public class ToolDefinition { public string Type { get; set; } = "function"; public FunctionDefinition? Function { get; set; } + + [JsonIgnore] public Func>? Execute { get; set; } } \ No newline at end of file diff --git a/src/MaIN.Services/Services/LLMService/LLMService.cs b/src/MaIN.Services/Services/LLMService/LLMService.cs index ceb539e..deb441f 100644 --- a/src/MaIN.Services/Services/LLMService/LLMService.cs +++ b/src/MaIN.Services/Services/LLMService/LLMService.cs @@ -1,5 +1,6 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Text; +using System.Text.Json; using LLama; using LLama.Batched; using LLama.Common; @@ -8,6 +9,7 @@ using MaIN.Domain.Configuration; using MaIN.Domain.Entities; using MaIN.Domain.Exceptions.Models; +using MaIN.Domain.Entities.Tools; using MaIN.Domain.Models; using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; @@ -26,6 +28,7 @@ public class LLMService : ILLMService { private const string DEFAULT_MODEL_ENV_PATH = "MaIN_ModelsPath"; private static readonly ConcurrentDictionary _sessionCache = new(); + private const int MaxToolIterations = 5; private readonly MaINSettings options; private readonly INotificationService notificationService; @@ -52,7 +55,9 @@ public LLMService( CancellationToken cancellationToken = default) { if (chat.Messages.Count == 0) + { return null; + } var lastMsg = chat.Messages.Last(); @@ -62,6 +67,11 @@ public LLMService( return await AskMemory(chat, memoryOptions, requestOptions, cancellationToken); } + if (chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any()) + { + return await ProcessWithToolsAsync(chat, requestOptions, cancellationToken); + } + var model = KnownModels.GetModel(chat.Model); var tokens = await ProcessChatRequest(chat, model, lastMsg, requestOptions, cancellationToken); lastMsg.MarkProcessed(); @@ -82,7 +92,9 @@ public Task GetCurrentModels() public Task CleanSessionCache(string? id) { if (string.IsNullOrEmpty(id) || !_sessionCache.TryRemove(id, out var session)) + { return Task.CompletedTask; + } session.Executor.Context.Dispose(); return Task.CompletedTask; @@ -302,7 +314,9 @@ private static async Task ProcessImageMessage(Conversation conversation, conversation.Prompt(imageEmbeddings!); while (executor.BatchedTokenCount > 0) + { await executor.Infer(cancellationToken); + } var prompt = llmModel.Tokenize($"USER: {lastMsg.Content}\nASSISTANT:", true, false, Encoding.UTF8); conversation.Prompt(prompt); @@ -319,16 +333,26 @@ private static void ProcessTextMessage(Conversation conversation, var template = new LLamaTemplate(llmModel); var finalPrompt = ChatHelper.GetFinalPrompt(lastMsg, model, isNewConversation); + var hasTools = chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any(); + if (isNewConversation) { - foreach (var messageToProcess in chat.Messages - .Where(x => x.Properties.ContainsKey(Message.UnprocessedMessageProperty)) - .SkipLast(1)) + var messagesToProcess = hasTools + ? chat.Messages.SkipLast(1) + : chat.Messages.Where(x => x.Properties.ContainsKey(Message.UnprocessedMessageProperty)).SkipLast(1); + + foreach (var messageToProcess in messagesToProcess) { template.Add(messageToProcess.Role, messageToProcess.Content); } } + if (hasTools && isNewConversation) + { + var toolsPrompt = FormatToolsForPrompt(chat.ToolsConfiguration!); + finalPrompt = $"{toolsPrompt}\n\n{finalPrompt}"; + } + template.Add(ServiceConstants.Roles.User, finalPrompt); template.AddAssistant = true; @@ -340,6 +364,35 @@ private static void ProcessTextMessage(Conversation conversation, conversation.Prompt(tokens); } + private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig) + { + var toolsList = new StringBuilder(); + foreach (var tool in toolsConfig.Tools) + { + if (tool.Function == null) + { + continue; + } + + toolsList.AppendLine($"- {tool.Function.Name}: {tool.Function.Description}"); + toolsList.AppendLine($" Parameters: {JsonSerializer.Serialize(tool.Function.Parameters)}"); + } + + return $$$""" + ## TOOLS + You can call these tools if needed. To call a tool, respond with a JSON object inside tags. + + {{{toolsList}}} + + ## RESPONSE FORMAT (YOU HAVE TO CHOOSE ONE FORMAT AND CANNOT MIX THEM)## + 1. For normal conversation, just respond with plain text. + 2. For tool calls, use this format. You cannot respond with plain text before or after format. If you want to call multiple functions, you have to combine them into one array. Your response MUST contain only one tool call block: + + {"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "tool_name", "arguments": "{\"param\":\"value\"}"}},{"id": "call_2", "type": "function", "function": {"name": "tool2_name", "arguments": "{\"param1\":\"value1\",\"param2\":\"value2\"}"}}]} + + """; + } + private async Task<(List Tokens, bool IsComplete, bool HasFailed)> ProcessTokens( Chat chat, Conversation conversation, @@ -477,4 +530,181 @@ await notificationService.DispatchNotification( NotificationMessageBuilder.CreateChatCompletion(chatId, token, isComplete), ServiceConstants.Notifications.ReceiveMessageUpdate); } -} \ No newline at end of file + + private async Task ProcessWithToolsAsync( + Chat chat, + ChatRequestOptions requestOptions, + CancellationToken cancellationToken) + { + NativeLogConfig.llama_log_set((level, message) => { + if (level == LLamaLogLevel.Error) + { + Console.Error.Write(message); + } + }); // Remove llama native logging + + var model = KnownModels.GetModel(chat.Model); + var iterations = 0; + var lastResponseTokens = new List(); + var lastResponse = string.Empty; + + while (iterations < MaxToolIterations) + { + var lastMsg = chat.Messages.Last(); + await SendNotification(chat.Id, new LLMTokenValue + { + Type = TokenType.FullAnswer, + Text = $"Processing with tools... iteration {iterations + 1}\n\n" + }, false); + requestOptions.InteractiveUpdates = false; + lastResponseTokens = await ProcessChatRequest(chat, model, lastMsg, requestOptions, cancellationToken); + lastMsg.MarkProcessed(); + lastResponse = string.Concat(lastResponseTokens.Select(x => x.Text)); + var responseMessage = new Message + { + Content = lastResponse, + Role = AuthorRole.Assistant.ToString(), + Type = MessageType.LocalLLM, + }; + chat.Messages.Add(responseMessage.MarkProcessed()); + + var parseResult = ToolCallParser.ParseToolCalls(lastResponse); + + // Tool not found or invalid JSON + if (!parseResult.IsSuccess) + { + if (parseResult.ErrorMessage is not null) // Invalid JSON, self correction + { + var errorMsg = new Message + { + Content = $"System Error: The tool call JSON was invalid. {parseResult.ErrorMessage}. Please correct the JSON format.", + Role = ServiceConstants.Roles.Tool, + Type = MessageType.LocalLLM, + Tool = true + }; + chat.Messages.Add(errorMsg.MarkProcessed()); + + iterations++; + continue; + } + else // Final response + { + requestOptions.InteractiveUpdates = true; + await SendNotification(chat.Id, new LLMTokenValue + { + Type = TokenType.FullAnswer, + Text = lastResponse + }, false); + break; + } + } + + var toolCalls = parseResult.ToolCalls!; + responseMessage.Properties[ToolCallsProperty] = JsonSerializer.Serialize(toolCalls); + + foreach (var toolCall in toolCalls) + { + if (chat.Properties.CheckProperty(ServiceConstants.Properties.AgentIdProperty)) + { + await notificationService.DispatchNotification( + NotificationMessageBuilder.ProcessingTools( + chat.Properties[ServiceConstants.Properties.AgentIdProperty], + string.Empty, + toolCall.Function.Name), + ServiceConstants.Notifications.ReceiveAgentUpdate); + } + + var executor = chat.ToolsConfiguration?.GetExecutor(toolCall.Function.Name); + + if (executor == null) + { + var errorMessage = $"No executor found for tool: {toolCall.Function.Name}"; + throw new InvalidOperationException(errorMessage); + } + + + try + { + if (requestOptions.ToolCallback is not null) + { + await requestOptions.ToolCallback.Invoke(new ToolInvocation + { + ToolName = toolCall.Function.Name, + Arguments = toolCall.Function.Arguments, + Done = false + }); + } + + var toolResult = await executor(toolCall.Function.Arguments); + + if (requestOptions.ToolCallback is not null) + { + await requestOptions.ToolCallback.Invoke(new ToolInvocation + { + ToolName = toolCall.Function.Name, + Arguments = toolCall.Function.Arguments, + Done = true + }); + } + + var toolMessage = new Message + { + Content = $"Tool result for {toolCall.Function.Name}: {toolResult}", + Role = ServiceConstants.Roles.Tool, + Type = MessageType.LocalLLM, + Tool = true + }; + toolMessage.Properties[ToolCallIdProperty] = toolCall.Id; + toolMessage.Properties[ToolNameProperty] = toolCall.Function.Name; + chat.Messages.Add(toolMessage.MarkProcessed()); + } + catch (Exception ex) + { + var errorResult = JsonSerializer.Serialize(new { error = ex.Message }); + var toolMessage = new Message + { + Content = $"Tool error for {toolCall.Function.Name}: {errorResult}", + Role = ServiceConstants.Roles.Tool, + Type = MessageType.LocalLLM, + Tool = true + }; + toolMessage.Properties[ToolCallIdProperty] = toolCall.Id; + toolMessage.Properties[ToolNameProperty] = toolCall.Function.Name; + chat.Messages.Add(toolMessage.MarkProcessed()); + } + } + + iterations++; + } + + if (iterations >= MaxToolIterations) + { + var errorMessage = "Maximum tool invocation iterations reached. Ending the tool-loop prematurely."; + var iterationMessage = new Message + { + Content = errorMessage, + Role = AuthorRole.System.ToString(), + Type = MessageType.LocalLLM, + }; + chat.Messages.Add(iterationMessage.MarkProcessed()); + + await SendNotification(chat.Id, new LLMTokenValue + { + Type = TokenType.FullAnswer, + Text = errorMessage + }, false); + } + + return new ChatResult + { + Done = true, + CreatedAt = DateTime.Now, + Model = chat.Model, + Message = chat.Messages.Last() + }; + } + + private const string ToolCallsProperty = "ToolCalls"; + private const string ToolCallIdProperty = "ToolCallId"; + private const string ToolNameProperty = "ToolName"; +} diff --git a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs index 1c3fc35..b0ae2d1 100644 --- a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs +++ b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs @@ -1001,41 +1001,6 @@ private static string DetectImageMimeType(byte[] imageBytes) } } - -public class ToolDefinition -{ - public string Type { get; set; } = "function"; - public FunctionDefinition Function { get; set; } = null!; - - [System.Text.Json.Serialization.JsonIgnore] - public Func>? Execute { get; set; } -} - -public class FunctionDefinition -{ - public string Name { get; set; } = null!; - public string? Description { get; set; } - public object Parameters { get; set; } = null!; -} - -public class ToolCall -{ - [JsonPropertyName("id")] - public string Id { get; set; } = null!; - [JsonPropertyName("type")] - public string Type { get; set; } = "function"; - [JsonPropertyName("function")] - public FunctionCall Function { get; set; } = null!; -} - -public class FunctionCall -{ - [JsonPropertyName("name")] - public string Name { get; set; } = null!; - [JsonPropertyName("arguments")] - public string Arguments { get; set; } = null!; -} - internal class ChatMessage { public string Role { get; set; } diff --git a/src/MaIN.Services/Services/LLMService/Utils/ToolCallsHelper.cs b/src/MaIN.Services/Services/LLMService/Utils/ToolCallsHelper.cs new file mode 100644 index 0000000..41df73f --- /dev/null +++ b/src/MaIN.Services/Services/LLMService/Utils/ToolCallsHelper.cs @@ -0,0 +1,94 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using MaIN.Domain.Entities.Tools; + +namespace MaIN.Services.Services.LLMService.Utils; + +public static class ToolCallParser +{ + private static readonly JsonSerializerOptions JsonOptions = new() + { + PropertyNameCaseInsensitive = true, + }; + + public static ToolParseResult ParseToolCalls(string response) + { + if (string.IsNullOrWhiteSpace(response)) + return ToolParseResult.Failure("Response is empty."); + + var jsonContent = ExtractJsonContent(response); + + if (string.IsNullOrEmpty(jsonContent)) + return ToolParseResult.ToolNotFound(); + + try + { + var wrapper = JsonSerializer.Deserialize(jsonContent, JsonOptions); + + if (wrapper?.ToolCalls is not null && wrapper.ToolCalls.Count != 0) + return ToolParseResult.Success(NormalizeToolCalls(wrapper.ToolCalls)); + + return ToolParseResult.Failure("JSON parsed correctly but 'tool_calls' property is missing or empty."); + } + catch (JsonException ex) + { + return ToolParseResult.Failure($"Invalid JSON format: {ex.Message}"); + } + } + + private static string? ExtractJsonContent(string text) + { + text = text.Trim(); + + var firstBrace = text.IndexOf('{'); + var firstBracket = text.IndexOf('['); + var startIndex = (firstBrace >= 0 && firstBracket >= 0) + ? Math.Min(firstBrace, firstBracket) + : Math.Max(firstBrace, firstBracket); + + var lastBrace = text.LastIndexOf('}'); + var lastBracket = text.LastIndexOf(']'); + var endIndex = Math.Max(lastBrace, lastBracket); + + if (startIndex >= 0 && endIndex > startIndex) + return text.Substring(startIndex, endIndex - startIndex + 1); + + return null; + } + + private static List NormalizeToolCalls(List? calls) + { + if (calls is null) + return []; + + var normalizedCalls = new List(); + + foreach (var call in calls) + { + var id = string.IsNullOrEmpty(call.Id) ? Guid.NewGuid().ToString()[..8] : call.Id; + var type = string.IsNullOrEmpty(call.Type) ? "function" : call.Type; + var function = call.Function ?? new FunctionCall(); + + normalizedCalls.Add(call with { Id = id, Type = type, Function = function }); + } + + return normalizedCalls; + } + + private sealed record ToolResponseWrapper + { + [JsonPropertyName("tool_calls")] + public List? ToolCalls { get; init; } + } +} + +public record ToolParseResult +{ + public bool IsSuccess { get; init; } + public List? ToolCalls { get; init; } + public string? ErrorMessage { get; init; } + + public static ToolParseResult Success(List calls) => new() { IsSuccess = true, ToolCalls = calls }; + public static ToolParseResult Failure(string error) => new() { IsSuccess = false, ErrorMessage = error }; + public static ToolParseResult ToolNotFound() => new() { IsSuccess = false }; +}