超市购物篮关联分析与货架优化(Apriori算法)

针对超市货架布局粗放(商品摆放凭经验)、交叉销售效率低(连带购买率<8%)、库存周转慢(滞销品占比12%)的痛点,采用Apriori算法挖掘购物篮商品关联规则,实现"商品共现模式→货架优化策略"的精准映射,支撑"关联商品就近陈列、促销组合推荐、库存联动管理"。

Apriori原理与货架场景深度结合(支持度/置信度/提升度量化关联强度)、分布式事务数据处理(Spark)、动态规则更新机制(周级迭代)、业务团队货架优化系统集成

  • 事务(Transaction):一个订单中的商品集合(如{牛奶, 面包, 鸡蛋})
  • 项集(Itemset):商品的集合(如{牛奶, 面包})
  • 支持度(Support):项集在所有事务中出现的比例(衡量普遍性)
  • 支持度(Support):项集在所有事务中出现的比例(衡量普遍性)
  • 提升度(Lift):规则A→B的实际概率与随机概率的比值(Lift>1表示正相关)

业务痛点:某连锁超市(50家门店,SKU超1万种)存在三大问题:

  • 货架布局低效:高关联商品(如"啤酒-尿布")分散摆放,连带购买率仅8%(行业标杆15%)
  • 促销组合盲目:凭经验搭配促销包(如"牙膏+牙刷"),转化率不足5%,滞销品占比12%
  • 库存协同不足:关联商品库存独立管理,常出现"面包热销但果酱缺货",年损失销售额3000万元

场景结合点:

  • 货架优化:高支持度+高置信度的商品对(如{啤酒, 尿布})需就近陈列(提升连带购买)
  • 促销组合:高提升度规则(如{咖啡, 方糖},Lift=2.5)作为捆绑促销包(替代低效组合)
  • 库存协同:关联商品设置联合安全库存(如面包销量↑→预配方糖库存)

数据准备与特征变化

(1)原数据结构(事务型数据,来自POS系统)

① POS交易流水(pos_transactions,Hive表)

② 商品主数据(product_master,MySQL表)

(2)数据清洗(详细代码,算法团队负责)

代码文件:data_processing/src/main/scala/com/supermarket/DataCleaning.scala

