Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/sqlite-ai.c
Original file line number Diff line number Diff line change
Expand Up @@ -1957,20 +1957,44 @@ static void llm_chat_save (sqlite3_context *context, int argc, sqlite3_value **a
// start transaction
sqlite_db_write_simple(context, db, "BEGIN;");

// save chat
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?);";
// save chat, the ON CONFLICT allows saving multiple times
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?) "
"ON CONFLICT(uuid) DO UPDATE SET "
" title = excluded.title, "
" metadata = excluded.metadata, "
" created_at = CURRENT_TIMESTAMP;";
const char *values[] = {ai->chat.uuid, title, meta};
int types[] = {SQLITE_TEXT, SQLITE_TEXT, SQLITE_TEXT};
int lens[] = {-1, -1, -1};

int rc = sqlite_db_write(context, db, sql, values, types, lens, 3);
if (rc != SQLITE_OK) goto abort_save;
// loop to save messages (the context)

// get the rowid, cannot use sqlite3_last_insert_rowid for the CONFLICT case
char rowid_s[256];
sqlite3_int64 rowid = sqlite3_last_insert_rowid(db);
sqlite3_stmt *pstmt = NULL;
sql = "SELECT id FROM ai_chat_history WHERE uuid = ?;";
rc = sqlite3_prepare_v2(db, sql, -1, &pstmt, NULL);
if (rc != SQLITE_OK) goto abort_save;
rc = sqlite3_bind_text(pstmt, 1, ai->chat.uuid, -1, SQLITE_STATIC);
rc = sqlite3_step(pstmt);
if (rc != SQLITE_ROW) {
sqlite3_finalize(pstmt);
goto abort_save;
}
sqlite3_int64 rowid = sqlite3_column_int64(pstmt, 0);
sqlite3_finalize(pstmt);
snprintf(rowid_s, sizeof(rowid_s), "%lld", (long long)rowid);

// delete all messages for this chat id, if any
sql = "DELETE FROM ai_chat_messages WHERE chat_id = ?;";
const char *values3[] = {rowid_s};
int types3[] = {SQLITE_INTEGER};
int lens3[] = {-1};
rc = sqlite_db_write(context, db, sql, values3, types3, lens3, 1);
if (rc != SQLITE_OK) goto abort_save;

// loop to save messages (the context)
sql = "INSERT INTO ai_chat_messages (chat_id, role, content) VALUES (?, ?, ?);";
int types2[] = {SQLITE_INTEGER, SQLITE_TEXT, SQLITE_TEXT};

Expand Down
2 changes: 1 addition & 1 deletion src/sqlite-ai.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
extern "C" {
#endif

#define SQLITE_AI_VERSION "0.7.58"
#define SQLITE_AI_VERSION "0.7.59"

SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);

Expand Down
115 changes: 115 additions & 0 deletions tests/c/unittest.c
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,120 @@ static int test_chat_system_prompt_after_first_response(const test_env *env) {
return status;
}

static int test_llm_chat_double_save(const test_env *env) {
sqlite3 *db = NULL;
bool model_loaded = false;
bool context_created = false;
bool chat_created = false;
int status = 1;

if (open_db_and_load(env, &db) != SQLITE_OK) {
goto done;
}

const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH;
char sqlbuf[512];
snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model);
if (exec_expect_ok(env, db, sqlbuf) != 0)
goto done;
model_loaded = true;

if (exec_expect_ok(env, db,
"SELECT llm_context_create('context_size=1000');") != 0)
goto done;
context_created = true;

if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0)
goto done;
chat_created = true;

// First prompt
const char *prompt1 = "First prompt";
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('First prompt');") != 0)
goto done;

// First save
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
goto done;

// Second prompt
const char *prompt2 = "Second prompt";
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('Second prompt');") != 0)
goto done;

// Second save
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
goto done;

ai_chat_message_row rows[8];
int count = 0;
// We expect 4 messages: User1, Assistant1, User2, Assistant2
if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0)
goto done;

if (count != 5) {
fprintf(stderr,
"[test_llm_chat_double_save] expected 4 message rows, got %d\n",
count);
goto done;
}

// Verify order and roles
if (strcmp(rows[0].role, "system") != 0 ||
strcmp(rows[0].content, "") != 0) {
fprintf(stderr,
"[test_llm_chat_double_save] row 0 mismatch (expected system/'%s', "
"got %s/'%s')\n",
"", rows[0].role, rows[0].content);
goto done;
}
if (strcmp(rows[1].role, "user") != 0 ||
strcmp(rows[1].content, prompt1) != 0) {
fprintf(stderr,
"[test_llm_chat_double_save] row 0 mismatch (expected user/'%s', "
"got %s/'%s')\n",
prompt1, rows[1].role, rows[1].content);
goto done;
}
if (strcmp(rows[2].role, "assistant") != 0) {
fprintf(stderr,
"[test_llm_chat_double_save] row 1 mismatch (expected assistant, "
"got %s)\n",
rows[2].role);
goto done;
}
if (strcmp(rows[3].role, "user") != 0 ||
strcmp(rows[3].content, prompt2) != 0) {
fprintf(stderr,
"[test_llm_chat_double_save] row 2 mismatch (expected user/'%s', "
"got %s/'%s')\n",
prompt2, rows[3].role, rows[3].content);
goto done;
}
if (strcmp(rows[4].role, "assistant") != 0) {
fprintf(stderr,
"[test_llm_chat_double_save] row 3 mismatch (expected assistant, "
"got %s)\n",
rows[4].role);
goto done;
}

status = 0;

done:
if (chat_created)
exec_expect_ok(env, db, "SELECT llm_chat_free();");
if (context_created)
exec_expect_ok(env, db, "SELECT llm_context_free();");
if (model_loaded)
exec_expect_ok(env, db, "SELECT llm_model_free();");
if (db)
sqlite3_close(db);
if (status == 0)
status = assert_sqlite_memory_clean("llm_chat_double_save", env);
return status;
}

static const test_case TESTS[] = {
{"issue15_llm_chat_without_context", test_issue15_chat_without_context},
{"llm_chat_respond_repeated", test_llm_chat_respond_repeated},
Expand All @@ -1026,6 +1140,7 @@ static const test_case TESTS[] = {
{"chat_system_prompt_new_chat", test_chat_system_prompt_new_chat},
{"chat_system_prompt_replace_previous_prompt", test_chat_system_prompt_replace_previous_prompt},
{"chat_system_prompt_after_first_response", test_chat_system_prompt_after_first_response},
{"llm_chat_double_save", test_llm_chat_double_save},
};

int main(int argc, char **argv) {
Expand Down
Loading