diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 09a89ac236bd..e3c098f4a63f 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -1,14 +1,17 @@ """ -Apriori Algorithm is a Association rule mining technique, also known as market basket -analysis, aims to discover interesting relationships or associations among a set of -items in a transactional or relational database. +Apriori Algorithm — Association Rule Mining Technique -For example, Apriori Algorithm states: "If a customer buys item A and item B, then they -are likely to buy item C." This rule suggests a relationship between items A, B, and C, -indicating that customers who purchased A and B are more likely to also purchase item C. +The Apriori algorithm is a **classic association rule learning method**, also known as +**Market Basket Analysis**, used to discover interesting relationships or associations +among a set of items in a transactional database. -WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm -Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining +For example: +"If a customer buys item A and item B, they are likely to buy item C." +This suggests a relationship between items A, B, and C — indicating that customers +who purchased A and B are more likely to also purchase C. + +📘 WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm +📊 Example Notebook: https://www.kaggle.com/code/earthian/apriori-association-rules-mining """ from itertools import combinations @@ -24,76 +27,50 @@ def load_data() -> list[list[str]]: return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]] -def prune(itemset: list, candidates: list, length: int) -> list: - """ - Prune candidate itemsets that are not frequent. - The goal of pruning is to filter out candidate itemsets that are not frequent. This - is done by checking if all the (k-1) subsets of a candidate itemset are present in - the frequent itemsets of the previous iteration (valid subsequences of the frequent - itemsets from the previous iteration). - - Prunes candidate itemsets that are not frequent. - - >>> itemset = ['X', 'Y', 'Z'] - >>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - >>> prune(itemset, candidates, 2) - [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - - >>> itemset = ['1', '2', '3', '4'] - >>> candidates = ['1', '2', '4'] - >>> prune(itemset, candidates, 3) - [] +def get_item_support(data: list[list[str]], candidates: list[list[str]]) -> dict[tuple[str], int]: """ - pruned = [] - for candidate in candidates: - is_subsequence = True - for item in candidate: - if item not in itemset or itemset.count(item) < length - 1: - is_subsequence = False - break - if is_subsequence: - pruned.append(candidate) - return pruned + Compute the support count for each candidate itemset. + Args: + data: A list of transactions (each transaction is a list of items) + candidates: List of candidate itemsets -def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]: + Returns: + Dictionary mapping itemsets to their support count. """ - Returns a list of frequent itemsets and their support counts. + support_count = {tuple(sorted(candidate)): 0 for candidate in candidates} + for transaction in data: + for candidate in candidates: + if all(item in transaction for item in candidate): + support_count[tuple(sorted(candidate))] += 1 + return support_count - >>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']] - >>> apriori(data, 2) - [(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)] - >>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']] - >>> apriori(data, 3) - [] +def prune_candidates(prev_freq_itemsets: list[list[str]], k: int) -> list[list[str]]: """ - itemset = [list(transaction) for transaction in data] - frequent_itemsets = [] - length = 1 + Generate and prune candidate itemsets of size k from previous frequent itemsets. - while itemset: - # Count itemset support - counts = [0] * len(itemset) - for transaction in data: - for j, candidate in enumerate(itemset): - if all(item in transaction for item in candidate): - counts[j] += 1 - - # Prune infrequent itemsets - itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support] - - # Append frequent itemsets (as a list to maintain order) - for i, item in enumerate(itemset): - frequent_itemsets.append((sorted(item), counts[i])) - - length += 1 - itemset = prune(itemset, list(combinations(itemset, length)), length) + Args: + prev_freq_itemsets: Frequent itemsets of size (k-1) + k: Desired size of next candidate itemsets - return frequent_itemsets + Returns: + List of pruned candidate itemsets. + """ + candidates = [] + for i in range(len(prev_freq_itemsets)): + for j in range(i + 1, len(prev_freq_itemsets)): + l1, l2 = sorted(prev_freq_itemsets[i]), sorted(prev_freq_itemsets[j]) + if l1[:-1] == l2[:-1]: + candidate = sorted(list(set(l1) | set(l2))) + # Prune candidates whose subsets are not frequent + subsets = list(combinations(candidate, k - 1)) + if all(list(subset) in prev_freq_itemsets for subset in subsets): + candidates.append(candidate) + return candidates -if __name__ == "__main__": +def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]: """ Apriori algorithm for finding frequent itemsets. @@ -103,11 +80,46 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in Returns: A list of frequent itemsets along with their support counts. + + Example: + >>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']] + >>> apriori(data, 2) + [(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)] """ - import doctest + # Step 1: Find unique items + single_items = sorted({item for transaction in data for item in transaction}) + current_candidates = [[item] for item in single_items] + frequent_itemsets = [] + k = 1 + while current_candidates: + support_count = get_item_support(data, current_candidates) + # Keep itemsets meeting minimum support + current_freq_itemsets = [ + list(itemset) for itemset, count in support_count.items() if count >= min_support + ] + frequent_itemsets.extend( + [(list(itemset), count) for itemset, count in support_count.items() if count >= min_support] + ) + + # Generate next level of candidates + k += 1 + current_candidates = prune_candidates(current_freq_itemsets, k) + + return frequent_itemsets + + +if __name__ == "__main__": + """ + Run Apriori algorithm on the sample dataset with user-defined minimum support. + """ + import doctest doctest.testmod() - # user-defined threshold or minimum support level - frequent_itemsets = apriori(data=load_data(), min_support=2) - print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets)) + transactions = load_data() + min_support = 2 + results = apriori(transactions, min_support) + + print("Frequent Itemsets and Support Counts:\n") + for itemset, support in results: + print(f"{itemset}: {support}")