Skip to content

Commit 0a2ed73

Browse files
committed
refactor(strings): Improve KMP implementation and tests
1 parent e2a78d4 commit 0a2ed73

File tree

1 file changed

+105
-84
lines changed

1 file changed

+105
-84
lines changed

strings/knuth_morris_pratt.py

Lines changed: 105 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,122 @@
1+
"""
2+
Implementation of the Knuth-Morris-Pratt (KMP) string searching algorithm.
3+
The KMP algorithm searches for all occurrences of a "pattern" within a main "text"
4+
by employing the observation that when a mismatch occurs, the pattern itself
5+
embodies sufficient information to determine where the next match could begin,
6+
thus bypassing re-examination of previously matched characters.
7+
8+
This results in an optimal time complexity of O(n + m), where n is the length
9+
of the text and m is the length of the pattern.
10+
11+
Source: https://en.wikipedia.org/wiki/Knuth–Morris–Pratt_algorithm
12+
"""
113
from __future__ import annotations
214

315

4-
def knuth_morris_pratt(text: str, pattern: str) -> int:
16+
def _compute_lps_array(pattern: str) -> list[int]:
517
"""
6-
The Knuth-Morris-Pratt Algorithm for finding a pattern within a piece of text
7-
with complexity O(n + m)
8-
9-
1) Preprocess pattern to identify any suffixes that are identical to prefixes
10-
11-
This tells us where to continue from if we get a mismatch between a character
12-
in our pattern and the text.
18+
Computes the Longest Proper Prefix Suffix (LPS) array for the KMP algorithm.
19+
The LPS array for a pattern of length m is an array lps of size m where lps[i]
20+
is the length of the longest proper prefix of pattern[0...i] that is also a
21+
suffix of pattern[0...i].
22+
23+
A "proper prefix" is a prefix of the string, but not the whole string.
24+
A "proper suffix" is a suffix of the string, but not the whole string.
25+
26+
:param pattern: The pattern string to compute the LPS array for.
27+
:return: The LPS array, which is used to guide the search.
28+
29+
>>> _compute_lps_array("aabaabaaa")
30+
[0, 1, 0, 1, 2, 3, 4, 5, 2]
31+
>>> _compute_lps_array("ababaca")
32+
[0, 0, 1, 2, 3, 0, 1]
33+
>>> _compute_lps_array("AAAA")
34+
[0, 1, 2, 3]
35+
>>> _compute_lps_array("abcde")
36+
[0, 0, 0, 0, 0]
37+
"""
38+
m = len(pattern)
39+
lps = [0] * m
40+
length = 0 # Length of the previous longest prefix suffix
41+
i = 1
42+
43+
while i < m:
44+
if pattern[i] == pattern[length]:
45+
length += 1
46+
lps[i] = length
47+
i += 1
48+
else:
49+
if length != 0:
50+
length = lps[length - 1]
51+
else:
52+
lps[i] = 0
53+
i += 1
54+
return lps
1355

14-
2) Step through the text one character at a time and compare it to a character in
15-
the pattern updating our location within the pattern if necessary
1656