scala 复制代码
import org.apache.spark.sql.{SparkSession,DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sql.types._
import org.slf4j.LoggerFactory

object DataCleaning{
	private val logger = LoggerFactory.getLogger(getClass)

	def cleanTransactions(spark:SparkSession,inputPath:String,outputPath: String):Unit = {
		import spark.implicits._

		//1.读取原始数据(Parquet格式,数据湖存储)
		val rawDF = spark.read.parquet(inputPath)
		logger.info(s"原始交易数据量:${rawDF.count()},字段:${rawDF.columns.mkString(",")}")
		
		//2.数据清洗
		val cleanedDF = rawDF
			.filter(col("order_id").isNotNull && col("product_id").isNotNull) // 过滤空ID 
			.filter(col("product_id").rlike("^P[0-9]{3}$"))// 保留有效商品ID(P开头+3位数字)
			.filter(col("quantity").isNotNull && col("quantity") > 0)// 过滤无效数量  
			.withColumn("purchase_time", to_timestamp(col("purchase_time"), "yyyy-MM-dd HH:mm:ss"))// 标准化时间
			.dropDuplicates("order_id", "product_id") // 去重(同一订单同商品)  
			.select("order_id", "product_id", "store_id", "purchase_time")
		
		logger.info(s"清洗后交易数据量:${cleanedDF.count()},示例:\n${cleanedDF.show(3)}")

		// 3. 保存清洗后数据(供特征工程使用)  
	    cleanedDF.write.mode("overwrite").parquet(outputPath)  
	    logger.info(s"清洗后数据保存至:$outputPath") 
	}

	def main():Unit = {
		val spark = SparkSession.builder()  
	      .appName("SupermarketTransactionCleaning")  
	      .master("yarn")  // 分布式集群运行  
	      .getOrCreate()  
	
	    val inputPath = "s3://supermarket-data-lake/raw/pos_transactions.parquet"  
	    val outputPath = "s3://supermarket-data-lake/cleaned/transactions_cleaned.parquet"  
	    cleanTransactions(spark, inputPath, outputPath)  
	    spark.stop() 
	}
}

(3)特征工程与特征数据生成(明确feature_path/label_path)

代码文件:feature_engineering/transaction_generator.py(事务转换)、feature_engineering/generate_feature_data.py(特征数据生成)

① 事务转换(transaction_generator.py)

python 复制代码
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.funcitons import collect_list

def generate_transactions()->pd.DataFrame:
	"""将清洗后数据转换为事务数据集(按order_id分组,收集product_id列表)"""
	spark = SparkSession.builder.appName("TransactionGenerator").getOrCreate()
	df = spark.read.parquet(cleaned_data_path)
	
	# 1.筛选高频商品(月销量>min_sales)
	product_sales = df.groupBy("product_id").count().filter(col("count") > min_sales)
	df_filtered = df.join(product_sales.select("product_id"),on="product_id",how="inner")

	# 2.按order_id分组,生成事务(product_id列表)
	transactions_df = df_filtered.groupBy("order_id","store_id")
		.agg(collect_list("product_id").alias("product_list"))
		.select("order_id","store_id","product_list")

	# 3.转换为Pandas DataFrame(小数据量验证用,大数据直接存Parquet)
	transactions_pd = transactions_df.toPandas()
	logger.info(f"生成事务数据集:{len(transactions_pd)}个订单,示例:\n{transactions_pd.head(2)}")  
    return transactions_pd

② 特征数据生成(generate_feature_data.py,明确feature_path/label_path)

python 复制代码
import pandas as pd
import json
from transaction_generator import generate_transactions
import logging

logging.basicConfig(level=logging.INFO)  
logger = logging.getLogger(__name__)  

def generate_feature_data(cleaned_data_path:str)->tuple:
	 """生成事务数据集(feature_path)和关联规则结果(label_path)"""  
	# 1. 生成事务数据集(feature_path指向的文件)  
    transactions_df = generate_transactions(cleaned_data_path, min_sales=100)  
    feature_path = "s3://supermarket-data-lake/processed/transactions_dataset.parquet"  
    transactions_df.to_parquet(feature_path, index=False)  

    # 2. 初始化label_path(关联规则结果文件,训练后填充)  
    label_path = "s3://supermarket-data-lake/processed/association_rules.json"  

    logger.info(f"""  
    【算法团队特征数据存储说明】  
    - feature_path: {feature_path}  
      存储内容:事务数据集(Parquet格式),含列:  
        order_id(订单ID), store_id(门店ID), product_list(商品ID列表,如["P001", "P002"])  
      示例数据(前2行):\n{transactions_df.head(2)}  

    - label_path: {label_path}  
      存储内容:关联规则结果(JSON格式),结构示例:  
        [  
          {{"antecedent": ["P001"], "consequent": ["P002"], "support": 0.025, "confidence": 0.65, "lift": 2.1}},  
          ...  
        ]  
      生成时机:Apriori模型训练完成后,由train_apriori.py写入  
    """)  

    return feature_path, label_path, transactions_df

代码结构(明确算法/业务团队仓库架构)

算法团队仓库

text 复制代码
algorithm-basket-analysis/  
├── data_processing/                # 数据清洗(Spark/Scala)  
│   ├── src/main/scala/com/supermarket/DataCleaning.scala  # 清洗代码(含详细注释)  
│   └── build.sbt                   # 依赖:spark-core, spark-sql  
├── feature_engineering/            # 特征工程(事务转换)  
│   ├── transaction_generator.py    # 事务数据集生成(含代码)  
│   ├── generate_feature_data.py     # 特征数据生成(含feature_path/label_path说明)  
│   └── requirements.txt             # 依赖:pyspark, pandas  
├── model_training/                 # Apriori模型训练(核心)  
│   ├── apriori.py                  # Apriori算法实现(含原理注释)  
│   ├── train_apriori.py            # 训练入口(参数调优/规则生成)  
│   ├── rule_evaluation.py          # 规则评估(业务价值评分)  
│   └── apriori_params.yaml         # 调参记录(min_support=0.02, min_conf=0.6)  
├── rule_storage/                   # 规则存储(MinIO)  
│   ├── upload_rules.py             # 上传规则到MinIO  
│   └── download_rules.py           # 下载规则(业务团队调用)  
└── mlflow_tracking/                # MLflow实验跟踪  
    └── run_apriori_experiment.py   # 记录参数/指标/模型

(1)算法团队:Apriori算法实现(model_training/apriori.py,含原理注释)

  • 模型文件:Apriori规则集(label_path=s3://supermarket-data-lake/processed/association_rules.json)
  • 特征数据:feature_path=s3://supermarket-data-lake/processed/transactions_dataset.parquet(事务数据集)
  • 工具:规则评估脚本(rule_evaluation.py)、MinIO上传/下载工具
python 复制代码
import pandas as pd
from collections import defaultict
import itertools
import logging

logging.basicConfig(level=logging.INFO)  
logger = logging.getLogger(__name__)  

class Apriori:
	def __init__(self, min_support=0.02, min_confidence=0.6, min_lift=1.5):  
        """  
        Apriori算法初始化  
        :param min_support: 最小支持度(默认2%,即项集在2%以上订单中出现)  
        :param min_confidence: 最小置信度(默认60%,即规则A→B的可信度)  
        :param min_lift: 最小提升度(默认1.5,即A出现时B的概率是随机的1.5倍以上)  
        """  
        self.min_support = min_support  
        self.min_confidence = min_confidence  
        self.min_lift = min_lift  
        self.transactions = None  # 事务数据集(列表的列表,如[[P001,P002], [P003,P004]])  
        self.item_counts = None  # 单元素项集支持度计数  
        self.frequent_itemsets = {}  # 频繁项集(key: 项集大小, value: 项集列表)  
        self.rules = []  # 关联规则列表  

	def fit(self,transactions:list):
		"""训练Apriori模型:生成频繁项集→挖掘关联规则"""  
        self.transactions = transactions  
        num_transactions = len(transactions)  
        logger.info(f"开始训练Apriori,事务数:{num_transactions}")  

        # Step 1: 生成频繁1-项集(L1)  
        self._generate_frequent_1_itemsets(num_transactions)  

        # Step 2: 迭代生成频繁k-项集(k≥2),基于Lk-1生成Ck,剪枝非频繁项集  
        k = 2  
        while True:  
            Ck = self._generate_candidate_itemsets(k)  # 生成候选k-项集  
            if not Ck:  
                break  
            Lk = self._prune_infrequent_itemsets(Ck, num_transactions, k)  # 剪枝(计算支持度)  
            if not Lk:  
                break  
            self.frequent_itemsets[k] = Lk  
            k += 1  

        # Step 3: 从频繁项集挖掘关联规则  
        self._generate_association_rules(num_transactions)  
        logger.info(f"训练完成,生成规则数:{len(self.rules)}")  
        return self

	def _generate_frequent_1_itemsets(self, num_transactions):  
		"""生成频繁1-项集(L1):统计每个商品的出现次数,计算支持度"""  
        item_counts = defaultdict(int)  
        for trans in self.transactions:  
            for item in trans:  
                item_counts[item] += 1  
        # 过滤支持度≥min_support的项集  
        self.item_counts = {item: (count, count/num_transactions) for item, count in item_counts.items()  
                            if count/num_transactions >= self.min_support}  
        self.frequent_itemsets[1] = [(item,) for item in self.item_counts.keys()]  
        logger.info(f"频繁1-项集数量:{len(self.frequent_itemsets[1])}")  

	def _generate_candidate_itemsets(self, k):  
        """生成候选k-项集(Ck):基于Lk-1,通过两两合并生成(仅保留前k-1个元素相同的项集)"""  
        Lk_1 = self.frequent_itemsets.get(k-1, [])  
        if not Lk_1:  
            return []  
        # 排序项集元素(确保合并一致性)  
        Lk_1_sorted = [tuple(sorted(itemset)) for itemset in Lk_1]  
        candidates = set()  
        for i in range(len(Lk_1_sorted)):  
            for j in range(i+1, len(Lk_1_sorted)):  
                itemset1 = Lk_1_sorted[i]  
                itemset2 = Lk_1_sorted[j]  
                # 前k-2个元素相同,合并第k-1个元素不同的项集  
                if itemset1[:-1] == itemset2[:-1]:  
                    candidate = tuple(sorted(set(itemset1) | set(itemset2)))  
                    if len(candidate) == k:  
                        candidates.add(candidate)  
        return list(candidates)  

    def _prune_infrequent_itemsets(self, Ck, num_transactions, k):  
        """剪枝非频繁k-项集:计算Ck中每个项集的支持度,保留≥min_support的项集"""  
        itemset_counts = defaultdict(int)  
        for trans in self.transactions:  
            trans_set = set(trans)  
            for itemset in Ck:  
                if set(itemset).issubset(trans_set):  
                    itemset_counts[itemset] += 1  
        # 过滤支持度达标的项集  
        Lk = []  
        for itemset, count in itemset_counts.items():  
            support = count / num_transactions  
            if support >= self.min_support:  
                Lk.append(itemset)  
                # 记录项集支持度(用于后续规则生成)  
                setattr(self, f"support_{itemset}", support)  
        logger.info(f"频繁{k}-项集数量:{len(Lk)}")  
        return Lk  

    def _generate_association_rules(self, num_transactions):  
        """从频繁项集(k≥2)挖掘关联规则:A→B,计算置信度、提升度"""  
        for k, itemsets in self.frequent_itemsets.items():  
            if k < 2:  
                continue  # 项集大小<2无法生成规则  
            for itemset in itemsets:  
                # 生成所有非空真子集作为前件(A)  
                subsets = self._get_non_empty_subsets(itemset)  
                for A in subsets:  
                    B = tuple(sorted(set(itemset) - set(A)))  
                    if not B:  
                        continue  
                    # 计算支持度、置信度、提升度  
                    support_A = getattr(self, f"support_{A}", 0) if len(A) == 1 else self._calculate_support(A)  
                    support_B = getattr(self, f"support_{B}", 0) if len(B) == 1 else self._calculate_support(B)  
                    support_AB = getattr(self, f"support_{itemset}", 0)  
                    confidence = support_AB / support_A if support_A > 0 else 0  
                    lift = confidence / support_B if support_B > 0 else 0  
                    # 保留强规则  
                    if confidence >= self.min_confidence and lift >= self.min_lift:  
                        self.rules.append({  
                            "antecedent": list(A), "consequent": list(B),  
                            "support": round(support_AB, 4),  
                            "confidence": round(confidence, 4),  
                            "lift": round(lift, 4)  
                        })  

    def _get_non_empty_subsets(self, itemset):  
        """生成项集的所有非空真子集(用于规则前件A)"""  
        subsets = []  
        n = len(itemset)  
        for i in range(1, n):  # 真子集(不包含所有元素)  
            for combo in itertools.combinations(itemset, i):  
                subsets.append(tuple(sorted(combo)))  
        return subsets  

    def _calculate_support(self, itemset):  
        """计算项集支持度(用于k≥2的子集)"""  
        count = 0  
        for trans in self.transactions:  
            if set(itemset).issubset(set(trans)):  
                count += 1  
        return count / len(self.transactions)	

(2)算法团队:模型训练与规则存储(model_training/train_apriori.py)

python 复制代码
import json  
import pandas as pd  
from apriori import Apriori  
from feature_engineering.generate_feature_data import generate_feature_data  
from rule_storage.upload_rules import upload_to_minio  
import logging  

logging.basicConfig(level=logging.INFO)  
logger = logging.getLogger(__name__)  

def train_and_save_rules():  
    # 1. 生成特征数据(获取feature_path和label_path)  
    feature_path, label_path, transactions_df = generate_feature_data(  
        cleaned_data_path="s3://supermarket-data-lake/cleaned/transactions_cleaned.parquet"  
    )  

    # 2. 加载事务数据集(feature_path文件)  
    transactions = transactions_df["product_list"].tolist()  # 提取商品ID列表  
    logger.info(f"加载事务数据集:{len(transactions)}个订单")  

    # 3. 训练Apriori模型(按业务目标调参)  
    apriori = Apriori(min_support=0.02, min_confidence=0.6, min_lift=1.5)  
    apriori.fit(transactions)  

    # 4. 保存规则到label_path(JSON文件)  
    with open("temp_rules.json", "w") as f:  
        json.dump(apriori.rules, f, indent=2)  
    logger.info(f"关联规则已保存至临时文件,共{len(apriori.rules)}条")  

    # 5. 上传规则到MinIO(业务团队通过label_path访问)  
    upload_to_minio(local_path="temp_rules.json", minio_path=label_path)  
    logger.info(f"规则已上传至MinIO:{label_path}")  

if __name__ == "__main__":  
    train_and_save_rules()

业务团队仓库

text 复制代码
business-basket-optimization/  
├── api_gateway/                    # API网关(Kong配置)  
├── rule_query_service/            # 关联规则查询服务(Go)  
│   ├── main.go                     # FastAPI风格Go服务(调用MinIO规则)  
│   └── Dockerfile                  # 容器化配置  
├── shelf_optimization_system/      # 货架优化系统(Java+React)  
│   ├── backend/                    # Spring Boot后端(规则解析/布局建议)  
│   ├── frontend/                   # React前端(可视化货架地图)  
│   └── sql/                        # PostgreSQL表结构(门店布局/规则元数据)  
├── promotion_engine/               # 促销组合引擎(生成促销包)  
├── inventory_sync/                 # 库存协同系统(关联商品库存预警)  
└── monitoring/                     # 监控告警(Prometheus配置)

业务团队:货架优化系统(Java后端规则解析)

  • 系统:货架优化系统(Web端,含门店布局可视化)、促销组合引擎(生成高提升度促销包)、库存协同系统(关联商品库存预警)
  • API服务:关联规则查询API(/api/rules?category=乳制品)、货架建议API(/api/shelf-suggestions?store_id=S001)
java 复制代码
import com.fasterxml.jackson.databind.ObjectMapper;  
import java.util.*;  
import java.nio.file.Files;  
import java.nio.file.Paths;  

public class ShelfOptimizer {  
    private List<AssociationRule> rules;  // 关联规则列表  

    // 关联规则POJO(对应JSON结构)  
    static class AssociationRule {  
        public List<String> antecedent;  // 前件(如["P001"])  
        public List<String> consequent;  // 后件(如["P002"])  
        public double support;  
        public double confidence;  
        public double lift;  
    }  

    // 加载MinIO中的关联规则(label_path文件)  
    public void loadRules(String minioPath) throws Exception {  
        ObjectMapper mapper = new ObjectMapper();  
        String content = new String(Files.readAllBytes(Paths.get(minioPath)));  // 实际通过MinIO SDK下载  
        this.rules = Arrays.asList(mapper.readValue(content, AssociationRule[].class));  
        System.out.println("加载规则数:" + rules.size());  
    }  

    // 生成货架优化建议(高关联商品就近陈列)  
    public Map<String, List<String>> generateShelfSuggestions(String storeId) {  
        Map<String, List<String>> suggestions = new HashMap<>();  
        // 筛选该门店的高提升度规则(lift≥2.0)  
        List<AssociationRule> highLiftRules = rules.stream()  
            .filter(r -> r.lift >= 2.0)  
            .sorted((r1, r2) -> Double.compare(r2.lift, r1.lift))  // 按提升度降序  
            .limit(10)  // 取TOP 10规则  
            .toList();  

        // 生成建议:将前件和后件商品分配到相邻货架区域  
        for (AssociationRule rule : highLiftRules) {  
            String key = String.join(",", rule.antecedent);  
            List<String> adjacentProducts = new ArrayList<>(rule.antecedent);  
            adjacentProducts.addAll(rule.consequent);  
            suggestions.put(key, adjacentProducts);  
        }  
        return suggestions;  
    }  
}
相关推荐
.小墨迹2 小时前
apollo学习之借道超车的速度规划
linux·c++·学习·算法·ubuntu
不穿格子的程序员2 小时前
从零开始刷算法——贪心篇1:跳跃游戏1 + 跳跃游戏2
算法·游戏·贪心
大江东去浪淘尽千古风流人物2 小时前
【SLAM新范式】几何主导=》几何+学习+语义+高效表示的融合
深度学习·算法·slam
重生之我是Java开发战士2 小时前
【优选算法】模拟算法:替换所有的问号,提莫攻击,N字形变换,外观数列,数青蛙
算法
仟濹2 小时前
算法打卡 day1 (2026-02-06 周四) | 算法: DFS | 1_卡码网98 可达路径 | 2_力扣797_所有可能的路径
算法·leetcode·深度优先
yang)2 小时前
欠采样时的相位倒置问题
算法
历程里程碑2 小时前
Linux20 : IO
linux·c语言·开发语言·数据结构·c++·算法
A尘埃2 小时前
物流公司配送路径动态优化(Q-Learning算法)
算法
天若有情6732 小时前
【自研实战】轻量级ASCII字符串加密算法:从设计到落地(防查岗神器版)
网络·c++·算法·安全·数据安全·加密