Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 83 additions & 71 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""
Apriori Algorithm is a Association rule mining technique, also known as market basket
analysis, aims to discover interesting relationships or associations among a set of
items in a transactional or relational database.
Apriori Algorithm — Association Rule Mining Technique

For example, Apriori Algorithm states: "If a customer buys item A and item B, then they
are likely to buy item C." This rule suggests a relationship between items A, B, and C,
indicating that customers who purchased A and B are more likely to also purchase item C.
The Apriori algorithm is a **classic association rule learning method**, also known as
**Market Basket Analysis**, used to discover interesting relationships or associations
among a set of items in a transactional database.

WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
For example:
"If a customer buys item A and item B, they are likely to buy item C."
This suggests a relationship between items A, B, and C — indicating that customers
who purchased A and B are more likely to also purchase C.

📘 WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm

Check failure on line 13 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

machine_learning/apriori_algorithm.py:13:56: W291 Trailing whitespace
📊 Example Notebook: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from itertools import combinations
Expand All @@ -24,76 +27,50 @@
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
def get_item_support(data: list[list[str]], candidates: list[list[str]]) -> dict[tuple[str], int]:

Check failure on line 30 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:30:89: E501 Line too long (98 > 88)
"""
pruned = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
if item not in itemset or itemset.count(item) < length - 1:
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned
Compute the support count for each candidate itemset.

Args:
data: A list of transactions (each transaction is a list of items)
candidates: List of candidate itemsets

def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
Returns:
Dictionary mapping itemsets to their support count.
"""
Returns a list of frequent itemsets and their support counts.
support_count = {tuple(sorted(candidate)): 0 for candidate in candidates}
for transaction in data:
for candidate in candidates:
if all(item in transaction for item in candidate):
support_count[tuple(sorted(candidate))] += 1
return support_count

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
def prune_candidates(prev_freq_itemsets: list[list[str]], k: int) -> list[list[str]]:
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1
Generate and prune candidate itemsets of size k from previous frequent itemsets.

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
Args:
prev_freq_itemsets: Frequent itemsets of size (k-1)
k: Desired size of next candidate itemsets

return frequent_itemsets
Returns:
List of pruned candidate itemsets.
"""
candidates = []
for i in range(len(prev_freq_itemsets)):
for j in range(i + 1, len(prev_freq_itemsets)):
l1, l2 = sorted(prev_freq_itemsets[i]), sorted(prev_freq_itemsets[j])
if l1[:-1] == l2[:-1]:
candidate = sorted(list(set(l1) | set(l2)))

Check failure on line 65 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (C414)

machine_learning/apriori_algorithm.py:65:29: C414 Unnecessary `list()` call within `sorted()`
# Prune candidates whose subsets are not frequent
subsets = list(combinations(candidate, k - 1))
if all(list(subset) in prev_freq_itemsets for subset in subsets):
candidates.append(candidate)
return candidates


if __name__ == "__main__":
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
"""
Apriori algorithm for finding frequent itemsets.

Expand All @@ -103,11 +80,46 @@

Returns:
A list of frequent itemsets along with their support counts.

Example:
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)]

Check failure on line 87 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:87:89: E501 Line too long (91 > 88)
"""
import doctest
# Step 1: Find unique items
single_items = sorted({item for transaction in data for item in transaction})
current_candidates = [[item] for item in single_items]
frequent_itemsets = []

k = 1
while current_candidates:
support_count = get_item_support(data, current_candidates)
# Keep itemsets meeting minimum support
current_freq_itemsets = [
list(itemset) for itemset, count in support_count.items() if count >= min_support

Check failure on line 99 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:99:89: E501 Line too long (93 > 88)
]
frequent_itemsets.extend(
[(list(itemset), count) for itemset, count in support_count.items() if count >= min_support]

Check failure on line 102 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:102:89: E501 Line too long (104 > 88)
)

# Generate next level of candidates
k += 1
current_candidates = prune_candidates(current_freq_itemsets, k)

return frequent_itemsets


if __name__ == "__main__":
"""
Run Apriori algorithm on the sample dataset with user-defined minimum support.
"""
import doctest
doctest.testmod()

# user-defined threshold or minimum support level
frequent_itemsets = apriori(data=load_data(), min_support=2)
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
transactions = load_data()
min_support = 2
results = apriori(transactions, min_support)

print("Frequent Itemsets and Support Counts:\n")
for itemset, support in results:
print(f"{itemset}: {support}")
Loading