@@ -5,19 +5,38 @@ import {
55 AccuracyRunStatus ,
66 AccuracyRunStatuses ,
77 ExpectedToolCall ,
8+ LLMToolCall ,
89 ModelResponse ,
10+ PromptResult ,
911} from "./result-storage.js" ;
1012
1113// Omitting these as they might contain large chunk of texts
1214const OMITTED_MODEL_RESPONSE_FIELDS : ( keyof ModelResponse ) [ ] = [ "messages" , "text" ] ;
1315
16+ // The LLMToolCalls and ExpectedToolCalls are expected to have mongodb operators
17+ // nested in the objects. This interferes with the update operation that we do
18+ // on the accuracy result document to save the model responses which is why we
19+ // serialize them before saving and deserialize them on fetch.
20+ type SavedAccuracyResult = Omit < AccuracyResult , "promptResults" > & {
21+ promptResults : SavedPromptResult [ ] ;
22+ } ;
23+
24+ type SavedPromptResult = Omit < PromptResult , "expectedToolCalls" | "modelResponses" > & {
25+ expectedToolCalls : string ;
26+ modelResponses : SavedModelResponse [ ] ;
27+ } ;
28+
29+ type SavedModelResponse = Omit < ModelResponse , "llmToolCalls" > & {
30+ llmToolCalls : string ;
31+ } ;
32+
1433export class MongoDBBasedResultStorage implements AccuracyResultStorage {
1534 private client : MongoClient ;
16- private resultCollection : Collection < AccuracyResult > ;
35+ private resultCollection : Collection < SavedAccuracyResult > ;
1736
1837 constructor ( connectionString : string , database : string , collection : string ) {
1938 this . client = new MongoClient ( connectionString ) ;
20- this . resultCollection = this . client . db ( database ) . collection < AccuracyResult > ( collection ) ;
39+ this . resultCollection = this . client . db ( database ) . collection < SavedAccuracyResult > ( collection ) ;
2140 }
2241
2342 async getAccuracyResult ( commitSHA : string , runId ?: string ) : Promise < AccuracyResult | null > {
@@ -28,11 +47,14 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
2847 // for commit is when you want the last successful run of that
2948 // particular commit.
3049 { commitSHA, runStatus : AccuracyRunStatus . Done } ;
31- return await this . resultCollection . findOne ( filters , {
50+
51+ const result = await this . resultCollection . findOne ( filters , {
3252 sort : {
3353 createdOn : - 1 ,
3454 } ,
3555 } ) ;
56+
57+ return result ? this . deserializeSavedResult ( result ) : result ;
3658 }
3759
3860 async updateRunStatus ( commitSHA : string , runId : string , status : AccuracyRunStatuses ) : Promise < void > {
@@ -59,130 +81,77 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
5981 expectedToolCalls : ExpectedToolCall [ ] ;
6082 modelResponse : ModelResponse ;
6183 } ) : Promise < void > {
62- const savedModelResponse : ModelResponse = { ...modelResponse } ;
63- for ( const field of OMITTED_MODEL_RESPONSE_FIELDS ) {
64- delete savedModelResponse [ field ] ;
65- }
66-
67- await this . resultCollection . updateOne (
68- { commitSHA, runId } ,
69- {
70- $setOnInsert : {
71- runStatus : AccuracyRunStatus . InProgress ,
72- createdOn : Date . now ( ) ,
73- commitSHA,
74- runId,
75- promptResults : [ ] ,
76- } ,
77- } ,
78- { upsert : true }
79- ) ;
80-
81- await this . resultCollection . updateOne (
82- {
83- commitSHA,
84- runId,
85- "promptResults.prompt" : { $ne : prompt } ,
86- } ,
87- {
88- $push : {
89- promptResults : { prompt, expectedToolCalls, modelResponses : [ ] } ,
90- } ,
91- }
92- ) ;
84+ const expectedToolCallsToSave = JSON . stringify ( expectedToolCalls ) ;
85+ const modelResponseToSave : SavedModelResponse = {
86+ ...modelResponse ,
87+ llmToolCalls : JSON . stringify ( modelResponse . llmToolCalls ) ,
88+ } ;
9389
94- await this . resultCollection . updateOne (
95- { commitSHA, runId } ,
96- {
97- $push : {
98- "promptResults.$[promptElement].modelResponses" : savedModelResponse ,
99- } ,
100- } ,
101- {
102- arrayFilters : [ { "promptElement.prompt" : prompt } ] ,
103- }
104- ) ;
105- }
106-
107- async saveModelResponseForPromptAtomic ( {
108- commitSHA,
109- runId,
110- prompt,
111- expectedToolCalls,
112- modelResponse,
113- } : {
114- commitSHA : string ;
115- runId : string ;
116- prompt : string ;
117- expectedToolCalls : ExpectedToolCall [ ] ;
118- modelResponse : ModelResponse ;
119- } ) : Promise < void > {
120- const savedModelResponse : ModelResponse = { ...modelResponse } ;
12190 for ( const field of OMITTED_MODEL_RESPONSE_FIELDS ) {
122- delete savedModelResponse [ field ] ;
91+ delete modelResponseToSave [ field ] ;
12392 }
12493
12594 await this . resultCollection . updateOne (
12695 { commitSHA, runId } ,
12796 [
12897 {
12998 $set : {
130- runStatus : {
131- $ifNull : [ "$runStatus" , AccuracyRunStatus . InProgress ] ,
132- } ,
133- createdOn : {
134- $ifNull : [ "$createdOn" , Date . now ( ) ] ,
99+ runStatus : { $ifNull : [ "$runStatus" , AccuracyRunStatus . InProgress ] } ,
100+ createdOn : { $ifNull : [ "$createdOn" , Date . now ( ) ] } ,
101+ commitSHA : { $ifNull : [ "$commitSHA" , commitSHA ] } ,
102+ runId : { $ifNull : [ "$runId" , runId ] } ,
103+ promptResults : {
104+ $ifNull : [ "$promptResults" , [ ] ] ,
135105 } ,
136- commitSHA : commitSHA ,
137- runId : runId ,
106+ } ,
107+ } ,
108+ {
109+ $set : {
138110 promptResults : {
139111 $let : {
140112 vars : {
141- existingPrompts : { $ifNull : [ "$promptResults" , [ ] ] } ,
142- promptExists : {
143- $in : [
144- prompt ,
145- {
146- $ifNull : [
147- { $map : { input : "$promptResults" , as : "pr" , in : "$$pr.prompt" } } ,
148- [ ] ,
149- ] ,
150- } ,
151- ] ,
113+ existingPromptIndex : {
114+ $indexOfArray : [ "$promptResults.prompt" , prompt ] ,
152115 } ,
153116 } ,
154117 in : {
155- $map : {
156- input : {
157- $cond : {
158- if : "$$promptExists" ,
159- then : "$$existingPrompts" ,
160- else : {
161- $concatArrays : [
162- "$$existingPrompts" ,
163- [ { prompt, expectedToolCalls, modelResponses : [ ] } ] ,
164- ] ,
165- } ,
166- } ,
167- } ,
168- as : "promptResult" ,
169- in : {
170- $cond : {
171- if : { $eq : [ "$$promptResult.prompt" , prompt ] } ,
172- then : {
173- prompt : "$$promptResult.prompt" ,
174- expectedToolCalls : "$$promptResult.expectedToolCalls" ,
175- modelResponses : {
176- $concatArrays : [
177- "$$promptResult.modelResponses" ,
178- [ savedModelResponse ] ,
179- ] ,
118+ $cond : [
119+ { $eq : [ "$$existingPromptIndex" , - 1 ] } ,
120+ {
121+ $concatArrays : [
122+ "$promptResults" ,
123+ [
124+ {
125+ prompt,
126+ expectedToolCalls : expectedToolCallsToSave ,
127+ modelResponses : [ modelResponseToSave ] ,
180128 } ,
129+ ] ,
130+ ] ,
131+ } ,
132+ {
133+ $map : {
134+ input : "$promptResults" ,
135+ as : "promptResult" ,
136+ in : {
137+ $cond : [
138+ { $eq : [ "$$promptResult.prompt" , prompt ] } ,
139+ {
140+ prompt : "$$promptResult.prompt" ,
141+ expectedToolCalls : expectedToolCallsToSave ,
142+ modelResponses : {
143+ $concatArrays : [
144+ "$$promptResult.modelResponses" ,
145+ [ modelResponseToSave ] ,
146+ ] ,
147+ } ,
148+ } ,
149+ "$$promptResult" ,
150+ ] ,
181151 } ,
182- else : "$$promptResult" ,
183152 } ,
184153 } ,
185- } ,
154+ ] ,
186155 } ,
187156 } ,
188157 } ,
@@ -193,6 +162,24 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
193162 ) ;
194163 }
195164
165+ private deserializeSavedResult ( result : SavedAccuracyResult ) : AccuracyResult {
166+ return {
167+ ...result ,
168+ promptResults : result . promptResults . map < PromptResult > ( ( result ) => {
169+ return {
170+ ...result ,
171+ expectedToolCalls : JSON . parse ( result . expectedToolCalls ) as ExpectedToolCall [ ] ,
172+ modelResponses : result . modelResponses . map < ModelResponse > ( ( response ) => {
173+ return {
174+ ...response ,
175+ llmToolCalls : JSON . parse ( response . llmToolCalls ) as LLMToolCall [ ] ,
176+ } ;
177+ } ) ,
178+ } ;
179+ } ) ,
180+ } ;
181+ }
182+
196183 async close ( ) : Promise < void > {
197184 await this . client . close ( ) ;
198185 }
0 commit comments