#!/usr/bin/env python
a simple implementation of Apriori algorithm by Python.
import sys
import csv
import argparse
import json
import os
from collections import namedtuple
from itertools import combinations
from itertools import chain
# Meta informations.
__version__ = '1.1.1'
__author__ = 'Yu Mochizuki'
__author_email__ = 'ymoch.dev@gmail.com'
# Data structures.
class TransactionManager(object):
Transaction managers.
def __init__(self, transactions):
transactions -- A transaction iterable object
(eg. [['A', 'B'], ['B', 'C']]).
self.__num_transaction = 0
self.__items = []
self.__transaction_index_map = {}
for transaction in transactions:
def add_transaction(self, transaction):
Add a transaction.
transaction -- A transaction as an iterable object (eg. ['A', 'B']).
for item in transaction:
if item not in self.__transaction_index_map:
self.__transaction_index_map[item] = set()
self.__num_transaction += 1
def calc_support(self, items):
Returns a support for items.
items -- Items as an iterable object (eg. ['A', 'B']).
# Empty items is supported by all transactions.
if not items:
return 1.0
# Empty transactions supports no items.
if not self.num_transaction:
return 0.0
# Create the transaction index intersection.
sum_indexes = None
for item in items:
indexes = self.__transaction_index_map.get(item)
if indexes is None:
# No support for any set that contains a not existing item.
return 0.0
if sum_indexes is None:
# Assign the indexes on the first time.
sum_indexes = indexes
# Calculate the intersection on not the first time.
sum_indexes = sum_indexes.intersection(indexes)
# Calculate and return the support.
return float(len(sum_indexes)) / self.__num_transaction
def initial_candidates(self):
Returns the initial candidates.
return [frozenset([item]) for item in self.items]
def num_transaction(self):
Returns the number of transactions.
return self.__num_transaction
def items(self):
Returns the item list that the transaction is consisted of.
return sorted(self.__items)
def create(transactions):
Create the TransactionManager with a transaction instance.
If the given instance is a TransactionManager, this returns itself.
if isinstance(transactions, TransactionManager):
return transactions
return TransactionManager(transactions)
# Ignore name errors because these names are namedtuples.
SupportRecord = namedtuple( # pylint: disable=C0103
'SupportRecord', ('items', 'support'))
RelationRecord = namedtuple( # pylint: disable=C0103
'RelationRecord', SupportRecord._fields + ('ordered_statistics',))
OrderedStatistic = namedtuple( # pylint: disable=C0103
'OrderedStatistic', ('items_base', 'items_add', 'confidence', 'lift',))
# Inner functions.
def create_next_candidates(prev_candidates, length):
Returns the apriori candidates as a list.
prev_candidates -- Previous candidates as a list.
length -- The lengths of the next candidates.
# Solve the items.
item_set = set()
for candidate in prev_candidates:
for item in candidate:
items = sorted(item_set)
# Create the temporary candidates. These will be filtered below.
tmp_next_candidates = (frozenset(x) for x in combinations(items, length))
# Return all the candidates if the length of the next candidates is 2
# because their subsets are the same as items.
if length < 3:
return list(tmp_next_candidates)
# Filter candidates that all of their subsets are
# in the previous candidates.
next_candidates = [
candidate for candidate in tmp_next_candidates
if all(
True if frozenset(x) in prev_candidates else False
for x in combinations(candidate, length - 1))
return next_candidates
def gen_support_records(transaction_manager, min_support, **kwargs):
Returns a generator of support records with given transactions.
transaction_manager -- Transactions as a TransactionManager instance.
min_support -- A minimum support (float).
Keyword arguments:
max_length -- The maximum length of relations (integer).
# Parse arguments.
max_length = kwargs.get('max_length')
# For testing.
_create_next_candidates = kwargs.get(
'_create_next_candidates', create_next_candidates)
# Process.
candidates = transaction_manager.initial_candidates()
length = 1
while candidates:
relations = set()
for relation_candidate in candidates:
support = transaction_manager.calc_support(relation_candidate)
if support < min_support:
candidate_set = frozenset(relation_candidate)
yield SupportRecord(candidate_set, support)
length += 1
if max_length and length > max_length:
candidates = _create_next_candidates(relations, length)
def gen_ordered_statistics(transaction_manager, record):
Returns a generator of ordered statistics as OrderedStatistic instances.
transaction_manager -- Transactions as a TransactionManager instance.
record -- A support record as a SupportRecord instance.
items = record.items
for combination_set in combinations(sorted(items), len(items) - 1):
items_base = frozenset(combination_set)
items_add = frozenset(items.difference(items_base))
confidence = (
record.support / transaction_manager.calc_support(items_base))
lift = confidence / transaction_manager.calc_support(items_add)
yield OrderedStatistic(
frozenset(items_base), frozenset(items_add), confidence, lift)
def filter_ordered_statistics(ordered_statistics, **kwargs):
Filter OrderedStatistic objects.
ordered_statistics -- A OrderedStatistic iterable object.
Keyword arguments:
min_confidence -- The minimum confidence of relations (float).
min_lift -- The minimum lift of relations (float).
min_confidence = kwargs.get('min_confidence', 0.0)
min_lift = kwargs.get('min_lift', 0.0)
for ordered_statistic in ordered_statistics:
if ordered_statistic.confidence < min_confidence:
if ordered_statistic.lift < min_lift:
yield ordered_statistic
# API function.
def apriori(transactions, **kwargs):
Executes Apriori algorithm and returns a RelationRecord generator.
transactions -- A transaction iterable object
(eg. [['A', 'B'], ['B', 'C']]).
Keyword arguments:
min_support -- The minimum support of relations (float).
min_confidence -- The minimum confidence of relations (float).
min_lift -- The minimum lift of relations (float).
max_length -- The maximum length of the relation (integer).
# Parse the arguments.
min_support = kwargs.get('min_support', 0.1)
min_confidence = kwargs.get('min_confidence', 0.0)
min_lift = kwargs.get('min_lift', 0.0)
max_length = kwargs.get('max_length', None)
# Check arguments.
if min_support <= 0:
raise ValueError('minimum support must be > 0')
# For testing.
_gen_support_records = kwargs.get(
'_gen_support_records', gen_support_records)
_gen_ordered_statistics = kwargs.get(
'_gen_ordered_statistics', gen_ordered_statistics)
_filter_ordered_statistics = kwargs.get(
'_filter_ordered_statistics', filter_ordered_statistics)
# Calculate supports.
transaction_manager = TransactionManager.create(transactions)
support_records = _gen_support_records(
transaction_manager, min_support, max_length=max_length)
# Calculate ordered stats.
for support_record in support_records:
ordered_statistics = list(
_gen_ordered_statistics(transaction_manager, support_record),
if not ordered_statistics:
yield RelationRecord(
support_record.items, support_record.support, ordered_statistics)
# Application functions.
def parse_args(argv):
Parse commandline arguments.
argv -- An argument list without the program name.
output_funcs = {
'json': dump_as_json,
'tsv': dump_as_two_item_tsv,
default_output_func_key = 'json'
parser = argparse.ArgumentParser()
'-v', '--version', action='version',
version='%(prog)s {0}'.format(__version__))
'input', metavar='inpath', nargs='*',
help='Input transaction file (default: stdin).',
type=argparse.FileType('r'), default=[sys.stdin])
'-o', '--output', metavar='outpath',
help='Output file (default: stdout).',
type=argparse.FileType('w'), default=sys.stdout)
'-l', '--max-length', metavar='int',
help='Max length of relations (default: infinite).',
type=int, default=None)
'-s', '--min-support', metavar='float',
help='Minimum support ratio (must be > 0, default: 0.1).',
type=float, default=0.1)
'-c', '--min-confidence', metavar='float',
help='Minimum confidence (default: 0.5).',
type=float, default=0.5)
'-t', '--min-lift', metavar='float',
help='Minimum lift (default: 0.0).',
type=float, default=0.0)
'-d', '--delimiter', metavar='str',
help='Delimiter for items of transactions (default: tab).',
type=str, default='\t')
'-f', '--out-format', metavar='str',
help='Output format ({0}; default: {1}).'.format(
', '.join(output_funcs.keys()), default_output_func_key),
type=str, choices=output_funcs.keys(), default=default_output_func_key)
args = parser.parse_args(argv)
args.output_func = output_funcs[args.out_format]
return args
def load_transactions(input_file, **kwargs):
Load transactions and returns a generator for transactions.
input_file -- An input file.
Keyword arguments:
delimiter -- The delimiter of the transaction.
delimiter = kwargs.get('delimiter', '\t')
for transaction in csv.reader(input_file, delimiter=delimiter):
yield transaction if transaction else ['']
def dump_as_json(record, output_file):
Dump an relation record as a json value.
record -- A RelationRecord instance to dump.
output_file -- A file to output.
def default_func(value):
Default conversion for JSON value.
if isinstance(value, frozenset):
return sorted(value)
raise TypeError(repr(value) + " is not JSON serializable")
converted_record = record._replace(
ordered_statistics=[x._asdict() for x in record.ordered_statistics])
converted_record._asdict(), output_file,
default=default_func, ensure_ascii=False)
def dump_as_two_item_tsv(record, output_file):
Dump a relation record as TSV only for 2 item relations.
record -- A RelationRecord instance to dump.
output_file -- A file to output.
for ordered_stats in record.ordered_statistics:
if len(ordered_stats.items_base) != 1:
if len(ordered_stats.items_add) != 1:
list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],
record.support, ordered_stats.confidence, ordered_stats.lift,
def main(**kwargs):
Executes Apriori algorithm and print its result.
# For tests.
_parse_args = kwargs.get('_parse_args', parse_args)
_load_transactions = kwargs.get('_load_transactions', load_transactions)
_apriori = kwargs.get('_apriori', apriori)
args = _parse_args(sys.argv[1:])
transactions = _load_transactions(
chain(*args.input), delimiter=args.delimiter)
result = _apriori(
for record in result:
args.output_func(record, args.output)
if __name__ == '__main__':