diff --git a/lib/ruby_llm/active_record/acts_as_legacy.rb b/lib/ruby_llm/active_record/acts_as_legacy.rb index 97679c126..1081e3941 100644 --- a/lib/ruby_llm/active_record/acts_as_legacy.rb +++ b/lib/ruby_llm/active_record/acts_as_legacy.rb @@ -151,6 +151,11 @@ def with_schema(...) self end + def with_messages(...) + to_llm.with_messages(...) + self + end + def on_new_message(&block) to_llm diff --git a/lib/ruby_llm/active_record/chat_methods.rb b/lib/ruby_llm/active_record/chat_methods.rb index 41930548c..765e9a922 100644 --- a/lib/ruby_llm/active_record/chat_methods.rb +++ b/lib/ruby_llm/active_record/chat_methods.rb @@ -139,6 +139,11 @@ def with_schema(...) self end + def with_messages(...) + to_llm.with_messages(...) + self + end + def on_new_message(&block) to_llm diff --git a/lib/ruby_llm/chat.rb b/lib/ruby_llm/chat.rb index d03d872ca..0967ed947 100644 --- a/lib/ruby_llm/chat.rb +++ b/lib/ruby_llm/chat.rb @@ -18,6 +18,7 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n with_model(model_id, provider: provider, assume_exists: assume_model_exists) @temperature = nil @messages = [] + @messages_scope = nil @tools = {} @params = {} @headers = {} @@ -97,6 +98,16 @@ def with_schema(schema) self end + def with_messages(&block) + if block.nil? + @messages_scope = nil + return self + end + + @messages_scope = block + self + end + def on_new_message(&block) @on[:new_message] = block self @@ -123,7 +134,7 @@ def each(&) def complete(&) # rubocop:disable Metrics/PerceivedComplexity response = @provider.complete( - messages, + scoped_messages, tools: @tools, temperature: @temperature, model: @model, @@ -169,6 +180,10 @@ def instance_variables private + def scoped_messages + @messages_scope ? @messages_scope.call(messages) : messages + end + def wrap_streaming_block(&block) return nil unless block_given? diff --git a/spec/ruby_llm/chat_messages_scope_spec.rb b/spec/ruby_llm/chat_messages_scope_spec.rb new file mode 100644 index 000000000..c53438cd1 --- /dev/null +++ b/spec/ruby_llm/chat_messages_scope_spec.rb @@ -0,0 +1,88 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Chat do + include_context 'with configured RubyLLM' + + describe '#with_messages' do + it 'uses all messages by default when no scope is configured' do + chat = RubyLLM.chat + provider = chat.instance_variable_get(:@provider) + + allow(provider).to receive(:complete).and_return( + RubyLLM::Message.new(role: :assistant, content: 'Test response') + ) + + chat.add_message(role: :user, content: 'one') + chat.add_message(role: :assistant, content: 'two') + + chat.complete + + expect(provider).to have_received(:complete).with( + chat.messages, + hash_including( + tools: chat.tools, + temperature: anything, + model: chat.model, + params: chat.params, + headers: chat.headers, + schema: chat.schema + ) + ) + end + + it 'applies a callable scope to the messages passed to the provider' do + chat = RubyLLM.chat + provider = chat.instance_variable_get(:@provider) + + allow(provider).to receive(:complete).and_return( + RubyLLM::Message.new(role: :assistant, content: 'Scoped response') + ) + + first = chat.add_message(role: :user, content: 'first') + second = chat.add_message(role: :assistant, content: 'second') + + chat.with_messages { |msgs| [msgs.last] } + + chat.complete + + expect(provider).to have_received(:complete) do |messages_arg, **_options| + expect(messages_arg).to eq([second]) + expect(messages_arg).not_to include(first) + end + end + + it 'can clear the scope by calling without a block' do + chat = RubyLLM.chat + provider = chat.instance_variable_get(:@provider) + + allow(provider).to receive(:complete).and_return( + RubyLLM::Message.new(role: :assistant, content: 'Cleared scope response') + ) + + chat.add_message(role: :user, content: 'one') + chat.add_message(role: :assistant, content: 'two') + + chat.with_messages { |msgs| [msgs.last] } + chat.with_messages + + chat.complete + + expect(provider).to have_received(:complete).with( + chat.messages, + hash_including( + tools: chat.tools, + temperature: anything, + model: chat.model, + params: chat.params, + headers: chat.headers, + schema: chat.schema + ) + ) + end + + end +end + +