Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 6a589a5

Browse files
authored
Similarity refactoring (#803)
* Delegate pearson function to the array based computation * Delegate cosine function to the array based computation * Delegate euclidean function to the array based computation * similarity vector aggregation function to save users doing the boring collect function * adding more tests + pushing NaN logic into Intersections * nicer switch config name * better name for the agg function * category instead of id
1 parent 025e2e8 commit 6a589a5

File tree

7 files changed

+290
-52
lines changed

7 files changed

+290
-52
lines changed

algo/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@
101101
<version>${neo4j.version}</version>
102102
<scope>provided</scope>
103103
</dependency>
104+
105+
<dependency>
106+
<groupId>org.mockito</groupId>
107+
<artifactId>mockito-core</artifactId>
108+
<version>2.23.4</version>
109+
<scope>test</scope>
110+
</dependency>
111+
104112
</dependencies>
105113

106114
<build>

algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java

Lines changed: 88 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@
1818
*/
1919
package org.neo4j.graphalgo.similarity;
2020

21-
import org.neo4j.procedure.Description;
22-
import org.neo4j.procedure.Name;
23-
import org.neo4j.procedure.UserFunction;
24-
21+
import com.carrotsearch.hppc.LongDoubleHashMap;
22+
import com.carrotsearch.hppc.LongDoubleMap;
23+
import com.carrotsearch.hppc.LongHashSet;
24+
import com.carrotsearch.hppc.LongSet;
25+
import com.carrotsearch.hppc.cursors.LongCursor;
26+
import org.neo4j.graphalgo.core.ProcedureConfiguration;
27+
import org.neo4j.graphalgo.core.utils.Intersections;
28+
import org.neo4j.procedure.*;
29+
30+
import java.util.HashMap;
2531
import java.util.HashSet;
2632
import java.util.List;
33+
import java.util.Map;
34+
35+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.CATEGORY_KEY;
36+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.WEIGHT_KEY;
2737

2838
public class Similarities {
2939

@@ -49,51 +59,83 @@ public double cosineSimilarity(@Name("vector1") List<Number> vector1, @Name("vec
4959
throw new RuntimeException("Vectors must be non-empty and of the same size");
5060
}
5161

52-
double dotProduct = 0d;
53-
double xLength = 0d;
54-
double yLength = 0d;
55-
for (int i = 0; i < vector1.size(); i++) {
56-
double weight1 = vector1.get(i).doubleValue();
57-
double weight2 = vector2.get(i).doubleValue();
62+
int len = Math.min(vector1.size(), vector2.size());
63+
double[] weights1 = new double[len];
64+
double[] weights2 = new double[len];
5865

59-
dotProduct += weight1 * weight2;
60-
xLength += weight1 * weight1;
61-
yLength += weight2 * weight2;
66+
for (int i = 0; i < len; i++) {
67+
weights1[i] = vector1.get(i).doubleValue();
68+
weights2[i] = vector2.get(i).doubleValue();
6269
}
6370

64-
xLength = Math.sqrt(xLength);
65-
yLength = Math.sqrt(yLength);
71+
return Math.sqrt(Intersections.cosineSquare(weights1, weights2, len));
72+
}
6673

67-
return dotProduct / (xLength * yLength);
74+
@UserAggregationFunction("algo.similarity.asVector")
75+
@Description("algo.similarity.asVector - builds a vector of maps containing items and weights")
76+
public SimilarityVectorAggregator asVector() {
77+
return new SimilarityVectorAggregator();
6878
}
6979

7080
@UserFunction("algo.similarity.pearson")
7181
@Description("algo.similarity.pearson([vector1], [vector2]) " +
7282
"given two collection vectors, calculate pearson similarity")
73-
public double pearsonSimilarity(@Name("vector1") List<Number> vector1, @Name("vector2") List<Number> vector2) {
74-
if (vector1.size() != vector2.size() || vector1.size() == 0) {
75-
throw new RuntimeException("Vectors must be non-empty and of the same size");
83+
public double pearsonSimilarity(@Name("vector1") Object rawVector1, @Name("vector2") Object rawVector2, @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
84+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
85+
86+
String listType = configuration.get("vectorType", "numbers");
87+
88+
if (listType.equalsIgnoreCase("maps")) {
89+
List<Map<String, Object>> vector1 = (List<Map<String, Object>>) rawVector1;
90+
List<Map<String, Object>> vector2 = (List<Map<String, Object>>) rawVector2;
91+
92+
LongSet ids = new LongHashSet();
93+
94+
LongDoubleMap v1Mappings = new LongDoubleHashMap();
95+
for (Map<String, Object> entry : vector1) {
96+
Long id = (Long) entry.get(CATEGORY_KEY);
97+
ids.add(id);
98+
v1Mappings.put(id, (Double) entry.get(WEIGHT_KEY));
99+
}
100+
101+
LongDoubleMap v2Mappings = new LongDoubleHashMap();
102+
for (Map<String, Object> entry : vector2) {
103+
Long id = (Long) entry.get(CATEGORY_KEY);
104+
ids.add(id);
105+
v2Mappings.put(id, (Double) entry.get(WEIGHT_KEY));
106+
}
107+
108+
double[] weights1 = new double[ids.size()];
109+
double[] weights2 = new double[ids.size()];
110+
111+
double skipValue = Double.NaN;
112+
int index = 0;
113+
for (long id : ids.toArray()) {
114+
weights1[index] = v1Mappings.getOrDefault(id, skipValue);
115+
weights2[index] = v2Mappings.getOrDefault(id, skipValue);
116+
index++;
117+
}
118+
119+
return Intersections.pearsonSkip(weights1, weights2, ids.size(), skipValue);
120+
} else {
121+
List<Number> vector1 = (List<Number>) rawVector1;
122+
List<Number> vector2 = (List<Number>) rawVector2;
123+
124+
if (vector1.size() != vector2.size() || vector1.size() == 0) {
125+
throw new RuntimeException("Vectors must be non-empty and of the same size");
126+
}
127+
128+
int len = vector1.size();
129+
double[] weights1 = new double[len];
130+
double[] weights2 = new double[len];
131+
132+
for (int i = 0; i < len; i++) {
133+
weights1[i] = vector1.get(i).doubleValue();
134+
weights2[i] = vector2.get(i).doubleValue();
135+
}
136+
return Intersections.pearson(weights1, weights2, len);
76137
}
77138

78-
double vector1Mean = vector1.stream().mapToDouble(Number::doubleValue).average().orElse(1);
79-
double vector2Mean = vector2.stream().mapToDouble(Number::doubleValue).average().orElse(1);
80-
81-
double dotProductMinusMean = 0d;
82-
double xLength = 0d;
83-
double yLength = 0d;
84-
for (int i = 0; i < vector1.size(); i++) {
85-
double weight1 = vector1.get(i).doubleValue();
86-
double weight2 = vector2.get(i).doubleValue();
87-
88-
double vector1Delta = weight1 - vector1Mean;
89-
double vector2Delta = weight2 - vector2Mean;
90-
91-
dotProductMinusMean += (vector1Delta * vector2Delta);
92-
xLength += vector1Delta * vector1Delta;
93-
yLength += vector2Delta * vector2Delta;
94-
}
95-
96-
return dotProductMinusMean / (Math.sqrt(xLength * yLength));
97139
}
98140

99141
@UserFunction("algo.similarity.euclideanDistance")
@@ -104,15 +146,16 @@ public double euclideanDistance(@Name("vector1") List<Number> vector1, @Name("ve
104146
throw new RuntimeException("Vectors must be non-empty and of the same size");
105147
}
106148

107-
double distance = 0.0;
108-
for (int i = 0; i < vector1.size(); i++) {
109-
double sqOfDiff = vector1.get(i).doubleValue() - vector2.get(i).doubleValue();
110-
sqOfDiff *= sqOfDiff;
111-
distance += sqOfDiff;
149+
int len = Math.min(vector1.size(), vector2.size());
150+
double[] weights1 = new double[len];
151+
double[] weights2 = new double[len];
152+
153+
for (int i = 0; i < len; i++) {
154+
weights1[i] = vector1.get(i).doubleValue();
155+
weights2[i] = vector2.get(i).doubleValue();
112156
}
113-
distance = Math.sqrt(distance);
114157

115-
return distance;
158+
return Math.sqrt(Intersections.sumSquareDelta(weights1, weights2, len));
116159
}
117160

118161
@UserFunction("algo.similarity.euclidean")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.neo4j.graphalgo.similarity;
2+
3+
import org.neo4j.graphdb.Node;
4+
import org.neo4j.helpers.collection.MapUtil;
5+
import org.neo4j.procedure.Name;
6+
import org.neo4j.procedure.UserAggregationResult;
7+
import org.neo4j.procedure.UserAggregationUpdate;
8+
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
public class SimilarityVectorAggregator {
14+
private List<Map<String, Object>> vector = new ArrayList<>();
15+
public static String CATEGORY_KEY = "category";
16+
public static String WEIGHT_KEY = "weight";
17+
18+
@UserAggregationUpdate
19+
public void next(
20+
@Name("node") Node node, @Name("weight") double weight) {
21+
vector.add(MapUtil.map(CATEGORY_KEY, node.getId(), WEIGHT_KEY, weight));
22+
}
23+
24+
@UserAggregationResult
25+
public List<Map<String, Object>> result() {
26+
return vector;
27+
}
28+
}

algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ public SimilarityResult pearson(RleDecoder decoder, double similarityCutoff, Wei
127127

128128
int len = Math.min(thisWeights.length, otherWeights.length);
129129
double pearson = Intersections.pearson(thisWeights, otherWeights, len);
130-
pearson = Double.isNaN(pearson) ? 0 : pearson;
131130

132131
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
133132

@@ -145,7 +144,6 @@ public SimilarityResult pearsonSkip(RleDecoder decoder, double similarityCutoff,
145144

146145
int len = Math.min(thisWeights.length, otherWeights.length);
147146
double pearson = Intersections.pearsonSkip(thisWeights, otherWeights, len, skipValue);
148-
pearson = Double.isNaN(pearson) ? 0 : pearson;
149147

150148
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
151149

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package org.neo4j.graphalgo.similarity;
2+
3+
import org.junit.Test;
4+
import org.neo4j.graphdb.Node;
5+
import org.neo4j.helpers.collection.MapUtil;
6+
7+
import java.util.Arrays;
8+
import java.util.Collections;
9+
import java.util.List;
10+
import java.util.Map;
11+
12+
import static org.hamcrest.Matchers.is;
13+
import static org.junit.Assert.*;
14+
import static org.mockito.Mockito.mock;
15+
import static org.mockito.Mockito.when;
16+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.CATEGORY_KEY;
17+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.WEIGHT_KEY;
18+
19+
public class SimilarityVectorAggregatorTest {
20+
21+
@Test
22+
public void singleItem() {
23+
SimilarityVectorAggregator aggregator = new SimilarityVectorAggregator();
24+
25+
Node node = mock(Node.class);
26+
when(node.getId()).thenReturn(1L);
27+
28+
aggregator.next(node, 3.0);
29+
30+
List<Map<String, Object>> expected = Collections.singletonList(
31+
MapUtil.map(CATEGORY_KEY, 1L, WEIGHT_KEY, 3.0)
32+
);
33+
34+
assertThat(aggregator.result(), is(expected));
35+
}
36+
37+
@Test
38+
public void multipleItems() {
39+
SimilarityVectorAggregator aggregator = new SimilarityVectorAggregator();
40+
41+
Node node = mock(Node.class);
42+
when(node.getId()).thenReturn(1L, 2L, 3L);
43+
44+
aggregator.next(node, 3.0);
45+
aggregator.next(node, 2.0);
46+
aggregator.next(node, 1.0);
47+
48+
List<Map<String, Object>> expected = Arrays.asList(
49+
MapUtil.map(CATEGORY_KEY, 1L, WEIGHT_KEY, 3.0),
50+
MapUtil.map(CATEGORY_KEY, 2L, WEIGHT_KEY, 2.0),
51+
MapUtil.map(CATEGORY_KEY, 3L, WEIGHT_KEY, 1.0)
52+
);
53+
54+
assertThat(aggregator.result(), is(expected));
55+
}
56+
57+
}

core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ public static double pearson(double[] vector1, double[] vector2, int len) {
200200
yLength += vector2Delta * vector2Delta;
201201
}
202202

203-
return dotProductMinusMean / Math.sqrt(xLength * yLength);
203+
double result = dotProductMinusMean / Math.sqrt(xLength * yLength);
204+
return Double.isNaN(result) ? 0 : result;
204205
}
205206

206207
public static double pearsonSkip(double[] vector1, double[] vector2, int len, double skipValue) {
@@ -246,7 +247,8 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
246247
yLength += vector2Delta * vector2Delta;
247248
}
248249

249-
return dotProductMinusMean / Math.sqrt(xLength * yLength);
250+
double result = dotProductMinusMean / Math.sqrt(xLength * yLength);
251+
return Double.isNaN(result) ? 0 : result;
250252
}
251253

252254
private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {

0 commit comments

Comments
 (0)