Skip to content

Commit a06bf27

Browse files
hangfeicopybara-github
authored andcommitted
feat: Adding the ContextFilterPlugin
This commit introduces a new ContextFilterPlugin which allows for filtering the LlmRequest contents before they are sent to the LLM. This helps in managing and potentially reducing the size of the LLM context. The plugin provides two primary filtering mechanisms: num_invocations_to_keep: Keeps only the specified number of the most recent user-model invocations. An invocation is defined as one or more user messages followed by a model response. custom_filter: Allows for a user-defined callable to be applied to the contents for more flexible filtering. Unit tests have been added to cover the different filtering scenarios, including: Filtering by the last N invocations. Filtering using a custom function. Combining both filtering methods. Handling cases with multiple user turns in a single invocation. Ensuring no filtering occurs when options are not provided. Gracefully handling exceptions from custom filter functions." For example, when num_of_innovacations=2: ----------------------------------------------------------- Contents: {"parts":[{"text":"9"}],"role":"user"} {"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"} {"parts":[{"text":"1"}],"role":"user"} {"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"} {"parts":[{"text":"10"}],"role":"user"} ----------------------------------------------------------- PiperOrigin-RevId: 808355316
1 parent 10cf377 commit a06bf27

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Callable
19+
from typing import List
20+
from typing import Optional
21+
22+
from ..agents.callback_context import CallbackContext
23+
from ..events.event import Event
24+
from ..models.llm_request import LlmRequest
25+
from ..models.llm_response import LlmResponse
26+
from .base_plugin import BasePlugin
27+
28+
logger = logging.getLogger("google_adk." + __name__)
29+
30+
31+
class ContextFilterPlugin(BasePlugin):
32+
"""A plugin that filters the LLM context to reduce its size."""
33+
34+
def __init__(
35+
self,
36+
num_invocations_to_keep: Optional[int] = None,
37+
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
38+
name: str = "context_filter_plugin",
39+
):
40+
"""Initializes the context management plugin.
41+
42+
Args:
43+
num_invocations_to_keep: The number of last invocations to keep. An
44+
invocation is defined as one or more consecutive user messages followed
45+
by a model response.
46+
custom_filter: A function to filter the context.
47+
name: The name of the plugin instance.
48+
"""
49+
super().__init__(name)
50+
self._num_invocations_to_keep = num_invocations_to_keep
51+
self._custom_filter = custom_filter
52+
53+
async def before_model_callback(
54+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
55+
) -> Optional[LlmResponse]:
56+
"""Filters the LLM request's context before it is sent to the model."""
57+
try:
58+
contents = llm_request.contents
59+
60+
if (
61+
self._num_invocations_to_keep is not None
62+
and self._num_invocations_to_keep > 0
63+
):
64+
num_model_turns = sum(1 for c in contents if c.role == "model")
65+
if num_model_turns >= self._num_invocations_to_keep:
66+
model_turns_to_find = self._num_invocations_to_keep
67+
split_index = 0
68+
for i in range(len(contents) - 1, -1, -1):
69+
if contents[i].role == "model":
70+
model_turns_to_find -= 1
71+
if model_turns_to_find == 0:
72+
start_index = i
73+
while (
74+
start_index > 0 and contents[start_index - 1].role == "user"
75+
):
76+
start_index -= 1
77+
split_index = start_index
78+
break
79+
contents = contents[split_index:]
80+
81+
if self._custom_filter:
82+
contents = self._custom_filter(contents)
83+
84+
llm_request.contents = contents
85+
except Exception as e:
86+
logger.error(f"Failed to reduce context for request: {e}")
87+
88+
return None
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for the ContextFilteringPlugin."""
16+
17+
from unittest.mock import Mock
18+
19+
from google.adk.agents.callback_context import CallbackContext
20+
from google.adk.models.llm_request import LlmRequest
21+
from google.adk.plugins.context_filter_plugin import ContextFilterPlugin
22+
from google.genai import types
23+
import pytest
24+
25+
26+
def _create_content(role: str, text: str) -> types.Content:
27+
return types.Content(parts=[types.Part(text=text)], role=role)
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_filter_last_n_invocations():
32+
"""Tests that the context is truncated to the last N invocations."""
33+
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
34+
contents = [
35+
_create_content("user", "user_prompt_1"),
36+
_create_content("model", "model_response_1"),
37+
_create_content("user", "user_prompt_2"),
38+
_create_content("model", "model_response_2"),
39+
]
40+
llm_request = LlmRequest(contents=contents)
41+
42+
await plugin.before_model_callback(
43+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
44+
)
45+
46+
assert len(llm_request.contents) == 2
47+
assert llm_request.contents[0].parts[0].text == "user_prompt_2"
48+
assert llm_request.contents[1].parts[0].text == "model_response_2"
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_filter_with_function():
53+
"""Tests that a custom filter function is applied to the context."""
54+
55+
def remove_model_responses(contents):
56+
return [c for c in contents if c.role != "model"]
57+
58+
plugin = ContextFilterPlugin(custom_filter=remove_model_responses)
59+
contents = [
60+
_create_content("user", "user_prompt_1"),
61+
_create_content("model", "model_response_1"),
62+
_create_content("user", "user_prompt_2"),
63+
_create_content("model", "model_response_2"),
64+
]
65+
llm_request = LlmRequest(contents=contents)
66+
67+
await plugin.before_model_callback(
68+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
69+
)
70+
71+
assert len(llm_request.contents) == 2
72+
assert all(c.role == "user" for c in llm_request.contents)
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_filter_with_function_and_last_n_invocations():
77+
"""Tests that both filtering methods are applied correctly."""
78+
79+
def remove_first_invocation(contents):
80+
return contents[2:]
81+
82+
plugin = ContextFilterPlugin(
83+
num_invocations_to_keep=1, custom_filter=remove_first_invocation
84+
)
85+
contents = [
86+
_create_content("user", "user_prompt_1"),
87+
_create_content("model", "model_response_1"),
88+
_create_content("user", "user_prompt_2"),
89+
_create_content("model", "model_response_2"),
90+
_create_content("user", "user_prompt_3"),
91+
_create_content("model", "model_response_3"),
92+
]
93+
llm_request = LlmRequest(contents=contents)
94+
95+
await plugin.before_model_callback(
96+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
97+
)
98+
99+
assert len(llm_request.contents) == 0
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_no_filtering_when_no_options_provided():
104+
"""Tests that no filtering occurs when no options are provided."""
105+
plugin = ContextFilterPlugin()
106+
contents = [
107+
_create_content("user", "user_prompt_1"),
108+
_create_content("model", "model_response_1"),
109+
]
110+
llm_request = LlmRequest(contents=contents)
111+
original_contents = list(llm_request.contents)
112+
113+
await plugin.before_model_callback(
114+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
115+
)
116+
117+
assert llm_request.contents == original_contents
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_last_n_invocations_with_multiple_user_turns():
122+
"""Tests filtering with multiple user turns in a single invocation."""
123+
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
124+
contents = [
125+
_create_content("user", "user_prompt_1"),
126+
_create_content("model", "model_response_1"),
127+
_create_content("user", "user_prompt_2a"),
128+
_create_content("user", "user_prompt_2b"),
129+
_create_content("model", "model_response_2"),
130+
]
131+
llm_request = LlmRequest(contents=contents)
132+
133+
await plugin.before_model_callback(
134+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
135+
)
136+
137+
assert len(llm_request.contents) == 3
138+
assert llm_request.contents[0].parts[0].text == "user_prompt_2a"
139+
assert llm_request.contents[1].parts[0].text == "user_prompt_2b"
140+
assert llm_request.contents[2].parts[0].text == "model_response_2"
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_last_n_invocations_more_than_existing_invocations():
145+
"""Tests that no filtering occurs if last_n_invocations is greater than
146+
147+
the number of invocations.
148+
"""
149+
plugin = ContextFilterPlugin(num_invocations_to_keep=3)
150+
contents = [
151+
_create_content("user", "user_prompt_1"),
152+
_create_content("model", "model_response_1"),
153+
_create_content("user", "user_prompt_2"),
154+
_create_content("model", "model_response_2"),
155+
]
156+
llm_request = LlmRequest(contents=contents)
157+
original_contents = list(llm_request.contents)
158+
159+
await plugin.before_model_callback(
160+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
161+
)
162+
163+
assert llm_request.contents == original_contents
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_filter_function_raises_exception():
168+
"""Tests that the plugin handles exceptions from the filter function."""
169+
170+
def faulty_filter(contents):
171+
raise ValueError("Filter error")
172+
173+
plugin = ContextFilterPlugin(custom_filter=faulty_filter)
174+
contents = [
175+
_create_content("user", "user_prompt_1"),
176+
_create_content("model", "model_response_1"),
177+
]
178+
llm_request = LlmRequest(contents=contents)
179+
original_contents = list(llm_request.contents)
180+
181+
await plugin.before_model_callback(
182+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
183+
)
184+
185+
assert llm_request.contents == original_contents

0 commit comments

Comments
 (0)