Skip to content

Commit 501dad2

Browse files
Enhance kNN with multiple distance metrics
Refactor kNN classifier to support multiple distance metrics including Manhattan and Minkowski. Update distance calculation method and adjust usage in classification.
1 parent e2a78d4 commit 501dad2

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

machine_learning/k_nearest_neighbours.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from collections import Counter
1616
from heapq import nsmallest
17-
1817
import numpy as np
1918
from sklearn import datasets
2019
from sklearn.model_selection import train_test_split
@@ -26,23 +25,36 @@ def __init__(
2625
train_data: np.ndarray[float],
2726
train_target: np.ndarray[int],
2827
class_labels: list[str],
28+
distance_metric: str = "euclidean",
29+
p: int = 2,
2930
) -> None:
3031
"""
31-
Create a kNN classifier using the given training data and class labels
32+
Create a kNN classifier using the given training data and class labels.
33+
34+
Parameters:
35+
-----------
36+
distance_metric : str
37+
Type of distance metric to use ('euclidean', 'manhattan', 'minkowski')
38+
p : int
39+
Power parameter for Minkowski distance (default 2)
3240
"""
33-
self.data = zip(train_data, train_target)
41+
self.data = list(zip(train_data, train_target))
3442
self.labels = class_labels
43+
self.distance_metric = distance_metric
44+
self.p = p
3545

36-
@staticmethod
37-
def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
46+
def _calculate_distance(self, a: np.ndarray[float], b: np.ndarray[float]) -> float:
3847
"""
39-
Calculate the Euclidean distance between two points
40-
>>> KNN._euclidean_distance(np.array([0, 0]), np.array([3, 4]))
41-
5.0
42-
>>> KNN._euclidean_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
43-
10.0
48+
Calculate distance between two points based on the selected metric.
4449
"""
45-
return float(np.linalg.norm(a - b))
50+
if self.distance_metric == "euclidean":
51+
return float(np.linalg.norm(a - b))
52+
elif self.distance_metric == "manhattan":
53+
return float(np.sum(np.abs(a - b)))
54+
elif self.distance_metric == "minkowski":
55+
return float(np.sum(np.abs(a - b) ** self.p) ** (1 / self.p))
56+
else:
57+
raise ValueError("Invalid distance metric. Choose 'euclidean', 'manhattan', or 'minkowski'.")
4658

4759
def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
4860
"""
@@ -57,23 +69,18 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
5769
>>> knn.classify(point)
5870
'A'
5971
"""
60-
# Distances of all points from the point to be classified
6172
distances = (
62-
(self._euclidean_distance(data_point[0], pred_point), data_point[1])
73+
(self._calculate_distance(data_point[0], pred_point), data_point[1])
6374
for data_point in self.data
6475
)
6576

66-
# Choosing k points with the shortest distances
6777
votes = (i[1] for i in nsmallest(k, distances))
68-
69-
# Most commonly occurring class is the one into which the point is classified
7078
result = Counter(votes).most_common(1)[0][0]
7179
return self.labels[result]
7280

7381

7482
if __name__ == "__main__":
7583
import doctest
76-
7784
doctest.testmod()
7885

7986
iris = datasets.load_iris()
@@ -84,5 +91,15 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
8491

8592
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
8693
iris_point = np.array([4.4, 3.1, 1.3, 1.4])
87-
classifier = KNN(X_train, y_train, iris_classes)
88-
print(classifier.classify(iris_point, k=3))
94+
95+
print("\nUsing Euclidean Distance:")
96+
classifier1 = KNN(X_train, y_train, iris_classes, distance_metric="euclidean")
97+
print(classifier1.classify(iris_point, k=3))
98+
99+
print("\nUsing Manhattan Distance:")
100+
classifier2 = KNN(X_train, y_train, iris_classes, distance_metric="manhattan")
101+
print(classifier2.classify(iris_point, k=3))
102+
103+
print("\nUsing Minkowski Distance (p=3):")
104+
classifier3 = KNN(X_train, y_train, iris_classes, distance_metric="minkowski", p=3)
105+
print(classifier3.classify(iris_point, k=3))

0 commit comments

Comments
 (0)