Skip to content

Commit 2f57d8a

Browse files
fix(kb-perms): search tool perms to use new system (#786)
* fix(kb): search tool perms * fix tests
1 parent af1c7dc commit 2f57d8a

File tree

3 files changed

+90
-40
lines changed

3 files changed

+90
-40
lines changed

apps/sim/app/api/knowledge/search/route.test.ts

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ vi.mock('@/providers/utils', () => ({
5151
}),
5252
}))
5353

54+
const mockCheckKnowledgeBaseAccess = vi.fn()
55+
vi.mock('@/app/api/knowledge/utils', () => ({
56+
checkKnowledgeBaseAccess: mockCheckKnowledgeBaseAccess,
57+
}))
58+
5459
mockConsoleLogger()
5560

5661
describe('Knowledge Search API Route', () => {
@@ -132,7 +137,11 @@ describe('Knowledge Search API Route', () => {
132137
it('should perform search successfully with single knowledge base', async () => {
133138
mockGetUserId.mockResolvedValue('user-123')
134139

135-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
140+
// Mock knowledge base access check to return success
141+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
142+
hasAccess: true,
143+
knowledgeBase: mockKnowledgeBases[0],
144+
})
136145

137146
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
138147

@@ -149,6 +158,10 @@ describe('Knowledge Search API Route', () => {
149158
const response = await POST(req)
150159
const data = await response.json()
151160

161+
if (response.status !== 200) {
162+
console.log('Test failed with response:', data)
163+
}
164+
152165
expect(response.status).toBe(200)
153166
expect(data.success).toBe(true)
154167
expect(data.data.results).toHaveLength(2)
@@ -171,7 +184,10 @@ describe('Knowledge Search API Route', () => {
171184

172185
mockGetUserId.mockResolvedValue('user-123')
173186

174-
mockDbChain.where.mockResolvedValueOnce(multiKbs)
187+
// Mock knowledge base access check to return success for both KBs
188+
mockCheckKnowledgeBaseAccess
189+
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[0] })
190+
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[1] })
175191

176192
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
177193

@@ -201,9 +217,13 @@ describe('Knowledge Search API Route', () => {
201217

202218
mockGetUserId.mockResolvedValue('user-123')
203219

204-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
220+
// Mock knowledge base access check to return success
221+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
222+
hasAccess: true,
223+
knowledgeBase: mockKnowledgeBases[0],
224+
})
205225

206-
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
226+
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
207227

208228
mockFetch.mockResolvedValue({
209229
ok: true,
@@ -255,7 +275,11 @@ describe('Knowledge Search API Route', () => {
255275
it('should return not found for non-existent knowledge base', async () => {
256276
mockGetUserId.mockResolvedValue('user-123')
257277

258-
mockDbChain.where.mockResolvedValueOnce([]) // No knowledge bases found
278+
// Mock knowledge base access check to return no access
279+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
280+
hasAccess: false,
281+
notFound: true,
282+
})
259283

260284
const req = createMockRequest('POST', validSearchData)
261285
const { POST } = await import('./route')
@@ -274,15 +298,18 @@ describe('Knowledge Search API Route', () => {
274298

275299
mockGetUserId.mockResolvedValue('user-123')
276300

277-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // Only kb-123 found
301+
// Mock access check: first KB has access, second doesn't
302+
mockCheckKnowledgeBaseAccess
303+
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: mockKnowledgeBases[0] })
304+
.mockResolvedValueOnce({ hasAccess: false, notFound: true })
278305

279306
const req = createMockRequest('POST', multiKbData)
280307
const { POST } = await import('./route')
281308
const response = await POST(req)
282309
const data = await response.json()
283310

284311
expect(response.status).toBe(404)
285-
expect(data.error).toBe('Knowledge bases not found: kb-missing')
312+
expect(data.error).toBe('Knowledge bases not found or access denied: kb-missing')
286313
})
287314

288315
it.concurrent('should validate search parameters', async () => {
@@ -310,9 +337,13 @@ describe('Knowledge Search API Route', () => {
310337

311338
mockGetUserId.mockResolvedValue('user-123')
312339

313-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
340+
// Mock knowledge base access check to return success
341+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
342+
hasAccess: true,
343+
knowledgeBase: mockKnowledgeBases[0],
344+
})
314345

315-
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
346+
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
316347

317348
mockFetch.mockResolvedValue({
318349
ok: true,
@@ -416,7 +447,13 @@ describe('Knowledge Search API Route', () => {
416447
describe('Cost tracking', () => {
417448
it.concurrent('should include cost information in successful search response', async () => {
418449
mockGetUserId.mockResolvedValue('user-123')
419-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
450+
451+
// Mock knowledge base access check to return success
452+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
453+
hasAccess: true,
454+
knowledgeBase: mockKnowledgeBases[0],
455+
})
456+
420457
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
421458

422459
mockFetch.mockResolvedValue({
@@ -458,7 +495,13 @@ describe('Knowledge Search API Route', () => {
458495
const { calculateCost } = await import('@/providers/utils')
459496

460497
mockGetUserId.mockResolvedValue('user-123')
461-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
498+
499+
// Mock knowledge base access check to return success
500+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
501+
hasAccess: true,
502+
knowledgeBase: mockKnowledgeBases[0],
503+
})
504+
462505
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
463506

464507
mockFetch.mockResolvedValue({
@@ -509,7 +552,13 @@ describe('Knowledge Search API Route', () => {
509552
}
510553

511554
mockGetUserId.mockResolvedValue('user-123')
512-
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
555+
556+
// Mock knowledge base access check to return success
557+
mockCheckKnowledgeBaseAccess.mockResolvedValue({
558+
hasAccess: true,
559+
knowledgeBase: mockKnowledgeBases[0],
560+
})
561+
513562
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
514563

515564
mockFetch.mockResolvedValue({

apps/sim/app/api/knowledge/search/route.ts

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import { and, eq, inArray, isNull, sql } from 'drizzle-orm'
1+
import { and, eq, inArray, sql } from 'drizzle-orm'
22
import { type NextRequest, NextResponse } from 'next/server'
33
import { z } from 'zod'
44
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
55
import { env } from '@/lib/env'
66
import { createLogger } from '@/lib/logs/console-logger'
77
import { estimateTokenCount } from '@/lib/tokenization/estimators'
88
import { getUserId } from '@/app/api/auth/oauth/utils'
9+
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
910
import { db } from '@/db'
10-
import { embedding, knowledgeBase } from '@/db/schema'
11+
import { embedding } from '@/db/schema'
1112
import { calculateCost } from '@/providers/utils'
1213

1314
const logger = createLogger('VectorSearchAPI')
@@ -261,47 +262,45 @@ export async function POST(request: NextRequest) {
261262
? validatedData.knowledgeBaseIds
262263
: [validatedData.knowledgeBaseIds]
263264

264-
const [kb, queryEmbedding] = await Promise.all([
265-
db
266-
.select()
267-
.from(knowledgeBase)
268-
.where(
269-
and(
270-
inArray(knowledgeBase.id, knowledgeBaseIds),
271-
eq(knowledgeBase.userId, userId),
272-
isNull(knowledgeBase.deletedAt)
273-
)
274-
),
275-
generateSearchEmbedding(validatedData.query),
276-
])
277-
278-
if (kb.length === 0) {
265+
// Check access permissions for each knowledge base using proper workspace-based permissions
266+
const accessibleKbIds: string[] = []
267+
for (const kbId of knowledgeBaseIds) {
268+
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
269+
if (accessCheck.hasAccess) {
270+
accessibleKbIds.push(kbId)
271+
}
272+
}
273+
274+
if (accessibleKbIds.length === 0) {
279275
return NextResponse.json(
280276
{ error: 'Knowledge base not found or access denied' },
281277
{ status: 404 }
282278
)
283279
}
284280

285-
const foundKbIds = kb.map((k) => k.id)
286-
const missingKbIds = knowledgeBaseIds.filter((id) => !foundKbIds.includes(id))
281+
// Generate query embedding in parallel with access checks
282+
const queryEmbedding = await generateSearchEmbedding(validatedData.query)
283+
284+
// Check if any requested knowledge bases were not accessible
285+
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
287286

288-
if (missingKbIds.length > 0) {
287+
if (inaccessibleKbIds.length > 0) {
289288
return NextResponse.json(
290-
{ error: `Knowledge bases not found: ${missingKbIds.join(', ')}` },
289+
{ error: `Knowledge bases not found or access denied: ${inaccessibleKbIds.join(', ')}` },
291290
{ status: 404 }
292291
)
293292
}
294293

295-
// Adaptive query strategy based on KB count and parameters
296-
const strategy = getQueryStrategy(foundKbIds.length, validatedData.topK)
294+
// Adaptive query strategy based on accessible KB count and parameters
295+
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
297296
const queryVector = JSON.stringify(queryEmbedding)
298297

299298
let results: any[]
300299

301300
if (strategy.useParallel) {
302301
// Execute parallel queries for better performance with many KBs
303302
const parallelResults = await executeParallelQueries(
304-
foundKbIds,
303+
accessibleKbIds,
305304
queryVector,
306305
validatedData.topK,
307306
strategy.distanceThreshold,
@@ -311,7 +310,7 @@ export async function POST(request: NextRequest) {
311310
} else {
312311
// Execute single optimized query for fewer KBs
313312
results = await executeSingleQuery(
314-
foundKbIds,
313+
accessibleKbIds,
315314
queryVector,
316315
validatedData.topK,
317316
strategy.distanceThreshold,
@@ -350,8 +349,8 @@ export async function POST(request: NextRequest) {
350349
similarity: 1 - result.distance,
351350
})),
352351
query: validatedData.query,
353-
knowledgeBaseIds: foundKbIds,
354-
knowledgeBaseId: foundKbIds[0],
352+
knowledgeBaseIds: accessibleKbIds,
353+
knowledgeBaseId: accessibleKbIds[0],
355354
topK: validatedData.topK,
356355
totalResults: results.length,
357356
...(cost && tokenCount

apps/sim/executor/handlers/workflow/workflow-handler.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ export class WorkflowBlockHandler implements BlockHandler {
9393
})
9494

9595
const startTime = performance.now()
96-
const result = await subExecutor.execute(executionId)
96+
// Use the actual child workflow ID for authentication, not the execution ID
97+
// This ensures knowledge base and other API calls can properly authenticate
98+
const result = await subExecutor.execute(workflowId)
9799
const duration = performance.now() - startTime
98100

99101
// Remove current execution from stack after completion

0 commit comments

Comments
 (0)