Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 3aac6f6

Browse files
authored
3.4 overlap similarity (#726)
* wip: Overlap similarity - need to figure out how it should work for the topK variant * handle returning the smaller set first for topK * overlap docs * more examples * more examples * link overlap similarity
1 parent dd13a39 commit 3aac6f6

File tree

13 files changed

+805
-33
lines changed

13 files changed

+805
-33
lines changed

algo/src/main/java/org/neo4j/graphalgo/similarity/CategoricalInput.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,21 @@ SimilarityResult jaccard(double similarityCutoff, CategoricalInput e2) {
2626
if (jaccard < similarityCutoff) return null;
2727
return new SimilarityResult(id, e2.id, count1, count2, intersection, jaccard);
2828
}
29+
30+
SimilarityResult overlap(double similarityCutoff, CategoricalInput e2) {
31+
long intersection = Intersections.intersection3(targets, e2.targets);
32+
if (similarityCutoff >= 0d && intersection == 0) return null;
33+
int count1 = targets.length;
34+
int count2 = e2.targets.length;
35+
long denominator = Math.min(count1, count2);
36+
double overlap = denominator == 0 ? 0 : (double)intersection / denominator;
37+
if (overlap < similarityCutoff) return null;
38+
39+
if(count1 <= count2) {
40+
return new SimilarityResult(id, e2.id, count1, count2, intersection, overlap, false, false);
41+
} else {
42+
return new SimilarityResult(e2.id, id, count2, count1, intersection, overlap, false, true);
43+
}
44+
45+
}
2946
}

algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public Stream<SimilaritySummaryResult> cosine(
8181

8282

8383
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
84-
return writeAndAggregateResults(configuration, stream, inputs.length, write);
84+
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
8585
}
8686

8787

algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public Stream<SimilaritySummaryResult> euclidean(
8080
.map(SimilarityResult::squareRooted);
8181

8282
boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0;
83-
return writeAndAggregateResults(configuration, stream, inputs.length, write);
83+
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
8484
}
8585

8686

algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public Stream<SimilaritySummaryResult> jaccard(
6161
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, getTopK(configuration)), getTopN(configuration));
6262

6363
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
64-
return writeAndAggregateResults(configuration, stream, inputs.length, write);
64+
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
6565
}
6666

6767

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/**
2+
* Copyright (c) 2017 "Neo4j, Inc." <http://neo4j.com>
3+
*
4+
* This file is part of Neo4j Graph Algorithms <http://github.com/neo4j-contrib/neo4j-graph-algorithms>.
5+
*
6+
* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License as published by
8+
* the Free Software Foundation, either version 3 of the License, or
9+
* (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
* GNU General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU General Public License
17+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18+
*/
19+
package org.neo4j.graphalgo.similarity;
20+
21+
import org.neo4j.graphalgo.core.ProcedureConfiguration;
22+
import org.neo4j.procedure.Description;
23+
import org.neo4j.procedure.Mode;
24+
import org.neo4j.procedure.Name;
25+
import org.neo4j.procedure.Procedure;
26+
27+
import java.util.List;
28+
import java.util.Map;
29+
import java.util.stream.Stream;
30+
31+
public class OverlapProc extends SimilarityProc {
32+
33+
@Procedure(name = "algo.similarity.overlap.stream", mode = Mode.READ)
34+
@Description("CALL algo.similarity.overlap.stream([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " +
35+
"YIELD item1, item2, count1, count2, intersection, similarity - computes jaccard similarities")
36+
public Stream<SimilarityResult> similarityStream(
37+
@Name(value = "data", defaultValue = "null") List<Map<String,Object>> data,
38+
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
39+
40+
SimilarityComputer<CategoricalInput> computer = (s, t, cutoff) -> s.overlap(cutoff, t);
41+
42+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
43+
44+
CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration));
45+
46+
return topN(similarityStream(inputs, computer, configuration, getSimilarityCutoff(configuration), getTopK(configuration)), getTopN(configuration));
47+
}
48+
49+
@Procedure(name = "algo.similarity.overlap", mode = Mode.WRITE)
50+
@Description("CALL algo.similarity.overlap([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " +
51+
"YIELD p50, p75, p90, p99, p999, p100 - computes jaccard similarities")
52+
public Stream<SimilaritySummaryResult> overlap(
53+
@Name(value = "data", defaultValue = "null") List<Map<String, Object>> data,
54+
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
55+
56+
SimilarityComputer<CategoricalInput> computer = (s,t,cutoff) -> s.overlap(cutoff, t);
57+
58+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
59+
60+
CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration));
61+
62+
double similarityCutoff = getSimilarityCutoff(configuration);
63+
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, getTopK(configuration)), getTopN(configuration));
64+
65+
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
66+
return writeAndAggregateResults(configuration, stream, inputs.length, write, "NARROWER_THAN");
67+
}
68+
69+
70+
}

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static Stream<SimilarityResult> topN(Stream<SimilarityResult> stream, int topN)
5151
if (topN > 10000) {
5252
return stream.sorted(comparator).limit(topN);
5353
}
54-
return topK(stream,topN, comparator);
54+
return topK(stream, topN, comparator);
5555
}
5656

