Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/app/endpoints/conversations_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None:

def transform_chat_message(entry: CacheEntry) -> dict[str, Any]:
"""Transform the message read from cache into format used by response payload."""
user_message = {"content": entry.query, "type": "user"}
user_message: dict[str, Any] = {"content": entry.query, "type": "user"}

# Add attachments to user message if present
if entry.attachments:
user_message["attachments"] = [
att.model_dump(mode="json") for att in entry.attachments
]

assistant_message: dict[str, Any] = {"content": entry.response, "type": "assistant"}

# If referenced_documents exist on the entry, add them to the assistant message
Expand Down
1 change: 1 addition & 0 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
referenced_documents=referenced_documents if referenced_documents else None,
tool_calls=summary.tool_calls if summary.tool_calls else None,
tool_results=summary.tool_results if summary.tool_results else None,
attachments=query_request.attachments or None,
)

consume_tokens(
Expand Down
41 changes: 38 additions & 3 deletions src/cache/postgres_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cache.cache_error import CacheError
from models.cache_entry import CacheEntry
from models.config import PostgreSQLDatabaseConfiguration
from models.requests import Attachment
from models.responses import ConversationData, ReferencedDocument
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
Expand Down Expand Up @@ -36,6 +37,7 @@ class PostgresCache(Cache):
referenced_documents | jsonb | |
tool_calls | jsonb | |
tool_results | jsonb | |
attachments | jsonb | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"timestamps" btree (created_at)
Expand All @@ -60,6 +62,7 @@ class PostgresCache(Cache):
referenced_documents jsonb,
tool_calls jsonb,
tool_results jsonb,
attachments jsonb,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -81,7 +84,7 @@ class PostgresCache(Cache):

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
referenced_documents, tool_calls, tool_results, attachments
FROM cache
WHERE user_id=%s AND conversation_id=%s
ORDER BY created_at
Expand All @@ -90,8 +93,8 @@ class PostgresCache(Cache):
INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s)
tool_calls, tool_results, attachments)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -301,6 +304,22 @@ def get( # pylint: disable=R0914
e,
)

# Parse attachments back into Attachment objects
attachments_data = conversation_entry[9]
attachments_obj = None
if attachments_data:
try:
attachments_obj = [
Attachment.model_validate(att) for att in attachments_data
]
except (ValueError, TypeError) as e:
logger.warning(
"Failed to deserialize attachments for "
"conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -311,6 +330,7 @@ def get( # pylint: disable=R0914
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
attachments=attachments_obj,
)
result.append(cache_entry)

Expand Down Expand Up @@ -382,6 +402,20 @@ def insert_or_append(
e,
)

attachments_json = None
if cache_entry.attachments:
try:
attachments_as_dicts = [
att.model_dump(mode="json") for att in cache_entry.attachments
]
attachments_json = json.dumps(attachments_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize attachments for conversation %s: %s",
conversation_id,
e,
)

# the whole operation is run in one transaction
with self.connection.cursor() as cursor:
cursor.execute(
Expand All @@ -398,6 +432,7 @@ def insert_or_append(
referenced_documents_json,
tool_calls_json,
tool_results_json,
attachments_json,
),
)

