From e7142321947fbdaf599e2b1472dfc915aeb0668d Mon Sep 17 00:00:00 2001 From: davidw-philips Date: Tue, 25 Nov 2025 10:03:13 +0000 Subject: [PATCH] Added in-memory locks for the process --- .../__tests__/lock-verification.test.ts | 303 ++++++++++++++++++ src/memory/__tests__/race-condition.test.ts | 213 ++++++++++++ src/memory/index.ts | 213 +++++++----- 3 files changed, 645 insertions(+), 84 deletions(-) create mode 100644 src/memory/__tests__/lock-verification.test.ts create mode 100644 src/memory/__tests__/race-condition.test.ts diff --git a/src/memory/__tests__/lock-verification.test.ts b/src/memory/__tests__/lock-verification.test.ts new file mode 100644 index 0000000000..6080334ef0 --- /dev/null +++ b/src/memory/__tests__/lock-verification.test.ts @@ -0,0 +1,303 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { promises as fs } from 'fs'; +import path from 'path'; +import { fileURLToPath } from 'url'; +import { KnowledgeGraphManager, Entity, Relation, KnowledgeGraph } from '../index.js'; + +/** + * This test suite verifies that the locking mechanism correctly prevents race conditions. + * It demonstrates that all concurrent operations complete successfully and maintain + * data integrity by preserving all writes without corruption. + * + * The fix uses an in-memory lock manager that serializes file operations, + * ensuring atomic read-modify-write cycles. + */ + +describe('Lock Mechanism Verification', () => { + let testFilePath: string; + + beforeEach(async () => { + testFilePath = path.join( + path.dirname(fileURLToPath(import.meta.url)), + `lock-verify-${Date.now()}.jsonl` + ); + }); + + afterEach(async () => { + try { + await fs.unlink(testFilePath); + } catch (error) { + // Ignore if file doesn't exist + } + }); + + it('should serialize concurrent writes to prevent corruption', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Define test data for multiple concurrent operations + const entity1: Entity = { + name: 'Process1_Entity', + entityType: 'Actor', + observations: ['created by process 1'] + }; + + const entity2: Entity = { + name: 'Process2_Entity', + entityType: 'Actor', + observations: ['created by process 2'] + }; + + // Execute concurrent entity creation + const [result1, result2] = await Promise.all([ + manager.createEntities([entity1]), + manager.createEntities([entity2]) + ]); + + // Both operations should succeed + expect(result1).toHaveLength(1); + expect(result2).toHaveLength(1); + + // Read the final graph + const finalGraph = await manager.readGraph(); + + // Verify both entities exist and no data was lost + expect(finalGraph.entities).toHaveLength(2); + const entityNames = finalGraph.entities.map(e => e.name); + expect(entityNames).toContain('Process1_Entity'); + expect(entityNames).toContain('Process2_Entity'); + + // Verify entity types and observations + const foundEntity1 = finalGraph.entities.find(e => e.name === 'Process1_Entity'); + const foundEntity2 = finalGraph.entities.find(e => e.name === 'Process2_Entity'); + expect(foundEntity1?.entityType).toBe('Actor'); + expect(foundEntity2?.entityType).toBe('Actor'); + expect(foundEntity1?.observations).toContain('created by process 1'); + expect(foundEntity2?.observations).toContain('created by process 2'); + }); + + it('should handle river-crossing scenario with locked writes', async () => { + /** + * Classic river-crossing puzzle scenario: + * - A man, goat, cabbage, and wolf need to cross a river + * - The boat can only carry the man and one other item + * - The goat and cabbage cannot be left alone + * - The goat and wolf cannot be left alone + */ + + const manager = new KnowledgeGraphManager(testFilePath); + + // Create entities concurrently + const entities: Entity[] = [ + { name: 'Human', entityType: 'Actor', observations: ['controls the boat'] }, + { name: 'Goat', entityType: 'Actor', observations: ['eats cabbage'] }, + { name: 'Cabbage', entityType: 'Item', observations: ['eaten by goat'] }, + { name: 'Wolf', entityType: 'Actor', observations: ['eats goat'] }, + { name: 'Start_Bank', entityType: 'Location', observations: ['initial position'] }, + { name: 'End_Bank', entityType: 'Location', observations: ['final destination'] } + ]; + + const entityResults = await Promise.all( + entities.map(entity => manager.createEntities([entity])) + ); + + // Verify all entities were created + entityResults.forEach(result => { + expect(result).toHaveLength(1); + }); + + // Create initial state relations + const initialRelations: Relation[] = [ + { from: 'Human', to: 'Start_Bank', relationType: 'is_at' }, + { from: 'Goat', to: 'Start_Bank', relationType: 'is_at' }, + { from: 'Cabbage', to: 'Start_Bank', relationType: 'is_at' }, + { from: 'Wolf', to: 'Start_Bank', relationType: 'is_at' } + ]; + + const relationResults = await Promise.all( + initialRelations.map(relation => manager.createRelations([relation])) + ); + + // Verify all relations were created + relationResults.forEach(result => { + expect(result).toHaveLength(1); + }); + + // Verify complete initial state + const initialState = await manager.readGraph(); + expect(initialState.entities).toHaveLength(6); + expect(initialState.relations).toHaveLength(4); + + // Verify all items are at start bank + const startBankItems = initialState.relations + .filter(r => r.to === 'Start_Bank') + .map(r => r.from); + expect(startBankItems).toContain('Human'); + expect(startBankItems).toContain('Goat'); + expect(startBankItems).toContain('Cabbage'); + expect(startBankItems).toContain('Wolf'); + }); + + it('should maintain JSONL format with concurrent writes', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Create multiple entities and relations concurrently + const createOps = [ + manager.createEntities([ + { name: 'Entity_A', entityType: 'Type_A', observations: ['obs_a'] }, + { name: 'Entity_B', entityType: 'Type_B', observations: ['obs_b'] } + ]), + manager.createEntities([ + { name: 'Entity_C', entityType: 'Type_C', observations: ['obs_c'] } + ]) + ]; + + await Promise.all(createOps); + + // Read the raw file + const rawContent = await fs.readFile(testFilePath, 'utf-8'); + const lines = rawContent.split('\n').filter(line => line.trim() !== ''); + + // Verify JSONL format: each non-empty line should be valid JSON + lines.forEach(line => { + expect(() => JSON.parse(line)).not.toThrow(); + }); + + // Verify content can be parsed + const graph = await manager.readGraph(); + expect(graph.entities.length).toBeGreaterThan(0); + }); + + it('should prevent file corruption with rapid consecutive operations', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // First, create some base entities + await manager.createEntities([ + { name: 'Base_1', entityType: 'Base', observations: [] }, + { name: 'Base_2', entityType: 'Base', observations: [] } + ]); + + // Perform rapid sequential operations that internally use concurrency + const operations = []; + for (let i = 0; i < 5; i++) { + operations.push( + manager.addObservations([ + { entityName: 'Base_1', contents: [`observation_${i}`] } + ]) + ); + } + + await Promise.all(operations); + + // Verify the final state + const graph = await manager.readGraph(); + const baseEntity = graph.entities.find(e => e.name === 'Base_1'); + + expect(baseEntity).toBeDefined(); + expect(baseEntity?.observations.length).toBe(5); + expect(baseEntity?.observations).toContain('observation_0'); + expect(baseEntity?.observations).toContain('observation_4'); + }); + + it('should correctly parse JSONL file after multiple concurrent operations', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Perform many concurrent operations + const concurrentOps = []; + for (let i = 0; i < 3; i++) { + concurrentOps.push( + manager.createEntities([ + { + name: `Entity_${i}`, + entityType: `Type_${i}`, + observations: [`observation_${i}`] + } + ]) + ); + } + + await Promise.all(concurrentOps); + + // Read and parse the JSONL file + const rawContent = await fs.readFile(testFilePath, 'utf-8'); + const lines = rawContent.split('\n').filter(line => line.trim() !== ''); + + // Manually parse JSONL to verify structure + const parsedLines = lines.map((line, index) => { + try { + return { data: JSON.parse(line), valid: true }; + } catch (e) { + return { data: null, valid: false, error: e, line: index, content: line }; + } + }); + + // All lines should be valid JSON + parsedLines.forEach(parsed => { + expect(parsed.valid).toBe(true); + expect(parsed.data).toHaveProperty('type'); + }); + + // Verify the data can be reconstructed + const graph = await manager.readGraph(); + expect(graph.entities).toHaveLength(3); + expect(graph.relations).toHaveLength(0); + }); + + it('should demonstrate the fix prevents "Unexpected non-whitespace character" errors', async () => { + /** + * This test specifically verifies that the fix prevents the original bug: + * "Unexpected non-whitespace character after JSON" + * + * This error occurred when the file was partially written during a race condition, + * resulting in truncated JSON on a line, followed by additional text from + * another process's write. + */ + + const manager = new KnowledgeGraphManager(testFilePath); + + // Simulate the exact scenario that would have caused the error: + // Multiple concurrent writes that would have previously caused file corruption + const operations = []; + const numConcurrentOps = 10; + + for (let i = 0; i < numConcurrentOps; i++) { + operations.push( + manager.createEntities([ + { + name: `ConcurrentEntity_${i}`, + entityType: 'TestType', + observations: [`concurrent_observation_${i}`] + } + ]) + ); + } + + // Execute all operations concurrently + const results = await Promise.all(operations); + + // All operations should succeed without errors + expect(results).toHaveLength(numConcurrentOps); + results.forEach(result => { + expect(result).toHaveLength(1); + }); + + // Read the file and verify it can be parsed without JSON errors + const rawContent = await fs.readFile(testFilePath, 'utf-8'); + const lines = rawContent.split('\n').filter(line => line.trim() !== ''); + + // Should not throw "Unexpected non-whitespace character after JSON" + let parseErrors = 0; + lines.forEach(line => { + try { + JSON.parse(line); + } catch (e) { + parseErrors++; + } + }); + + expect(parseErrors).toBe(0); + + // Verify all entities were created and can be read + const graph = await manager.readGraph(); + expect(graph.entities).toHaveLength(numConcurrentOps); + }); +}); diff --git a/src/memory/__tests__/race-condition.test.ts b/src/memory/__tests__/race-condition.test.ts new file mode 100644 index 0000000000..a43daeec9e --- /dev/null +++ b/src/memory/__tests__/race-condition.test.ts @@ -0,0 +1,213 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { promises as fs } from 'fs'; +import path from 'path'; +import { fileURLToPath } from 'url'; +import { KnowledgeGraphManager, Entity, Relation } from '../index.js'; +import { Worker } from 'worker_threads'; + +/** + * This test suite demonstrates concurrent access to the knowledge graph + * and verifies that the locking mechanism prevents race conditions. + * + * The race condition occurs when multiple processes/threads attempt to + * read, modify, and write to the same file simultaneously without synchronization. + */ + +describe('Race Condition Prevention', () => { + let testFilePath: string; + + beforeEach(async () => { + testFilePath = path.join( + path.dirname(fileURLToPath(import.meta.url)), + `race-test-${Date.now()}.jsonl` + ); + }); + + afterEach(async () => { + try { + await fs.unlink(testFilePath); + } catch (error) { + // Ignore if file doesn't exist + } + }); + + it('should handle concurrent entity creation without data loss', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Simulate concurrent operations + const operations = [ + manager.createEntities([ + { name: 'Human', entityType: 'Actor', observations: ['is_actor'] } + ]), + manager.createEntities([ + { name: 'Goat', entityType: 'Actor', observations: ['is_actor'] } + ]), + manager.createEntities([ + { name: 'Cabbage', entityType: 'Item', observations: ['is_item'] } + ]) + ]; + + // Wait for all operations to complete + const results = await Promise.all(operations); + + // Verify all entities were created successfully + expect(results[0]).toHaveLength(1); + expect(results[1]).toHaveLength(1); + expect(results[2]).toHaveLength(1); + + // Verify final state + const graph = await manager.readGraph(); + expect(graph.entities).toHaveLength(3); + expect(graph.entities.map(e => e.name)).toContain('Human'); + expect(graph.entities.map(e => e.name)).toContain('Goat'); + expect(graph.entities.map(e => e.name)).toContain('Cabbage'); + }); + + it('should handle concurrent relation creation without data loss', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Create entities first + await manager.createEntities([ + { name: 'Human', entityType: 'Actor', observations: [] }, + { name: 'Goat', entityType: 'Actor', observations: [] }, + { name: 'Cabbage', entityType: 'Item', observations: [] }, + { name: 'Start_Bank', entityType: 'Location', observations: [] }, + { name: 'End_Bank', entityType: 'Location', observations: [] } + ]); + + // Simulate concurrent relation creation + const operations = [ + manager.createRelations([ + { from: 'Human', to: 'Start_Bank', relationType: 'is_at' } + ]), + manager.createRelations([ + { from: 'Goat', to: 'Start_Bank', relationType: 'is_at' } + ]), + manager.createRelations([ + { from: 'Cabbage', to: 'Start_Bank', relationType: 'is_at' } + ]), + manager.createRelations([ + { from: 'Human', to: 'Goat', relationType: 'can_take' } + ]), + manager.createRelations([ + { from: 'Human', to: 'Cabbage', relationType: 'can_take' } + ]) + ]; + + const results = await Promise.all(operations); + + // Verify all relations were created + results.forEach(result => { + expect(result).toHaveLength(1); + }); + + // Verify final state + const graph = await manager.readGraph(); + expect(graph.relations).toHaveLength(5); + }); + + it('should handle mixed concurrent operations (create + read)', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Mix create and read operations + const operations = [ + manager.createEntities([ + { name: 'Entity1', entityType: 'Type1', observations: ['obs1'] } + ]), + manager.readGraph(), + manager.createEntities([ + { name: 'Entity2', entityType: 'Type2', observations: ['obs2'] } + ]), + manager.readGraph(), + manager.createEntities([ + { name: 'Entity3', entityType: 'Type3', observations: ['obs3'] } + ]) + ]; + + const results = await Promise.all(operations); + + // Check that reads returned valid graphs + const readResults = [results[1], results[3]]; + readResults.forEach((graph: any) => { + expect(graph).toHaveProperty('entities'); + expect(graph).toHaveProperty('relations'); + expect(Array.isArray(graph.entities)).toBe(true); + }); + + // Verify final state contains all created entities + const finalGraph = await manager.readGraph(); + expect(finalGraph.entities).toHaveLength(3); + }); + + it('should handle concurrent observations addition without data loss', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Create an entity first + await manager.createEntities([ + { name: 'TestEntity', entityType: 'TestType', observations: ['initial_obs'] } + ]); + + // Add observations concurrently + const operations = [ + manager.addObservations([ + { entityName: 'TestEntity', contents: ['observation_1'] } + ]), + manager.addObservations([ + { entityName: 'TestEntity', contents: ['observation_2'] } + ]), + manager.addObservations([ + { entityName: 'TestEntity', contents: ['observation_3'] } + ]) + ]; + + await Promise.all(operations); + + // Verify all observations were added + const graph = await manager.readGraph(); + const entity = graph.entities.find(e => e.name === 'TestEntity'); + expect(entity).toBeDefined(); + expect(entity?.observations).toContain('initial_obs'); + expect(entity?.observations).toContain('observation_1'); + expect(entity?.observations).toContain('observation_2'); + expect(entity?.observations).toContain('observation_3'); + }); + + it('should maintain data integrity under high concurrency', async () => { + const manager = new KnowledgeGraphManager(testFilePath); + + // Create base entities + const baseEntities: Entity[] = Array.from({ length: 5 }, (_, i) => ({ + name: `Entity_${i}`, + entityType: `Type_${i % 2}`, + observations: [`obs_${i}`] + })); + + await manager.createEntities(baseEntities); + + // Create many concurrent operations + const operations = []; + for (let i = 0; i < 10; i++) { + operations.push( + manager.addObservations([ + { entityName: `Entity_${i % 5}`, contents: [`concurrent_obs_${i}`] } + ]) + ); + } + + await Promise.all(operations); + + // Verify data integrity + const graph = await manager.readGraph(); + expect(graph.entities).toHaveLength(5); + + let totalObservations = 0; + graph.entities.forEach(entity => { + totalObservations += entity.observations.length; + // Each entity should have at least 1 initial observation + some concurrent ones + expect(entity.observations.length).toBeGreaterThanOrEqual(1); + }); + + // Should have 5 initial + 10 concurrent observations + expect(totalObservations).toBe(15); + }); +}); diff --git a/src/memory/index.ts b/src/memory/index.ts index c7d781d2c4..6ceeb73799 100644 --- a/src/memory/index.ts +++ b/src/memory/index.ts @@ -64,6 +64,29 @@ export interface KnowledgeGraph { relations: Relation[]; } +// Simple in-memory lock manager to prevent race conditions +// Uses a Promise-based queue to serialize file operations +class LockManager { + private locks: Map> = new Map(); + + async withLock(key: string, operation: () => Promise): Promise { + // Get the current lock promise for this key, or start with a resolved promise + const currentLock = this.locks.get(key) ?? Promise.resolve(); + + // Create a new lock that waits for the current lock to complete before starting + const newLock = currentLock.then(() => operation()); + + // Store the lock promise (without the operation result) for future operations + // This ensures the next operation waits for this one to complete + this.locks.set(key, newLock.then(() => undefined)); + + // Return the result of the operation + return newLock; + } +} + +const lockManager = new LockManager(); + // The KnowledgeGraphManager class contains all operations to interact with the knowledge graph export class KnowledgeGraphManager { constructor(private memoryFilePath: string) {} @@ -104,120 +127,142 @@ export class KnowledgeGraphManager { await fs.writeFile(this.memoryFilePath, lines.join("\n")); } + private async executeWithLock(operation: () => Promise): Promise { + return lockManager.withLock(this.memoryFilePath, operation); + } + async createEntities(entities: Entity[]): Promise { - const graph = await this.loadGraph(); - const newEntities = entities.filter(e => !graph.entities.some(existingEntity => existingEntity.name === e.name)); - graph.entities.push(...newEntities); - await this.saveGraph(graph); - return newEntities; + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + const newEntities = entities.filter(e => !graph.entities.some(existingEntity => existingEntity.name === e.name)); + graph.entities.push(...newEntities); + await this.saveGraph(graph); + return newEntities; + }); } async createRelations(relations: Relation[]): Promise { - const graph = await this.loadGraph(); - const newRelations = relations.filter(r => !graph.relations.some(existingRelation => - existingRelation.from === r.from && - existingRelation.to === r.to && - existingRelation.relationType === r.relationType - )); - graph.relations.push(...newRelations); - await this.saveGraph(graph); - return newRelations; + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + const newRelations = relations.filter(r => !graph.relations.some(existingRelation => + existingRelation.from === r.from && + existingRelation.to === r.to && + existingRelation.relationType === r.relationType + )); + graph.relations.push(...newRelations); + await this.saveGraph(graph); + return newRelations; + }); } async addObservations(observations: { entityName: string; contents: string[] }[]): Promise<{ entityName: string; addedObservations: string[] }[]> { - const graph = await this.loadGraph(); - const results = observations.map(o => { - const entity = graph.entities.find(e => e.name === o.entityName); - if (!entity) { - throw new Error(`Entity with name ${o.entityName} not found`); - } - const newObservations = o.contents.filter(content => !entity.observations.includes(content)); - entity.observations.push(...newObservations); - return { entityName: o.entityName, addedObservations: newObservations }; + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + const results = observations.map(o => { + const entity = graph.entities.find(e => e.name === o.entityName); + if (!entity) { + throw new Error(`Entity with name ${o.entityName} not found`); + } + const newObservations = o.contents.filter(content => !entity.observations.includes(content)); + entity.observations.push(...newObservations); + return { entityName: o.entityName, addedObservations: newObservations }; + }); + await this.saveGraph(graph); + return results; }); - await this.saveGraph(graph); - return results; } async deleteEntities(entityNames: string[]): Promise { - const graph = await this.loadGraph(); - graph.entities = graph.entities.filter(e => !entityNames.includes(e.name)); - graph.relations = graph.relations.filter(r => !entityNames.includes(r.from) && !entityNames.includes(r.to)); - await this.saveGraph(graph); + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + graph.entities = graph.entities.filter(e => !entityNames.includes(e.name)); + graph.relations = graph.relations.filter(r => !entityNames.includes(r.from) && !entityNames.includes(r.to)); + await this.saveGraph(graph); + }); } async deleteObservations(deletions: { entityName: string; observations: string[] }[]): Promise { - const graph = await this.loadGraph(); - deletions.forEach(d => { - const entity = graph.entities.find(e => e.name === d.entityName); - if (entity) { - entity.observations = entity.observations.filter(o => !d.observations.includes(o)); - } + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + deletions.forEach(d => { + const entity = graph.entities.find(e => e.name === d.entityName); + if (entity) { + entity.observations = entity.observations.filter(o => !d.observations.includes(o)); + } + }); + await this.saveGraph(graph); }); - await this.saveGraph(graph); } async deleteRelations(relations: Relation[]): Promise { - const graph = await this.loadGraph(); - graph.relations = graph.relations.filter(r => !relations.some(delRelation => - r.from === delRelation.from && - r.to === delRelation.to && - r.relationType === delRelation.relationType - )); - await this.saveGraph(graph); + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + graph.relations = graph.relations.filter(r => !relations.some(delRelation => + r.from === delRelation.from && + r.to === delRelation.to && + r.relationType === delRelation.relationType + )); + await this.saveGraph(graph); + }); } async readGraph(): Promise { - return this.loadGraph(); + return this.executeWithLock(async () => { + return this.loadGraph(); + }); } // Very basic search function async searchNodes(query: string): Promise { - const graph = await this.loadGraph(); + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + + // Filter entities + const filteredEntities = graph.entities.filter(e => + e.name.toLowerCase().includes(query.toLowerCase()) || + e.entityType.toLowerCase().includes(query.toLowerCase()) || + e.observations.some(o => o.toLowerCase().includes(query.toLowerCase())) + ); - // Filter entities - const filteredEntities = graph.entities.filter(e => - e.name.toLowerCase().includes(query.toLowerCase()) || - e.entityType.toLowerCase().includes(query.toLowerCase()) || - e.observations.some(o => o.toLowerCase().includes(query.toLowerCase())) - ); - - // Create a Set of filtered entity names for quick lookup - const filteredEntityNames = new Set(filteredEntities.map(e => e.name)); - - // Filter relations to only include those between filtered entities - const filteredRelations = graph.relations.filter(r => - filteredEntityNames.has(r.from) && filteredEntityNames.has(r.to) - ); - - const filteredGraph: KnowledgeGraph = { - entities: filteredEntities, - relations: filteredRelations, - }; - - return filteredGraph; + // Create a Set of filtered entity names for quick lookup + const filteredEntityNames = new Set(filteredEntities.map(e => e.name)); + + // Filter relations to only include those between filtered entities + const filteredRelations = graph.relations.filter(r => + filteredEntityNames.has(r.from) && filteredEntityNames.has(r.to) + ); + + const filteredGraph: KnowledgeGraph = { + entities: filteredEntities, + relations: filteredRelations, + }; + + return filteredGraph; + }); } async openNodes(names: string[]): Promise { - const graph = await this.loadGraph(); + return this.executeWithLock(async () => { + const graph = await this.loadGraph(); + + // Filter entities + const filteredEntities = graph.entities.filter(e => names.includes(e.name)); - // Filter entities - const filteredEntities = graph.entities.filter(e => names.includes(e.name)); - - // Create a Set of filtered entity names for quick lookup - const filteredEntityNames = new Set(filteredEntities.map(e => e.name)); - - // Filter relations to only include those between filtered entities - const filteredRelations = graph.relations.filter(r => - filteredEntityNames.has(r.from) && filteredEntityNames.has(r.to) - ); - - const filteredGraph: KnowledgeGraph = { - entities: filteredEntities, - relations: filteredRelations, - }; - - return filteredGraph; + // Create a Set of filtered entity names for quick lookup + const filteredEntityNames = new Set(filteredEntities.map(e => e.name)); + + // Filter relations to only include those between filtered entities + const filteredRelations = graph.relations.filter(r => + filteredEntityNames.has(r.from) && filteredEntityNames.has(r.to) + ); + + const filteredGraph: KnowledgeGraph = { + entities: filteredEntities, + relations: filteredRelations, + }; + + return filteredGraph; + }); } }