Skip to content

Commit 5b1f948

Browse files
feat(kb-tags): natural language pre-filter tag system for knowledge base searches (#800)
* fix lint * checkpoint * works * simplify * checkpoint * works * fix lint * checkpoint - create doc ui * working block * fix import conflicts * fix tests * add blockers to going past max tag slots * remove console logs * forgot a few * Update apps/sim/tools/knowledge/search.ts Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * remove console.warn * Update apps/sim/hooks/use-tag-definitions.ts Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * use tag slots consts in more places * remove duplicate title --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent cb17691 commit 5b1f948

File tree

27 files changed

+8021
-404
lines changed

27 files changed

+8021
-404
lines changed

apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.test.ts

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,15 @@ describe('Document By ID API Route', () => {
218218
}),
219219
}
220220

221+
// Mock transaction
222+
mockDbChain.transaction.mockImplementation(async (callback) => {
223+
const mockTx = {
224+
update: vi.fn().mockReturnValue(updateChain),
225+
}
226+
await callback(mockTx)
227+
})
228+
221229
// Mock db operations in sequence
222-
mockDbChain.update.mockReturnValue(updateChain)
223230
mockDbChain.select.mockReturnValue(selectChain)
224231

225232
const req = createMockRequest('PUT', validUpdateData)
@@ -231,7 +238,7 @@ describe('Document By ID API Route', () => {
231238
expect(data.success).toBe(true)
232239
expect(data.data.filename).toBe('updated-document.pdf')
233240
expect(data.data.enabled).toBe(false)
234-
expect(mockDbChain.update).toHaveBeenCalled()
241+
expect(mockDbChain.transaction).toHaveBeenCalled()
235242
expect(mockDbChain.select).toHaveBeenCalled()
236243
})
237244

@@ -298,8 +305,15 @@ describe('Document By ID API Route', () => {
298305
}),
299306
}
300307

308+
// Mock transaction
309+
mockDbChain.transaction.mockImplementation(async (callback) => {
310+
const mockTx = {
311+
update: vi.fn().mockReturnValue(updateChain),
312+
}
313+
await callback(mockTx)
314+
})
315+
301316
// Mock db operations in sequence
302-
mockDbChain.update.mockReturnValue(updateChain)
303317
mockDbChain.select.mockReturnValue(selectChain)
304318

305319
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
@@ -309,7 +323,7 @@ describe('Document By ID API Route', () => {
309323

310324
expect(response.status).toBe(200)
311325
expect(data.success).toBe(true)
312-
expect(mockDbChain.update).toHaveBeenCalled()
326+
expect(mockDbChain.transaction).toHaveBeenCalled()
313327
expect(updateChain.set).toHaveBeenCalledWith(
314328
expect.objectContaining({
315329
processingStatus: 'failed',
@@ -479,7 +493,9 @@ describe('Document By ID API Route', () => {
479493
document: mockDocument,
480494
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
481495
})
482-
mockDbChain.set.mockRejectedValue(new Error('Database error'))
496+
497+
// Mock transaction to throw an error
498+
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))
483499

484500
const req = createMockRequest('PUT', validUpdateData)
485501
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')

apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { eq } from 'drizzle-orm'
22
import { type NextRequest, NextResponse } from 'next/server'
33
import { z } from 'zod'
44
import { getSession } from '@/lib/auth'
5+
import { TAG_SLOTS } from '@/lib/constants/knowledge'
56
import { createLogger } from '@/lib/logs/console/logger'
67

78
export const dynamic = 'force-dynamic'
@@ -26,6 +27,14 @@ const UpdateDocumentSchema = z.object({
2627
processingError: z.string().optional(),
2728
markFailedDueToTimeout: z.boolean().optional(),
2829
retryProcessing: z.boolean().optional(),
30+
// Tag fields
31+
tag1: z.string().optional(),
32+
tag2: z.string().optional(),
33+
tag3: z.string().optional(),
34+
tag4: z.string().optional(),
35+
tag5: z.string().optional(),
36+
tag6: z.string().optional(),
37+
tag7: z.string().optional(),
2938
})
3039

3140
export async function GET(
@@ -213,9 +222,36 @@ export async function PUT(
213222
updateData.processingStatus = validatedData.processingStatus
214223
if (validatedData.processingError !== undefined)
215224
updateData.processingError = validatedData.processingError
225+
226+
// Tag field updates
227+
TAG_SLOTS.forEach((slot) => {
228+
if ((validatedData as any)[slot] !== undefined) {
229+
;(updateData as any)[slot] = (validatedData as any)[slot]
230+
}
231+
})
216232
}
217233

218-
await db.update(document).set(updateData).where(eq(document.id, documentId))
234+
await db.transaction(async (tx) => {
235+
// Update the document
236+
await tx.update(document).set(updateData).where(eq(document.id, documentId))
237+
238+
// If any tag fields were updated, also update the embeddings
239+
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)
240+
241+
if (hasTagUpdates) {
242+
const embeddingUpdateData: Record<string, string | null> = {}
243+
TAG_SLOTS.forEach((field) => {
244+
if ((validatedData as any)[field] !== undefined) {
245+
embeddingUpdateData[field] = (validatedData as any)[field] || null
246+
}
247+
})
248+
249+
await tx
250+
.update(embedding)
251+
.set(embeddingUpdateData)
252+
.where(eq(embedding.documentId, documentId))
253+
}
254+
})
219255

220256
// Fetch the updated document
221257
const updatedDocument = await db

0 commit comments

Comments
 (0)