Skip to content

Commit e135374

Browse files
committed
Add KNN Manhattan and Minkowski distances with tests
1 parent c79034c commit e135374

File tree

2 files changed

+171
-9
lines changed

2 files changed

+171
-9
lines changed

machine_learning/k_nearest_neighbours.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
of the given point. In effect, the label of the given point is decided by a
77
majority vote.
88
9-
This implementation uses the commonly used Euclidean distance metric, but other
10-
distance metrics can also be used.
9+
This implementation uses the Euclidean distance metric by default, and also
10+
supports Manhattan (L1) and Minkowski (Lp) distances.
1111
1212
Reference: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
1313
"""
@@ -16,8 +16,6 @@
1616
from heapq import nsmallest
1717

1818
import numpy as np
19-
from sklearn import datasets
20-
from sklearn.model_selection import train_test_split
2119

2220

2321
class KNN:
@@ -26,12 +24,42 @@ def __init__(
2624
train_data: np.ndarray[float],
2725
train_target: np.ndarray[int],
2826
class_labels: list[str],
27+
*,
28+
distance_metric: str = "euclidean",
29+
p: float = 2.0,
2930
) -> None:
3031
"""
3132
Create a kNN classifier using the given training data and class labels
33+
34+
Parameters
35+
----------
36+
train_data : np.ndarray[float]
37+
Training features.
38+
train_target : np.ndarray[int]
39+
Training labels as integer indices.
40+
class_labels : list[str]
41+
Mapping from label index to label name.
42+
distance_metric : {"euclidean", "manhattan", "minkowski"}
43+
Distance to use for neighbour search. Defaults to "euclidean".
44+
p : float
45+
Power parameter for Minkowski distance (Lp norm). Must be >= 1 when
46+
distance_metric is "minkowski". Defaults to 2.0.
3247
"""
33-
self.data = zip(train_data, train_target)
48+
# Store a reusable copy; zip() returns an iterator that would be
49+
# exhausted after one classification otherwise.
50+
self.data = list(zip(train_data, train_target))
3451
self.labels = class_labels
52+
self.distance_metric = distance_metric.lower()
53+
self.p = float(p)
54+
55+
if self.distance_metric not in {"euclidean", "manhattan", "minkowski"}:
56+
msg = (
57+
"distance_metric must be one of {'euclidean', 'manhattan', 'minkowski'}"
58+
)
59+
raise ValueError(msg)
60+
if self.distance_metric == "minkowski" and self.p < 1:
61+
msg = "For Minkowski distance, p must be >= 1"
62+
raise ValueError(msg)
3563

3664
@staticmethod
3765
def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
@@ -44,6 +72,30 @@ def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
4472
"""
4573
return float(np.linalg.norm(a - b))
4674

75+
@staticmethod
76+
def _manhattan_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
77+
"""
78+
Calculate the Manhattan (L1) distance between two points
79+
>>> KNN._manhattan_distance(np.array([0, 0]), np.array([3, 4]))
80+
7.0
81+
>>> KNN._manhattan_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
82+
14.0
83+
"""
84+
return float(np.linalg.norm(a - b, ord=1))
85+
86+
@staticmethod
87+
def _minkowski_distance(
88+
a: np.ndarray[float], b: np.ndarray[float], p: float
89+
) -> float:
90+
"""
91+
Calculate the Minkowski (Lp) distance between two points
92+
>>> KNN._minkowski_distance(np.array([0, 0]), np.array([3, 4]), 2)
93+
5.0
94+
>>> KNN._minkowski_distance(np.array([0, 0]), np.array([3, 4]), 1)
95+
7.0
96+
"""
97+
return float(np.linalg.norm(a - b, ord=p))
98+
4799
def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
48100
"""
49101
Classify a given point using the kNN algorithm
@@ -56,12 +108,42 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
56108
>>> point = np.array([1.2, 1.2])
57109
>>> knn.classify(point)
58110
'A'
111+
>>> # Manhattan distance yields the same class here
112+
>>> knn_l1 = KNN(train_X, train_y, classes, distance_metric='manhattan')
113+
>>> knn_l1.classify(point)
114+
'A'
115+
>>> # Minkowski with p=2 equals Euclidean
116+
>>> knn_lp = KNN(train_X, train_y, classes, distance_metric='minkowski', p=2)
117+
>>> knn_lp.classify(point)
118+
'A'
119+
>>> # Invalid distance metric
120+
>>> try:
121+
... _ = KNN(train_X, train_y, classes, distance_metric='chebyshev')
122+
... except ValueError as e:
123+
... 'distance_metric' in str(e)
124+
True
125+
>>> # Invalid Minkowski power
126+
>>> try:
127+
... _ = KNN(train_X, train_y, classes, distance_metric='minkowski', p=0.5)
128+
... except ValueError as e:
129+
... 'p must be >=' in str(e)
130+
True
59131
"""
132+
# Choose the distance function once
133+
if self.distance_metric == "euclidean":
134+
def dist_fn(a: np.ndarray[float]) -> float:
135+
return self._euclidean_distance(a, pred_point)
136+
elif self.distance_metric == "manhattan":
137+
def dist_fn(a: np.ndarray[float]) -> float:
138+
return self._manhattan_distance(a, pred_point)
139+
else: # minkowski
140+
p = self.p
141+
142+
def dist_fn(a: np.ndarray[float]) -> float:
143+
return self._minkowski_distance(a, pred_point, p)
144+
60145
# Distances of all points from the point to be classified
61-
distances = (
62-
(self._euclidean_distance(data_point[0], pred_point), data_point[1])
63-
for data_point in self.data
64-
)
146+
distances = ((dist_fn(dp), lbl) for dp, lbl in self.data)
65147