Expand Down
108 changes: 62 additions & 46 deletions src/cache/sqlite_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Cache that uses SQLite to store cached values."""

from time import time
from typing import Any

import sqlite3
import json
Expand All @@ -9,6 +10,7 @@
from cache.cache_error import CacheError
from models.cache_entry import CacheEntry
from models.config import SQLiteDatabaseConfiguration
from models.requests import Attachment
from models.responses import ConversationData, ReferencedDocument
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
Expand Down Expand Up @@ -37,6 +39,7 @@ class SQLiteCache(Cache):
referenced_documents | text | |
tool_calls | text | |
tool_results | text | |
attachments | text | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"cache_key_key" UNIQUE CONSTRAINT, btree (key)
Expand All @@ -59,6 +62,7 @@ class SQLiteCache(Cache):
referenced_documents text,
tool_calls text,
tool_results text,
attachments text,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -80,7 +84,7 @@ class SQLiteCache(Cache):

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
referenced_documents, tool_calls, tool_results, attachments
FROM cache
WHERE user_id=? AND conversation_id=?
ORDER BY created_at
Expand All @@ -89,8 +93,8 @@ class SQLiteCache(Cache):
INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
tool_calls, tool_results, attachments)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -268,6 +272,21 @@ def get( # pylint: disable=R0914
e,
)

attachments_json_str = conversation_entry[9]
attachments_obj = None
if attachments_json_str:
try:
attachments_data = json.loads(attachments_json_str)
attachments_obj = [
Attachment.model_validate(att) for att in attachments_data
]
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to deserialize attachments for conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -278,11 +297,38 @@ def get( # pylint: disable=R0914
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
attachments=attachments_obj,
)
result.append(cache_entry)

return result

def _serialize_json_field(self, obj: Any, field_name: str) -> str | None:
"""Serialize a Pydantic model or list to JSON.

Args:
obj: The object or list to serialize.
field_name: The name of the field (for logging).

Returns:
JSON string or None if serialization fails.
"""
if not obj:
return None
try:
if isinstance(obj, list):
obj_as_dicts = [item.model_dump(mode="json") for item in obj]
else:
obj_as_dicts = obj.model_dump(mode="json")
return json.dumps(obj_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize %s: %s",
field_name,
e,
)
return None

@connection
def insert_or_append(
self,
Expand All @@ -307,49 +353,18 @@ def insert_or_append(
cursor = self.connection.cursor()
current_time = time()

referenced_documents_json = None
if cache_entry.referenced_documents:
try:
docs_as_dicts = [
doc.model_dump(mode="json")
for doc in cache_entry.referenced_documents
]
referenced_documents_json = json.dumps(docs_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize referenced_documents for "
"conversation %s: %s",
conversation_id,
e,
)

tool_calls_json = None
if cache_entry.tool_calls:
try:
tool_calls_as_dicts = [
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
]
tool_calls_json = json.dumps(tool_calls_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_calls for conversation %s: %s",
conversation_id,
e,
)

tool_results_json = None
if cache_entry.tool_results:
try:
tool_results_as_dicts = [
tr.model_dump(mode="json") for tr in cache_entry.tool_results
]
tool_results_json = json.dumps(tool_results_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_results for conversation %s: %s",
conversation_id,
e,
)
referenced_documents_json = self._serialize_json_field(
cache_entry.referenced_documents, "referenced_documents"
)
tool_calls_json = self._serialize_json_field(
cache_entry.tool_calls, "tool_calls"
)
tool_results_json = self._serialize_json_field(
cache_entry.tool_results, "tool_results"
)
attachments_json = self._serialize_json_field(
cache_entry.attachments, "attachments"
)

cursor.execute(
self.INSERT_CONVERSATION_HISTORY_STATEMENT,
Expand All @@ -366,6 +381,7 @@ def insert_or_append(
referenced_documents_json,
tool_calls_json,
tool_results_json,
attachments_json,
),
)

Expand Down
5 changes: 5 additions & 0 deletions src/models/cache_entry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Model for conversation history cache entry."""

from typing import Optional

from pydantic import BaseModel

from models.requests import Attachment
from models.responses import ReferencedDocument
from utils.types import ToolCallSummary, ToolResultSummary

Expand All @@ -17,6 +20,7 @@ class CacheEntry(BaseModel):
referenced_documents: List of documents referenced by the response
tool_calls: List of tool calls made during response generation
tool_results: List of tool results from tool calls
attachments: List of attachments included with the user query
"""

query: str
Expand All @@ -28,3 +32,4 @@ class CacheEntry(BaseModel):
referenced_documents: Optional[list[ReferencedDocument]] = None
tool_calls: Optional[list[ToolCallSummary]] = None
tool_results: Optional[list[ToolResultSummary]] = None
attachments: list[Attachment] | None = None
Loading
Loading