5757
private static <T> void put(BlockingQueue<T> queue, T items) {
@@ -66,8 +66,8 @@ Long getDegreeCutoff(ProcedureConfiguration configuration) {
6666
return configuration.get("degreeCutoff", 0L);
6767
}
6868

69-
Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration configuration, Stream<SimilarityResult> stream, int length, boolean write) {
70-
String writeRelationshipType = configuration.get("writeRelationshipType", "SIMILAR");
69+
Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration configuration, Stream<SimilarityResult> stream, int length, boolean write, String defaultWriteProperty) {
70+
String writeRelationshipType = configuration.get("writeRelationshipType", defaultWriteProperty);
7171
String writeProperty = configuration.getWriteProperty("score");
7272

7373
AtomicLong similarityPairs = new AtomicLong();
@@ -77,7 +77,7 @@ Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration
7777
similarityPairs.getAndIncrement();
7878
};
7979

80-
if(write) {
80+
if (write) {
8181
SimilarityExporter similarityExporter = new SimilarityExporter(api, writeRelationshipType, writeProperty);
8282
similarityExporter.export(stream.peek(recorder));
8383
} else {
@@ -114,17 +114,15 @@ <T> Stream<SimilarityResult> similarityStream(T[] inputs, SimilarityComputer<T>
114114
private <T> Stream<SimilarityResult> similarityStream(T[] inputs, int length, double similiarityCutoff, SimilarityComputer<T> computer) {
115115
return IntStream.range(0, length)
116116
.boxed().flatMap(sourceId -> IntStream.range(sourceId + 1, length)
117-
.mapToObj(targetId -> computer.similarity(inputs[sourceId],inputs[targetId],similiarityCutoff)).filter(Objects::nonNull));
117+
.mapToObj(targetId -> computer.similarity(inputs[sourceId], inputs[targetId], similiarityCutoff)).filter(Objects::nonNull));
118118
}
119119

120120
private <T> Stream<SimilarityResult> similarityStreamTopK(T[] inputs, int length, double cutoff, int topK, SimilarityComputer<T> computer) {
121121
TopKConsumer<SimilarityResult>[] topKHolder = initializeTopKConsumers(length, topK);
122122

123-
for (int sourceId = 0;sourceId < length;sourceId++) {
124-
computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, (sourceIndex, targetIndex, similarityResult) -> {
125-
topKHolder[sourceIndex].accept(similarityResult);
126-
topKHolder[targetIndex].accept(similarityResult.reverse());
127-
}, computer);
123+
SimilarityConsumer consumer = assignSimilarityPairs(topKHolder);
124+
for (int sourceId = 0; sourceId < length; sourceId++) {
125+
computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, consumer, computer);
128126
}
129127
return Arrays.stream(topKHolder).flatMap(TopKConsumer::stream);
130128
}
@@ -176,13 +174,13 @@ private <T> Stream<SimilarityResult> similarityParallelStreamTopK(T[] inputs, in
176174
ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT);
177175

178176
TopKConsumer<SimilarityResult>[] topKConsumers = initializeTopKConsumers(length, topK);
179-
for (Runnable task : tasks) ((TopKTask)task).mergeInto(topKConsumers);
177+
for (Runnable task : tasks) ((TopKTask) task).mergeInto(topKConsumers);
180178
return Arrays.stream(topKConsumers).flatMap(TopKConsumer::stream);
181179
}
182180

