@@ -24,6 +24,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
2424@dataclasses .dataclass
2525class _Candidate :
2626 id : str
27+ similarity : float
2728 weighted_similarity : float
2829 weighted_redundancy : float
2930 score : float = dataclasses .field (init = False )
@@ -69,6 +70,13 @@ class MmrHelper:
6970
7071 selected_ids : list [str ]
7172 """List of selected IDs (in selection order)."""
73+
74+ selected_mmr_scores : list [float ]
75+ """List of MMR score at the time each document is selected."""
76+
77+ selected_similarity_scores : list [float ]
78+ """List of similarity score for each selected document."""
79+
7280 selected_embeddings : NDArray [np .float32 ]
7381 """(N, dim) ndarray with a row for each selected node."""
7482
@@ -100,6 +108,8 @@ def __init__(
100108 self .score_threshold = score_threshold
101109
102110 self .selected_ids = []
111+ self .selected_similarity_scores = []
112+ self .selected_mmr_scores = []
103113
104114 # List of selected embeddings (in selection order).
105115 self .selected_embeddings = np .ndarray ((k , self .dimensions ), dtype = np .float32 )
@@ -123,11 +133,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
123133 selected = len (self .selected_ids )
124134 return np .vsplit (self .selected_embeddings , [selected ])[0 ]
125135
126- def _pop_candidate (self , candidate_id : str ) -> NDArray [np .float32 ]:
136+ def _pop_candidate (self , candidate_id : str ) -> tuple [ float , NDArray [np .float32 ] ]:
127137 """Pop the candidate with the given ID.
128138
129139 Returns:
130- The embedding of the candidate.
140+ The similarity score and embedding of the candidate.
131141 """
132142 # Get the embedding for the id.
133143 index = self .candidate_id_to_index .pop (candidate_id )
@@ -143,12 +153,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
143153 # candidate_embeddings.
144154 last_index = self .candidate_embeddings .shape [0 ] - 1
145155
156+ similarity = 0.0
146157 if index == last_index :
147158 # Already the last item. We don't need to swap.
148- self .candidates .pop ()
159+ similarity = self .candidates .pop (). similarity
149160 else :
150161 self .candidate_embeddings [index ] = self .candidate_embeddings [last_index ]
151162
163+ similarity = self .candidates [index ].similarity
164+
152165 old_last = self .candidates .pop ()
153166 self .candidates [index ] = old_last
154167 self .candidate_id_to_index [old_last .id ] = index
@@ -157,7 +170,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
157170 0
158171 ]
159172
160- return embedding
173+ return similarity , embedding
161174
162175 def pop_best (self ) -> str | None :
163176 """Select and pop the best item being considered.
@@ -172,11 +185,13 @@ def pop_best(self) -> str | None:
172185
173186 # Get the selection and remove from candidates.
174187 selected_id = self .best_id
175- selected_embedding = self ._pop_candidate (selected_id )
188+ selected_similarity , selected_embedding = self ._pop_candidate (selected_id )
176189
177190 # Add the ID and embedding to the selected information.
178191 selection_index = len (self .selected_ids )
179192 self .selected_ids .append (selected_id )
193+ self .selected_mmr_scores .append (self .best_score )
194+ self .selected_similarity_scores .append (selected_similarity )
180195 self .selected_embeddings [selection_index ] = selected_embedding
181196
182197 # Reset the best score / best ID.
@@ -232,6 +247,7 @@ def add_candidates(self, candidates: dict[str, list[float]]) -> None:
232247 max_redundancy = redundancy [index ].max ()
233248 candidate = _Candidate (
234249 id = candidate_id ,
250+ similarity = similarity [index ][0 ],
235251 weighted_similarity = self .lambda_mult * similarity [index ][0 ],
236252 weighted_redundancy = self .lambda_mult_complement * max_redundancy ,
237253 )
0 commit comments