17-
>>> kmp = "knuth_morris_pratt"
18-
>>> all(
19-
... knuth_morris_pratt(kmp, s) == kmp.find(s)
20-
... for s in ("kn", "h_m", "rr", "tt", "not there")
21-
... )
22-
True
57+
def knuth_morris_pratt_search(text: str, pattern: str) -> list[int]:
58+
"""
59+
Finds all occurrences of a pattern in a text using the KMP algorithm.
60+
61+
:param text: The text to search in.
62+
:param pattern: The pattern to search for.
63+
:return: A list of starting indices of all occurrences of the pattern.
64+
Returns an empty list if the pattern is not found or is empty.
65+
66+
>>> # Test cases from the original file
67+
>>> knuth_morris_pratt_search("alskfjaldsabc1abc1abc12k23adsfabcabc", "abc1abc12")
68+
[10]
69+
>>> knuth_morris_pratt_search("alskfjaldsk23adsfabcabc", "abc1abc12")
70+
[]
71+
>>> knuth_morris_pratt_search("ABABZABABYABABX", "ABABX")
72+
[10]
73+
>>> knuth_morris_pratt_search("ABAAAAAB", "AAAB")
74+
[4]
75+
>>> knuth_morris_pratt_search("abcxabcdabxabcdabcdabcy", "abcdabcy")
76+
[15]
77+
>>> # More comprehensive test cases
78+
>>> knuth_morris_pratt_search("AABAACAADAABAABA", "AABA")
79+
[0, 9, 12]
80+
>>> knuth_morris_pratt_search("knuth_morris_pratt", "kn")
81+
[0]
82+
>>> knuth_morris_pratt_search("knuth_morris_pratt", "h_m")
83+
[4]
84+
>>> knuth_morris_pratt_search("knuth_morris_pratt", "rr")
85+
[12]
86+
>>> knuth_morris_pratt_search("knuth_morris_pratt", "tt")
87+
[16]
88+
>>> knuth_morris_pratt_search("knuth_morris_pratt", "not there")
89+
[]
90+
>>> knuth_morris_pratt_search("test", "")
91+
[]
2392
"""
93+
n = len(text)
94+
m = len(pattern)
95+
if m == 0:
96+
return []
2497

25-
# 1) Construct the failure array
26-
failure = get_failure_array(pattern)
98+
lps = _compute_lps_array(pattern)
99+
found_indices = []
100+
i = 0 # index for text
101+
j = 0 # index for pattern
27102

28-
# 2) Step through text searching for pattern
29-
i, j = 0, 0 # index into text, pattern
30-
while i < len(text):
103+
while i < n:
31104
if pattern[j] == text[i]:
32-
if j == (len(pattern) - 1):
33-
return i - j
105+
i += 1
34106
j += 1
35107

36-
# if this is a prefix in our pattern
37-
# just go back far enough to continue
38-
elif j > 0:
39-
j = failure[j - 1]
40-
continue
41-
i += 1
42-
return -1
43-
44-
45-
def get_failure_array(pattern: str) -> list[int]:
46-
"""
47-
Calculates the new index we should go to if we fail a comparison
48-
:param pattern:
49-
:return:
50-
"""
51-
failure = [0]
52-
i = 0
53-
j = 1
54-
while j < len(pattern):
55-
if pattern[i] == pattern[j]:
56-
i += 1
57-
elif i > 0:
58-
i = failure[i - 1]
59-
continue
60-
j += 1
61-
failure.append(i)
62-
return failure
108+
if j == m:
109+
found_indices.append(i - j)
110+
j = lps[j - 1]
111+
elif i < n and pattern[j] != text[i]:
112+
if j != 0:
113+
j = lps[j - 1]
114+
else:
115+
i += 1
116+
return found_indices
63117

64118

65119
if __name__ == "__main__":
66120
import doctest
67121

68-
doctest.testmod()
69-
70-
# Test 1)
71-
pattern = "abc1abc12"
72-
text1 = "alskfjaldsabc1abc1abc12k23adsfabcabc"
73-
text2 = "alskfjaldsk23adsfabcabc"
74-
assert knuth_morris_pratt(text1, pattern)
75-
assert knuth_morris_pratt(text2, pattern)
76-
77-
# Test 2)
78-
pattern = "ABABX"
79-
text = "ABABZABABYABABX"
80-
assert knuth_morris_pratt(text, pattern)
81-
82-
# Test 3)
83-
pattern = "AAAB"
84-
text = "ABAAAAAB"
85-
assert knuth_morris_pratt(text, pattern)
86-
87-
# Test 4)
88-
pattern = "abcdabcy"
89-
text = "abcxabcdabxabcdabcdabcy"
90-
assert knuth_morris_pratt(text, pattern)
91-
92-
# Test 5) -> Doctests
93-
kmp = "knuth_morris_pratt"
94-
assert all(
95-
knuth_morris_pratt(kmp, s) == kmp.find(s)
96-
for s in ("kn", "h_m", "rr", "tt", "not there")
97-
)
98-
99-
# Test 6)
100-
pattern = "aabaabaaa"
101-
assert get_failure_array(pattern) == [0, 1, 0, 1, 2, 3, 4, 5, 2]
122+
doctest.testmod()

0 commit comments

Comments
 (0)