183181
private <T> void computeSimilarityForSourceIndex(int sourceId, T[] inputs, int length, double cutoff, SimilarityConsumer consumer, SimilarityComputer<T> computer) {
184-
for (int targetId=sourceId+1;targetId<length;targetId++) {
185-
SimilarityResult similarity = computer.similarity(inputs[sourceId], inputs[targetId],cutoff);
182+
for (int targetId = sourceId + 1; targetId < length; targetId++) {
183+
SimilarityResult similarity = computer.similarity(inputs[sourceId], inputs[targetId], cutoff);
186184
if (similarity != null) {
187185
consumer.accept(sourceId, targetId, similarity);
188186
}
@@ -195,11 +193,11 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
195193
for (Map<String, Object> row : data) {
196194
List<Number> targetIds = extractValues(row.get("categories"));
197195
int size = targetIds.size();
198-
if ( size > degreeCutoff) {
196+
if (size > degreeCutoff) {
199197
long[] targets = new long[size];
200-
int i=0;
198+
int i = 0;
201199
for (Number id : targetIds) {
202-
targets[i++]=id.longValue();
200+
targets[i++] = id.longValue();
203201
}
204202
Arrays.sort(targets);
205203
ids[idx++] = new CategoricalInput((Long) row.get("item"), targets);
@@ -218,11 +216,11 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
218216
List<Number> weightList = extractValues(row.get("weights"));
219217

220218
int size = weightList.size();
221-
if ( size > degreeCutoff) {
219+
if (size > degreeCutoff) {
222220
double[] weights = new double[size];
223-
int i=0;
221+
int i = 0;
224222
for (Number value : weightList) {
225-
weights[i++]=value.doubleValue();
223+
weights[i++] = value.doubleValue();
226224
}
227225
inputs[idx++] = new WeightedInput((Long) row.get("item"), weights);
228226
}
@@ -233,7 +231,7 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
233231
}
234232

235233
private List<Number> extractValues(Object rawValues) {
236-
if(rawValues == null) {
234+
if (rawValues == null) {
237235
return Collections.emptyList();
238236
}
239237

@@ -259,24 +257,35 @@ protected int getTopK(ProcedureConfiguration configuration) {
259257
}
260258

261259
protected int getTopN(ProcedureConfiguration configuration) {
262-
return configuration.getInt("top",0);
260+
return configuration.getInt("top", 0);
263261
}
264262

265263
interface SimilarityComputer<T> {
266264
SimilarityResult similarity(T source, T target, double cutoff);
267265
}
268266

267+
public static SimilarityConsumer assignSimilarityPairs(TopKConsumer<SimilarityResult>[] topKConsumers) {
268+
return (s, t, result) -> {
269+
topKConsumers[result.reversed ? t : s].accept(result);
270+
271+
if (result.bidirectional) {
272+
SimilarityResult reverse = result.reverse();
273+
topKConsumers[reverse.reversed ? t : s].accept(reverse);
274+
}
275+
};
276+
}
277+
269278
private class TopKTask<T> implements Runnable {
270279
private final int batchSize;
271280
private final int taskOffset;
272281
private final int multiplier;
273282
private final int length;
274283
private final T[] ids;
275284
private final double similiarityCutoff;
276-
private final SimilarityComputer computer;
285+
private final SimilarityComputer<T> computer;
277286
private final TopKConsumer<SimilarityResult>[] topKConsumers;
278287

279-
TopKTask(int batchSize, int taskOffset, int multiplier, int length, T[] ids, double similiarityCutoff, int topK, SimilarityComputer computer) {
288+
TopKTask(int batchSize, int taskOffset, int multiplier, int length, T[] ids, double similiarityCutoff, int topK, SimilarityComputer<T> computer) {
280289
this.batchSize = batchSize;
281290
this.taskOffset = taskOffset;
282291
this.multiplier = multiplier;
@@ -289,16 +298,17 @@ private class TopKTask<T> implements Runnable {
289298

290299
@Override
291300
public void run() {
301+
SimilarityConsumer consumer = assignSimilarityPairs(topKConsumers);
292302
for (int offset = 0; offset < batchSize; offset++) {
293303
int sourceId = taskOffset * multiplier + offset;
294304
if (sourceId < length) {
295-
computeSimilarityForSourceIndex(sourceId, ids, length, similiarityCutoff, (s, t, result) -> {
296-
topKConsumers[s].accept(result);
297-
topKConsumers[t].accept(result.reverse());
298-
}, computer);
305+
306+
computeSimilarityForSourceIndex(sourceId, ids, length, similiarityCutoff, consumer, computer);
299307
}
300308
}
301309
}
310+
311+
302312
void mergeInto(TopKConsumer<SimilarityResult>[] target) {
303313
for (int i = 0; i < target.length; i++) {
304314
target[i].accept(topKConsumers[i]);

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityResult.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,23 @@ public class SimilarityResult implements Comparable<SimilarityResult> {
3030
public final long count2;
3131
public final long intersection;
3232
public double similarity;
33+
public final boolean bidirectional;
34+
public final boolean reversed;
3335

3436
public static SimilarityResult TOMB = new SimilarityResult(-1, -1, -1, -1, -1, -1);
3537

36-
public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity) {
38+
public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity, boolean bidirectional, boolean reversed) {
3739
this.item1 = item1;
3840
this.item2 = item2;
3941
this.count1 = count1;
4042
this.count2 = count2;
4143
this.intersection = intersection;
4244
this.similarity = similarity;
45+
this.bidirectional = bidirectional;
46+
this.reversed = reversed;
47+
}
48+
public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity) {
49+
this(item1,item2, count1,count2,intersection,similarity, true, false);
4350
}
4451

4552
@Override
@@ -70,7 +77,7 @@ public int compareTo(SimilarityResult o) {
7077
}
7178

7279
public SimilarityResult reverse() {
73-
return new SimilarityResult(item2, item1,count2,count1,intersection,similarity);
80+
return new SimilarityResult(item2, item1,count2,count1,intersection,similarity,bidirectional,!reversed);
7481
}
7582

7683
public SimilarityResult squareRooted() {

doc/asciidoc/algorithms-similarity.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ These algorithms help calculate the similarity of nodes:
1313
* <<algorithms-similarity-jaccard, Jaccard Similarity>> (`algo.similarity.jaccard`)
1414
* <<algorithms-similarity-cosine, Cosine Similarity>> (`algo.similarity.cosine`)
1515
* <<algorithms-similarity-euclidean, Euclidean Distance>> (`algo.similarity.euclidean`)
16+
* <<algorithms-similarity-overlap, Overlap Similarity>> (`algo.similarity.overlap`)
1617

1718
include::similarity-jaccard.adoc[leveloffset=2]
1819
include::similarity-cosine.adoc[leveloffset=2]
1920
include::similarity-euclidean.adoc[leveloffset=2]
21+
include::similarity-overlap.adoc[leveloffset=2]

0 commit comments

Comments
 (0)