@@ -51,7 +51,7 @@ static Stream<SimilarityResult> topN(Stream<SimilarityResult> stream, int topN)
5151 if (topN > 10000 ) {
5252 return stream .sorted (comparator ).limit (topN );
5353 }
54- return topK (stream ,topN , comparator );
54+ return topK (stream , topN , comparator );
5555 }
5656
5757 private static <T > void put (BlockingQueue <T > queue , T items ) {
@@ -66,8 +66,8 @@ Long getDegreeCutoff(ProcedureConfiguration configuration) {
6666 return configuration .get ("degreeCutoff" , 0L );
6767 }
6868
69- Stream <SimilaritySummaryResult > writeAndAggregateResults (ProcedureConfiguration configuration , Stream <SimilarityResult > stream , int length , boolean write ) {
70- String writeRelationshipType = configuration .get ("writeRelationshipType" , "SIMILAR" );
69+ Stream <SimilaritySummaryResult > writeAndAggregateResults (ProcedureConfiguration configuration , Stream <SimilarityResult > stream , int length , boolean write , String defaultWriteProperty ) {
70+ String writeRelationshipType = configuration .get ("writeRelationshipType" , defaultWriteProperty );
7171 String writeProperty = configuration .getWriteProperty ("score" );
7272
7373 AtomicLong similarityPairs = new AtomicLong ();
@@ -77,7 +77,7 @@ Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration
7777 similarityPairs .getAndIncrement ();
7878 };
7979
80- if (write ) {
80+ if (write ) {
8181 SimilarityExporter similarityExporter = new SimilarityExporter (api , writeRelationshipType , writeProperty );
8282 similarityExporter .export (stream .peek (recorder ));
8383 } else {
@@ -114,17 +114,15 @@ <T> Stream<SimilarityResult> similarityStream(T[] inputs, SimilarityComputer<T>
114114 private <T > Stream <SimilarityResult > similarityStream (T [] inputs , int length , double similiarityCutoff , SimilarityComputer <T > computer ) {
115115 return IntStream .range (0 , length )
116116 .boxed ().flatMap (sourceId -> IntStream .range (sourceId + 1 , length )
117- .mapToObj (targetId -> computer .similarity (inputs [sourceId ],inputs [targetId ],similiarityCutoff )).filter (Objects ::nonNull ));
117+ .mapToObj (targetId -> computer .similarity (inputs [sourceId ], inputs [targetId ], similiarityCutoff )).filter (Objects ::nonNull ));
118118 }
119119
120120 private <T > Stream <SimilarityResult > similarityStreamTopK (T [] inputs , int length , double cutoff , int topK , SimilarityComputer <T > computer ) {
121121 TopKConsumer <SimilarityResult >[] topKHolder = initializeTopKConsumers (length , topK );
122122
123- for (int sourceId = 0 ;sourceId < length ;sourceId ++) {
124- computeSimilarityForSourceIndex (sourceId , inputs , length , cutoff , (sourceIndex , targetIndex , similarityResult ) -> {
125- topKHolder [sourceIndex ].accept (similarityResult );
126- topKHolder [targetIndex ].accept (similarityResult .reverse ());
127- }, computer );
123+ SimilarityConsumer consumer = assignSimilarityPairs (topKHolder );
124+ for (int sourceId = 0 ; sourceId < length ; sourceId ++) {
125+ computeSimilarityForSourceIndex (sourceId , inputs , length , cutoff , consumer , computer );
128126 }
129127 return Arrays .stream (topKHolder ).flatMap (TopKConsumer ::stream );
130128 }
@@ -176,13 +174,13 @@ private <T> Stream<SimilarityResult> similarityParallelStreamTopK(T[] inputs, in
176174 ParallelUtil .runWithConcurrency (concurrency , tasks , terminationFlag , Pools .DEFAULT );
177175
178176 TopKConsumer <SimilarityResult >[] topKConsumers = initializeTopKConsumers (length , topK );
179- for (Runnable task : tasks ) ((TopKTask )task ).mergeInto (topKConsumers );
177+ for (Runnable task : tasks ) ((TopKTask ) task ).mergeInto (topKConsumers );
180178 return Arrays .stream (topKConsumers ).flatMap (TopKConsumer ::stream );
181179 }
182180
183181 private <T > void computeSimilarityForSourceIndex (int sourceId , T [] inputs , int length , double cutoff , SimilarityConsumer consumer , SimilarityComputer <T > computer ) {
184- for (int targetId = sourceId + 1 ; targetId < length ;targetId ++) {
185- SimilarityResult similarity = computer .similarity (inputs [sourceId ], inputs [targetId ],cutoff );
182+ for (int targetId = sourceId + 1 ; targetId < length ; targetId ++) {
183+ SimilarityResult similarity = computer .similarity (inputs [sourceId ], inputs [targetId ], cutoff );
186184 if (similarity != null ) {
187185 consumer .accept (sourceId , targetId , similarity );
188186 }
@@ -195,11 +193,11 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
195193 for (Map <String , Object > row : data ) {
196194 List <Number > targetIds = extractValues (row .get ("categories" ));
197195 int size = targetIds .size ();
198- if ( size > degreeCutoff ) {
196+ if (size > degreeCutoff ) {
199197 long [] targets = new long [size ];
200- int i = 0 ;
198+ int i = 0 ;
201199 for (Number id : targetIds ) {
202- targets [i ++]= id .longValue ();
200+ targets [i ++] = id .longValue ();
203201 }
204202 Arrays .sort (targets );
205203 ids [idx ++] = new CategoricalInput ((Long ) row .get ("item" ), targets );
@@ -218,11 +216,11 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
218216 List <Number > weightList = extractValues (row .get ("weights" ));
219217
220218 int size = weightList .size ();
221- if ( size > degreeCutoff ) {
219+ if (size > degreeCutoff ) {
222220 double [] weights = new double [size ];
223- int i = 0 ;
221+ int i = 0 ;
224222 for (Number value : weightList ) {
225- weights [i ++]= value .doubleValue ();
223+ weights [i ++] = value .doubleValue ();
226224 }
227225 inputs [idx ++] = new WeightedInput ((Long ) row .get ("item" ), weights );
228226 }
@@ -233,7 +231,7 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
233231 }
234232
235233 private List <Number > extractValues (Object rawValues ) {
236- if (rawValues == null ) {
234+ if (rawValues == null ) {
237235 return Collections .emptyList ();
238236 }
239237
@@ -259,24 +257,35 @@ protected int getTopK(ProcedureConfiguration configuration) {
259257 }
260258
261259 protected int getTopN (ProcedureConfiguration configuration ) {
262- return configuration .getInt ("top" ,0 );
260+ return configuration .getInt ("top" , 0 );
263261 }
264262
265263 interface SimilarityComputer <T > {
266264 SimilarityResult similarity (T source , T target , double cutoff );
267265 }
268266
267+ public static SimilarityConsumer assignSimilarityPairs (TopKConsumer <SimilarityResult >[] topKConsumers ) {
268+ return (s , t , result ) -> {
269+ topKConsumers [result .reversed ? t : s ].accept (result );
270+
271+ if (result .bidirectional ) {
272+ SimilarityResult reverse = result .reverse ();
273+ topKConsumers [reverse .reversed ? t : s ].accept (reverse );
274+ }
275+ };
276+ }
277+
269278 private class TopKTask <T > implements Runnable {
270279 private final int batchSize ;
271280 private final int taskOffset ;
272281 private final int multiplier ;
273282 private final int length ;
274283 private final T [] ids ;
275284 private final double similiarityCutoff ;
276- private final SimilarityComputer computer ;
285+ private final SimilarityComputer < T > computer ;
277286 private final TopKConsumer <SimilarityResult >[] topKConsumers ;
278287
279- TopKTask (int batchSize , int taskOffset , int multiplier , int length , T [] ids , double similiarityCutoff , int topK , SimilarityComputer computer ) {
288+ TopKTask (int batchSize , int taskOffset , int multiplier , int length , T [] ids , double similiarityCutoff , int topK , SimilarityComputer < T > computer ) {
280289 this .batchSize = batchSize ;
281290 this .taskOffset = taskOffset ;
282291 this .multiplier = multiplier ;
@@ -289,16 +298,17 @@ private class TopKTask<T> implements Runnable {
289298
290299 @ Override
291300 public void run () {
301+ SimilarityConsumer consumer = assignSimilarityPairs (topKConsumers );
292302 for (int offset = 0 ; offset < batchSize ; offset ++) {
293303 int sourceId = taskOffset * multiplier + offset ;
294304 if (sourceId < length ) {
295- computeSimilarityForSourceIndex (sourceId , ids , length , similiarityCutoff , (s , t , result ) -> {
296- topKConsumers [s ].accept (result );
297- topKConsumers [t ].accept (result .reverse ());
298- }, computer );
305+
306+ computeSimilarityForSourceIndex (sourceId , ids , length , similiarityCutoff , consumer , computer );
299307 }
300308 }
301309 }
310+
311+
302312 void mergeInto (TopKConsumer <SimilarityResult >[] target ) {
303313 for (int i = 0 ; i < target .length ; i ++) {
304314 target [i ].accept (topKConsumers [i ]);
0 commit comments