66148
# Choosing k points with the shortest distances
67149
votes = (i[1] for i in nsmallest(k, distances))
@@ -76,6 +158,11 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
76158

77159
doctest.testmod()
78160

161+
# Optional demo using scikit-learn's iris dataset. Kept under __main__ to
162+
# avoid making scikit-learn a hard dependency for importing this module.
163+
from sklearn import datasets
164+
from sklearn.model_selection import train_test_split
165+
79166
iris = datasets.load_iris()
80167

81168
X = np.array(iris["data"])
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import pytest
3+
4+
from machine_learning.k_nearest_neighbours import KNN
5+
6+
7+
def test_distance_functions():
8+
a = np.array([0, 0])
9+
b = np.array([3, 4])
10+
assert KNN._euclidean_distance(a, b) == 5.0
11+
assert KNN._manhattan_distance(a, b) == 7.0
12+
assert KNN._minkowski_distance(a, b, 2) == 5.0
13+
assert KNN._minkowski_distance(a, b, 1) == 7.0
14+
15+
16+
@pytest.mark.parametrize(
17+
("distance_metric", "p"),
18+
[
19+
("euclidean", None),
20+
("manhattan", None),
21+
("minkowski", 2), # p=2 -> Euclidean
22+
("minkowski", 3), # another valid p
23+
],
24+
)
25+
def test_classify_with_different_metrics(distance_metric: str, p: float | None):
26+
train_X = np.array(
27+
[[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
28+
)
29+
train_y = np.array([0, 0, 0, 0, 1, 1, 1])
30+
classes = ["A", "B"]
31+
32+
kwargs: dict[str, object] = {"distance_metric": distance_metric}
33+
if p is not None:
34+
kwargs["p"] = float(p)
35+
36+
knn = KNN(train_X, train_y, classes, **kwargs)
37+
point = np.array([1.2, 1.2])
38+
# For this dataset/point, the class should be 'A' regardless of metric
39+
assert knn.classify(point) == "A"
40+
41+
42+
def test_invalid_distance_metric_raises():
43+
X = np.array([[0.0, 0.0]])
44+
y = np.array([0])
45+
labels = ["A"]
46+
with pytest.raises(ValueError):
47+
KNN(X, y, labels, distance_metric="chebyshev")
48+
49+
50+
def test_invalid_minkowski_p_raises():
51+
X = np.array([[0.0, 0.0]])
52+
y = np.array([0])
53+
labels = ["A"]
54+
with pytest.raises(ValueError):
55+
KNN(X, y, labels, distance_metric="minkowski", p=0.5)
56+
57+
58+
def test_multiple_classify_calls_with_same_instance():
59+
train_X = np.array([[0, 0], [1, 1], [2, 2]])
60+
train_y = np.array([0, 0, 1])
61+
classes = ["A", "B"]
62+
knn = KNN(train_X, train_y, classes)
63+
64+
p1 = np.array([0.1, 0.2])
65+
p2 = np.array([1.9, 2.0])
66+
67+
# Ensure we can call classify multiple times (zip exhaustion bug regression)
68+
assert knn.classify(p1) == "A"
69+
assert knn.classify(p2) in {"A", "B"}
70+
71+
72+
if __name__ == "__main__":
73+
import pytest as _pytest
74+
75+
_pytest.main([__file__])

0 commit comments

Comments
 (0)