66of the given point. In effect, the label of the given point is decided by a
77majority vote.
88
9- This implementation uses the commonly used Euclidean distance metric, but other
10- distance metrics can also be used .
9+ This implementation uses the Euclidean distance metric by default, and also
10+ supports Manhattan (L1) and Minkowski (Lp) distances .
1111
1212Reference: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
1313"""
1616from heapq import nsmallest
1717
1818import numpy as np
19- from sklearn import datasets
20- from sklearn .model_selection import train_test_split
2119
2220
2321class KNN :
@@ -26,12 +24,42 @@ def __init__(
2624 train_data : np .ndarray [float ],
2725 train_target : np .ndarray [int ],
2826 class_labels : list [str ],
27+ * ,
28+ distance_metric : str = "euclidean" ,
29+ p : float = 2.0 ,
2930 ) -> None :
3031 """
3132 Create a kNN classifier using the given training data and class labels
33+
34+ Parameters
35+ ----------
36+ train_data : np.ndarray[float]
37+ Training features.
38+ train_target : np.ndarray[int]
39+ Training labels as integer indices.
40+ class_labels : list[str]
41+ Mapping from label index to label name.
42+ distance_metric : {"euclidean", "manhattan", "minkowski"}
43+ Distance to use for neighbour search. Defaults to "euclidean".
44+ p : float
45+ Power parameter for Minkowski distance (Lp norm). Must be >= 1 when
46+ distance_metric is "minkowski". Defaults to 2.0.
3247 """
33- self .data = zip (train_data , train_target )
48+ # Store a reusable copy; zip() returns an iterator that would be
49+ # exhausted after one classification otherwise.
50+ self .data = list (zip (train_data , train_target ))
3451 self .labels = class_labels
52+ self .distance_metric = distance_metric .lower ()
53+ self .p = float (p )
54+
55+ if self .distance_metric not in {"euclidean" , "manhattan" , "minkowski" }:
56+ msg = (
57+ "distance_metric must be one of {'euclidean', 'manhattan', 'minkowski'}"
58+ )
59+ raise ValueError (msg )
60+ if self .distance_metric == "minkowski" and self .p < 1 :
61+ msg = "For Minkowski distance, p must be >= 1"
62+ raise ValueError (msg )
3563
3664 @staticmethod
3765 def _euclidean_distance (a : np .ndarray [float ], b : np .ndarray [float ]) -> float :
@@ -44,6 +72,30 @@ def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
4472 """
4573 return float (np .linalg .norm (a - b ))
4674
75+ @staticmethod
76+ def _manhattan_distance (a : np .ndarray [float ], b : np .ndarray [float ]) -> float :
77+ """
78+ Calculate the Manhattan (L1) distance between two points
79+ >>> KNN._manhattan_distance(np.array([0, 0]), np.array([3, 4]))
80+ 7.0
81+ >>> KNN._manhattan_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
82+ 14.0
83+ """
84+ return float (np .linalg .norm (a - b , ord = 1 ))
85+
86+ @staticmethod
87+ def _minkowski_distance (
88+ a : np .ndarray [float ], b : np .ndarray [float ], p : float
89+ ) -> float :
90+ """
91+ Calculate the Minkowski (Lp) distance between two points
92+ >>> KNN._minkowski_distance(np.array([0, 0]), np.array([3, 4]), 2)
93+ 5.0
94+ >>> KNN._minkowski_distance(np.array([0, 0]), np.array([3, 4]), 1)
95+ 7.0
96+ """
97+ return float (np .linalg .norm (a - b , ord = p ))
98+
4799 def classify (self , pred_point : np .ndarray [float ], k : int = 5 ) -> str :
48100 """
49101 Classify a given point using the kNN algorithm
@@ -56,12 +108,42 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
56108 >>> point = np.array([1.2, 1.2])
57109 >>> knn.classify(point)
58110 'A'
111+ >>> # Manhattan distance yields the same class here
112+ >>> knn_l1 = KNN(train_X, train_y, classes, distance_metric='manhattan')
113+ >>> knn_l1.classify(point)
114+ 'A'
115+ >>> # Minkowski with p=2 equals Euclidean
116+ >>> knn_lp = KNN(train_X, train_y, classes, distance_metric='minkowski', p=2)
117+ >>> knn_lp.classify(point)
118+ 'A'
119+ >>> # Invalid distance metric
120+ >>> try:
121+ ... _ = KNN(train_X, train_y, classes, distance_metric='chebyshev')
122+ ... except ValueError as e:
123+ ... 'distance_metric' in str(e)
124+ True
125+ >>> # Invalid Minkowski power
126+ >>> try:
127+ ... _ = KNN(train_X, train_y, classes, distance_metric='minkowski', p=0.5)
128+ ... except ValueError as e:
129+ ... 'p must be >=' in str(e)
130+ True
59131 """
132+ # Choose the distance function once
133+ if self .distance_metric == "euclidean" :
134+ def dist_fn (a : np .ndarray [float ]) -> float :
135+ return self ._euclidean_distance (a , pred_point )
136+ elif self .distance_metric == "manhattan" :
137+ def dist_fn (a : np .ndarray [float ]) -> float :
138+ return self ._manhattan_distance (a , pred_point )
139+ else : # minkowski
140+ p = self .p
141+
142+ def dist_fn (a : np .ndarray [float ]) -> float :
143+ return self ._minkowski_distance (a , pred_point , p )
144+
60145 # Distances of all points from the point to be classified
61- distances = (
62- (self ._euclidean_distance (data_point [0 ], pred_point ), data_point [1 ])
63- for data_point in self .data
64- )
146+ distances = ((dist_fn (dp ), lbl ) for dp , lbl in self .data )
65147
66148 # Choosing k points with the shortest distances
67149 votes = (i [1 ] for i in nsmallest (k , distances ))
@@ -76,6 +158,11 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
76158
77159 doctest .testmod ()
78160
161+ # Optional demo using scikit-learn's iris dataset. Kept under __main__ to
162+ # avoid making scikit-learn a hard dependency for importing this module.
163+ from sklearn import datasets
164+ from sklearn .model_selection import train_test_split
165+
79166 iris = datasets .load_iris ()
80167
81168 X = np .array (iris ["data" ])
0 commit comments