Skip to content

Commit a97dd89

Browse files
Enhance documentation and support count logic in Apriori
Refactor the Apriori algorithm implementation with improved documentation and support count calculation.
1 parent c79034c commit a97dd89

File tree

1 file changed

+83
-71
lines changed

1 file changed

+83
-71
lines changed
Lines changed: 83 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
"""
2-
Apriori Algorithm is a Association rule mining technique, also known as market basket
3-
analysis, aims to discover interesting relationships or associations among a set of
4-
items in a transactional or relational database.
2+
Apriori Algorithm — Association Rule Mining Technique
53
6-
For example, Apriori Algorithm states: "If a customer buys item A and item B, then they
7-
are likely to buy item C." This rule suggests a relationship between items A, B, and C,
8-
indicating that customers who purchased A and B are more likely to also purchase item C.
4+
The Apriori algorithm is a **classic association rule learning method**, also known as
5+
**Market Basket Analysis**, used to discover interesting relationships or associations
6+
among a set of items in a transactional database.
97
10-
WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm
11-
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
8+
For example:
9+
"If a customer buys item A and item B, they are likely to buy item C."
10+
This suggests a relationship between items A, B, and C — indicating that customers
11+
who purchased A and B are more likely to also purchase C.
12+
13+
📘 WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm
14+
📊 Example Notebook: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
1215
"""
1316

1417
from itertools import combinations
@@ -24,76 +27,50 @@ def load_data() -> list[list[str]]:
2427
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]
2528

2629

27-
def prune(itemset: list, candidates: list, length: int) -> list:
28-
"""
29-
Prune candidate itemsets that are not frequent.
30-
The goal of pruning is to filter out candidate itemsets that are not frequent. This
31-
is done by checking if all the (k-1) subsets of a candidate itemset are present in
32-
the frequent itemsets of the previous iteration (valid subsequences of the frequent
33-
itemsets from the previous iteration).
34-
35-
Prunes candidate itemsets that are not frequent.
36-
37-
>>> itemset = ['X', 'Y', 'Z']
38-
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
39-
>>> prune(itemset, candidates, 2)
40-
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
41-
42-
>>> itemset = ['1', '2', '3', '4']
43-
>>> candidates = ['1', '2', '4']
44-
>>> prune(itemset, candidates, 3)
45-
[]
30+
def get_item_support(data: list[list[str]], candidates: list[list[str]]) -> dict[tuple[str], int]:
4631
"""
47-
pruned = []
48-
for candidate in candidates:
49-
is_subsequence = True
50-
for item in candidate:
51-
if item not in itemset or itemset.count(item) < length - 1:
52-
is_subsequence = False
53-
break
54-
if is_subsequence:
55-
pruned.append(candidate)
56-
return pruned
32+
Compute the support count for each candidate itemset.
5733
34+
Args:
35+
data: A list of transactions (each transaction is a list of items)
36+
candidates: List of candidate itemsets
5837
59-
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
38+
Returns:
39+
Dictionary mapping itemsets to their support count.
6040
"""
61-
Returns a list of frequent itemsets and their support counts.
41+
support_count = {tuple(sorted(candidate)): 0 for candidate in candidates}
42+
for transaction in data:
43+
for candidate in candidates:
44+
if all(item in transaction for item in candidate):
45+
support_count[tuple(sorted(candidate))] += 1
46+
return support_count
6247

63-
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
64-
>>> apriori(data, 2)
65-
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
6648

67-
>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
68-
>>> apriori(data, 3)
69-
[]
49+
def prune_candidates(prev_freq_itemsets: list[list[str]], k: int) -> list[list[str]]:
7050
"""
71-
itemset = [list(transaction) for transaction in data]
72-
frequent_itemsets = []
73-
length = 1
51+
Generate and prune candidate itemsets of size k from previous frequent itemsets.
7452
75-
while itemset:
76-
# Count itemset support
77-
counts = [0] * len(itemset)
78-
for transaction in data:
79-
for j, candidate in enumerate(itemset):
80-
if all(item in transaction for item in candidate):
81-
counts[j] += 1
82-
83-
# Prune infrequent itemsets
84-
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
85-
86-
# Append frequent itemsets (as a list to maintain order)
87-
for i, item in enumerate(itemset):
88-
frequent_itemsets.append((sorted(item), counts[i]))
89-
90-
length += 1
91-
itemset = prune(itemset, list(combinations(itemset, length)), length)
53+
Args:
54+
prev_freq_itemsets: Frequent itemsets of size (k-1)
55+
k: Desired size of next candidate itemsets
9256
93-
return frequent_itemsets
57+
Returns:
58+
List of pruned candidate itemsets.
59+
"""
60+
candidates = []
61+
for i in range(len(prev_freq_itemsets)):
62+
for j in range(i + 1, len(prev_freq_itemsets)):
63+
l1, l2 = sorted(prev_freq_itemsets[i]), sorted(prev_freq_itemsets[j])
64+
if l1[:-1] == l2[:-1]:
65+
candidate = sorted(list(set(l1) | set(l2)))
66+
# Prune candidates whose subsets are not frequent
67+
subsets = list(combinations(candidate, k - 1))
68+
if all(list(subset) in prev_freq_itemsets for subset in subsets):
69+
candidates.append(candidate)
70+
return candidates
9471

9572

96-
if __name__ == "__main__":
73+
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
9774
"""
9875
Apriori algorithm for finding frequent itemsets.
9976
@@ -103,11 +80,46 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
10380
10481
Returns:
10582
A list of frequent itemsets along with their support counts.
83+
84+
Example:
85+
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
86+
>>> apriori(data, 2)
87+
[(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)]
10688
"""
107-
import doctest
89+
# Step 1: Find unique items
90+
single_items = sorted({item for transaction in data for item in transaction})
91+
current_candidates = [[item] for item in single_items]
92+
frequent_itemsets = []
10893

94+
k = 1
95+
while current_candidates:
96+
support_count = get_item_support(data, current_candidates)
97+
# Keep itemsets meeting minimum support
98+
current_freq_itemsets = [
99+
list(itemset) for itemset, count in support_count.items() if count >= min_support
100+
]
101+
frequent_itemsets.extend(
102+
[(list(itemset), count) for itemset, count in support_count.items() if count >= min_support]
103+
)
104+
105+
# Generate next level of candidates
106+
k += 1
107+
current_candidates = prune_candidates(current_freq_itemsets, k)
108+
109+
return frequent_itemsets
110+
111+
112+
if __name__ == "__main__":
113+
"""
114+
Run Apriori algorithm on the sample dataset with user-defined minimum support.
115+
"""
116+
import doctest
109117
doctest.testmod()
110118

111-
# user-defined threshold or minimum support level
112-
frequent_itemsets = apriori(data=load_data(), min_support=2)
113-
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
119+
transactions = load_data()
120+
min_support = 2
121+
results = apriori(transactions, min_support)
122+
123+
print("Frequent Itemsets and Support Counts:\n")
124+
for itemset, support in results:
125+
print(f"{itemset}: {support}")

0 commit comments

Comments
 (0)