垃圾分类识别:迁移学习在环保领域的应用

1. 引言

1.1 背景

我们使用计算机视觉创建一个简单的工具,将垃圾分类为有机垃圾或可回收垃圾,以简化垃圾管理流程。

传统的垃圾分类主要依赖人工判断,效率低下且容易出错。随着深度学习技术的发展,我们可以使用迁移学习快速构建垃圾分类模型,为环保事业提供智能化的解决方案。

1.2 应用场景

  • 家庭分类:帮助家庭用户正确分类垃圾
  • 学校教育:作为环保教育的教学工具
  • 社区指导:在社区设置智能分类指导系统
  • 学习案例:作为迁移学习和图像分类的学习示例
  • AI项目周期实践示例:完整展示AI项目从需求到部署的全流程

1.3 价值

本项目通过构建一个完整的垃圾分类识别系统,展示了如何:

  • 使用迁移学习快速构建图像分类模型
  • 处理数据不平衡问题
  • 实现从数据获取到模型部署的完整流程
  • 为环保事业提供智能化的解决方案

2. 概述

2.1 项目目标

使用计算机视觉技术将垃圾分类为有机垃圾(Organic)和可回收垃圾(Recyclable)。

2.2 任务类型

  • 任务类型:监督学习、图像分类(二分类)
  • 目标变量Organic(有机垃圾)或 Recyclable(可回收垃圾)
  • 技术方案:迁移学习(使用MobileNetV2作为基础模型)

2.3 技术栈

  • 数据处理:Pandas、NumPy、PIL
  • 深度学习:TensorFlow/Keras
  • 迁移学习:MobileNetV2(预训练模型)
  • 数据增强:ImageDataGenerator
  • 模型评估:Scikit-learn(混淆矩阵、分类报告)
  • 交互界面:IPython Widgets

2.4 数据集

  • 数据集名称:Waste Classification Data
  • 来源:Kaggle
  • 链接https://www.kaggle.com/techsash/waste-classification-data
  • 许可证:CC0: Public Domain(可自由使用)
  • 数据量:约25,000张图片
  • 类别:2类(Organic, Recyclable)
  • 数据分布
    • 训练集:Organic 10,052张,Recyclable 8,000张
    • 测试集:Organic 1,219张,Recyclable 1,112张
    • 验证集:Organic 2,513张,Recyclable 1,999张

3. AI项目周期6个阶段详解

阶段1:需求界定

3.1.1 问题定义

大多数垃圾最终进入填埋场,导致环境污染。因此,我们希望使用计算机视觉创建一个简单的工具,将垃圾分类为有机垃圾或可回收垃圾,以简化垃圾管理流程。

项目目标

  • 使用计算机视觉将垃圾分类为有机垃圾和可回收垃圾
  • 实现高准确率的分类(≥90%)
  • 提供实时预测功能
  • 创建用户友好的交互界面

应用场景

  • 家庭分类
  • 学校教育
  • 社区指导
  • 学习案例
3.1.2 关键技术:迁移学习

**迁移学习(Transfer Learning)**是一种机器学习技术,它利用在一个任务上训练好的模型来解决另一个相关任务。

迁移学习的优势

  • 快速训练:不需要从零开始训练,可以快速获得好效果
  • 数据需求少:即使数据量较少,也能获得较好的性能
  • 计算资源少:可以使用预训练模型,减少计算资源需求

MobileNetV2

  • 轻量级卷积神经网络
  • 适合移动端和边缘设备
  • 在ImageNet上预训练,特征提取能力强
  • 模型小(约8MB),推理速度快

阶段2:数据获取

3.2.1 环境准备

在开始项目之前,需要安装必要的库:

复制代码
required_libraries = {
    "numpy": None,
    "pillow": None,
    "keras": None,
    "tensorflow": None,
    "tqdm": None
}

from utilities.utils import check_and_install
check_and_install(required_libraries)
3.2.2 数据加载
复制代码
import os
import pandas as pd
from PIL import Image

# 路径配置
project_dir = os.getcwd()
data_path = os.path.join(project_dir, "sample", "data")

# 数据目录结构
train_data_path = os.path.join(data_path, "train")
train_data_path_organic = os.path.join(train_data_path, "O")
train_data_path_recyclable = os.path.join(train_data_path, "R")

test_data_path = os.path.join(data_path, "test")
test_data_path_organic = os.path.join(test_data_path, "O")
test_data_path_recyclable = os.path.join(test_data_path, "R")

# 统计数据量
train_o_count = len([f for f in os.listdir(train_data_path_organic) 
                     if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
train_r_count = len([f for f in os.listdir(train_data_path_recyclable) 
                     if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

print(f"训练集 - Organic: {train_o_count} 张, Recyclable: {train_r_count} 张")

知识点

  • 数据集按类别组织在子目录中(O=Organic, R=Recyclable)
  • 使用 os.listdir() 统计文件数量

阶段3:数据分析

3.3.1 数据探索和可视化
复制代码
import matplotlib.pyplot as plt
import numpy as np

def explore_dataset():
    """探索数据集:分布、样本展示"""
    
    # 1. 数据分布统计
    train_o_count = len([f for f in os.listdir(train_data_path_organic) 
                         if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    train_r_count = len([f for f in os.listdir(train_data_path_recyclable) 
                         if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    
    # 创建数据分布图
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    train_data = {'Organic': train_o_count, 'Recyclable': train_r_count}
    axes[0].bar(train_data.keys(), train_data.values(), color=['green', 'blue'])
    axes[0].set_title('训练集数据分布', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('图片数量')
    for i, (k, v) in enumerate(train_data.items()):
        axes[0].text(i, v, str(v), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # 计算数据不平衡比例
    if train_r_count > 0:
        imbalance_ratio = train_o_count / train_r_count
        print(f"数据不平衡比例: {imbalance_ratio:.2f} (Organic:Recyclable)")
        if imbalance_ratio > 1.2 or imbalance_ratio < 0.8:
            print("⚠️ 数据存在不平衡,需要使用类别权重处理")
    
    # 2. 样本图片展示
    def show_samples(class_name, class_path, num_samples=4):
        """展示某个类别的样本图片"""
        files = [f for f in os.listdir(class_path) 
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        sample_files = np.random.choice(files, min(num_samples, len(files)), replace=False)
        
        fig, axes = plt.subplots(1, num_samples, figsize=(15, 4))
        if num_samples == 1:
            axes = [axes]
        
        for i, filename in enumerate(sample_files):
            img_path = os.path.join(class_path, filename)
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].set_title(f'{class_name}\n{filename}', fontsize=10)
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    print("Organic类别样本:")
    show_samples('Organic', train_data_path_organic, 4)
    
    print("Recyclable类别样本:")
    show_samples('Recyclable', train_data_path_recyclable, 4)

# 执行数据探索
explore_dataset()

知识点

  • 数据不平衡:Organic样本(10,052)多于Recyclable样本(8,000),需要使用类别权重处理
  • 样本展示:通过可视化了解数据特点
3.3.2 数据预处理
复制代码
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils import class_weight

# 创建验证集(从训练集划分20%)
validation_data_path = os.path.join(data_path, "validation")
validation_ratio = 0.2

# 数据增强配置
train_datagen = ImageDataGenerator(
    rescale=1./255,  # 归一化到[0,1]
    rotation_range=20,  # 随机旋转±20度
    width_shift_range=0.2,  # 水平平移±20%
    height_shift_range=0.2,  # 垂直平移±20%
    shear_range=0.2,  # 剪切变换±20%
    zoom_range=0.2,  # 缩放±20%
    horizontal_flip=True,  # 水平翻转
    fill_mode='nearest'  # 填充模式
)

# 测试集和验证集只做归一化(不做增强)
test_datagen = ImageDataGenerator(rescale=1./255)
validation_datagen = ImageDataGenerator(rescale=1./255)

# 设置批次大小和图片尺寸
batch_size = 32
target_size = (128, 128)  # 模型输入尺寸

# 创建数据生成器
train_generator = train_datagen.flow_from_directory(
    train_data_path,
    target_size=target_size,
    batch_size=batch_size,
    class_mode='binary',  # 二分类
    shuffle=True
)

validation_generator = validation_datagen.flow_from_directory(
    validation_data_path,
    target_size=target_size,
    batch_size=batch_size,
    class_mode='binary',
    shuffle=False
)

# 计算类别权重(处理数据不平衡)
train_classes = train_generator.classes
class_weights = class_weight.compute_class_weight(
    'balanced',  # 使用平衡策略
    classes=np.unique(train_classes),
    y=train_classes
)
class_weights = dict(enumerate(class_weights))

print(f"✅ 数据生成器创建完成")
print(f"   - 训练集: {train_generator.samples} 张图片")
print(f"   - 验证集: {validation_generator.samples} 张图片")
print(f"   - 类别权重: {class_weights}")

知识点

  • 数据增强:通过旋转、翻转、缩放等变换增加数据多样性
  • 类别权重:处理数据不平衡问题,给少数类别更高的权重
  • 数据生成器:批量加载和预处理图片数据,提高训练效率

阶段4:模型构建

3.4.1 创建迁移学习模型
复制代码
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def create_transfer_model(input_shape):
    """创建基于MobileNetV2的迁移学习模型"""
    # 加载预训练模型(不包括顶层)
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet',
        pooling='avg'
    )
    
    # 冻结基础模型
    base_model.trainable = False
    
    # 添加自定义顶层
    inputs = Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = Dropout(0.3)(x)
    outputs = Dense(1, activation='sigmoid')(x)  # 二分类,使用sigmoid激活
    
    # 创建完整模型
    model = Model(inputs, outputs)
    
    # 编译模型
    model.compile(
        optimizer=Adam(learning_rate=0.0001),
        loss='binary_crossentropy',
        metrics=['accuracy', 
                 tf.keras.metrics.Precision(name='precision'),
                 tf.keras.metrics.Recall(name='recall')]
    )
    
    return model

# 创建模型
input_shape = (target_size[0], target_size[1], 3)
model = create_transfer_model(input_shape)
model.summary()

输出示例

复制代码
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 mobilenetv2_1.00_128 (Func  (None, 1280)              2257984   
 tional)                                                         
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 1)                 1281      
                                                                 
=================================================================
Total params: 2259265 (8.62 MB)
Trainable params: 1281 (5.00 KB)
Non-trainable params: 2257984 (8.61 MB)

知识点

  • 迁移学习:使用预训练的MobileNetV2作为特征提取器
  • 冻结基础模型:只训练顶层,加快训练速度
  • 模型大小:约8.62 MB,满足≤20MB要求
3.4.2 训练模型
复制代码
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# 回调函数
callbacks = [
    ModelCheckpoint(
        filepath=best_model_file,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=2,
        verbose=1
    )
]

# 训练模型
epochs = 10
steps_per_epoch = max(1, train_generator.samples // batch_size)
validation_steps = max(1, validation_generator.samples // batch_size)

history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=validation_steps,
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1
)

# 保存模型
model.save(os.path.join(model_path, 'garbage_classifier.keras'))
print(f"✅ 模型训练完成并已保存")

知识点

  • 回调函数:ModelCheckpoint保存最佳模型,EarlyStopping防止过拟合,ReduceLROnPlateau自动调整学习率
  • 类别权重:使用class_weight处理数据不平衡问题
  • 模型保存:保存为Keras格式,便于后续加载和使用

阶段5:效果评估

3.5.1 基础评估
复制代码
from tqdm import tqdm

def predict_and_evaluate(model, dir_path, class_names=None):
    """评估整个目录下的所有图片并返回结果"""
    results = []
    if class_names is None:
        class_names = {0: 'Organic', 1: 'Recyclable'}
    
    # 获取所有图片文件路径
    image_paths = []
    for root, _, files in os.walk(dir_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(root, file))
    
    total_images = len(image_paths)
    print(f"开始评估 {total_images} 张图片...")
    
    # 使用进度条
    for img_path in tqdm(image_paths, desc="预测进度"):
        # 提取真实类别(基于目录名)
        true_class = 'Organic' if 'O' in img_path else 'Recyclable'
        
        # 加载和预处理图像
        img = Image.open(img_path)
        img = img.convert('RGB').resize(target_size)
        img_array = np.array(img) / 255.0
        img_array = np.expand_dims(img_array, axis=0)
        
        # 预测
        prediction = model.predict(img_array, verbose=0)
        pred_class = class_names[1] if prediction[0][0] > 0.5 else class_names[0]
        confidence = prediction[0][0] if pred_class == 'Recyclable' else 1 - prediction[0][0]
        
        results.append({
            'filename': os.path.basename(img_path),
            'true_class': true_class,
            'predicted_class': pred_class,
            'confidence': float(confidence),
            'correct': true_class == pred_class
        })
    
    return results

# 执行批量测试
test_results = predict_and_evaluate(model, validation_data_path)

# 计算准确率
correct_count = sum(1 for res in test_results if res['correct'])
accuracy = correct_count / len(test_results) * 100

print(f"\n测试集评估结果:")
print(f"总样本数: {len(test_results)}")
print(f"正确分类数: {correct_count}")
print(f"准确率: {accuracy:.2f}%")

输出示例

复制代码
测试集评估结果:
总样本数: 4512
正确分类数: 4131
准确率: 91.56%
3.5.2 详细性能评估
复制代码
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def detailed_evaluation(model, test_data_path, class_names=None):
    """详细的模型评估:混淆矩阵、精确率、召回率、F1等"""
    
    if class_names is None:
        class_names = {0: 'Organic', 1: 'Recyclable'}
    
    # 收集所有预测结果
    y_true = []
    y_pred = []
    
    # 获取所有测试图片
    image_paths = []
    for root, _, files in os.walk(test_data_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(root, file))
    
    for img_path in tqdm(image_paths, desc="预测进度"):
        # 提取真实标签
        true_label = 0 if 'O' in img_path else 1
        y_true.append(true_label)
        
        # 加载和预处理图像
        img = Image.open(img_path)
        img = img.convert('RGB').resize(target_size)
        img_array = np.array(img) / 255.0
        img_array = np.expand_dims(img_array, axis=0)
        
        # 预测
        prediction = model.predict(img_array, verbose=0)
        pred_label = 1 if prediction[0][0] > 0.5 else 0
        y_pred.append(pred_label)
    
    # 混淆矩阵
    cm = confusion_matrix(y_true, y_pred)
    
    # 可视化混淆矩阵
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Organic', 'Recyclable'],
                yticklabels=['Organic', 'Recyclable'])
    plt.title('混淆矩阵 (Confusion Matrix)', fontsize=14, fontweight='bold')
    plt.ylabel('真实类别', fontsize=12)
    plt.xlabel('预测类别', fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # 分类报告
    report = classification_report(y_true, y_pred, 
                                   target_names=['Organic', 'Recyclable'],
                                   output_dict=True)
    
    print("\n详细性能指标:")
    print(f"{'类别':<15} {'精确率':<10} {'召回率':<10} {'F1分数':<10} {'样本数':<10}")
    print("-" * 60)
    for class_name in ['Organic', 'Recyclable']:
        metrics = report[class_name]
        print(f"{class_name:<15} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<10.4f} {metrics['f1-score']:<10.4f} "
              f"{int(metrics['support']):<10}")
    
    print("-" * 60)
    print(f"{'总体准确率':<15} {report['accuracy']:<10.4f}")
    
    return {
        'confusion_matrix': cm,
        'classification_report': report,
        'accuracy': report['accuracy']
    }

# 执行详细评估
evaluation_results = detailed_evaluation(model, validation_data_path)

知识点

  • 混淆矩阵:可视化分类错误情况
  • 精确率、召回率、F1分数:更全面的性能指标
  • 分类报告:提供每个类别的详细性能指标

阶段6:部署应用

3.6.1 单张图片预测
复制代码
def predict_custom_image(model, img_path, class_names=None, target_size=target_size):
    """
    预测自定义图片的类别
    
    参数:
    model -- 训练好的模型
    img_path -- 图片路径
    class_names -- 类别映射字典
    target_size -- 模型所需的输入尺寸
    
    返回:
    预测结果和置信度
    """
    if class_names is None:
        class_names = {0: 'Organic', 1: 'Recyclable'}
    
    try:
        # 加载图片
        img = Image.open(img_path)
        img = img.convert('RGB')
        
        # 预处理
        img = img.resize(target_size)
        img_array = np.array(img) / 255.0
        img_array = np.expand_dims(img_array, axis=0)
        
        # 预测
        prediction = model.predict(img_array, verbose=0)
        pred_class = class_names[1] if prediction[0][0] > 0.5 else class_names[0]
        confidence = prediction[0][0] if pred_class == 'Recyclable' else 1 - prediction[0][0]
        confidence_percent = confidence * 100
        
        return pred_class, float(confidence)
    
    except Exception as e:
        print(f"预测失败: {str(e)}")
        return None, None

# 测试预测函数
sample_image = os.path.join(input_images_path, "sample.jpg")
if os.path.exists(sample_image):
    pred_class, confidence = predict_custom_image(model, sample_image)
    print(f"预测结果: {pred_class} (置信度: {confidence*100:.2f}%)")
3.6.2 交互式预测界面
复制代码
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# 创建UI组件
upload = widgets.FileUpload(description="上传图片", accept='image/*', multiple=False)
predict_button = widgets.Button(description="预测")
output = widgets.Output()

# 处理预测
def on_predict_button_clicked(b):
    with output:
        clear_output()
        if not upload.value:
            print("请先上传图片!")
            return
            
        try:
            # 获取上传的图片
            uploaded_file_name = list(upload.value.keys())[0]
            img_data = upload.value[uploaded_file_name]['content']
            
            # 将字节数据转换为PIL图像
            img = Image.open(io.BytesIO(img_data))
            img = img.convert('RGB').resize(target_size)
            
            # 转换为模型输入格式
            img_array = np.array(img) / 255.0
            img_array = np.expand_dims(img_array, axis=0)
            
            # 预测
            prediction = model.predict(img_array, verbose=0)
            pred_class = "Recyclable" if prediction[0][0] > 0.5 else "Organic"
            confidence = prediction[0][0] if pred_class == "Recyclable" else 1 - prediction[0][0]
            confidence_percent = confidence * 100
            
            # 显示结果
            print(f"预测结果: {pred_class} (置信度: {confidence_percent:.2f}%)")
            
            # 显示图片
            plt.figure(figsize=(8, 6))
            plt.imshow(img)
            plt.title(f"预测: {pred_class}\n(置信度: {confidence_percent:.2f}%)")
            plt.axis('off')
            plt.tight_layout()
            plt.show()
                
        except Exception as e:
            print(f"处理图片时出错: {str(e)}")

predict_button.on_click(on_predict_button_clicked)

# 显示UI
display(upload)
display(predict_button)
display(output)

知识点

  • 交互式界面:使用IPython Widgets创建用户友好的界面
  • 实时预测:用户可以上传图片并立即获得预测结果

4. 关键技术点总结

4.1 迁移学习

  • 预训练模型:使用MobileNetV2作为特征提取器
  • 冻结基础模型:只训练顶层,加快训练速度
  • 快速训练:不需要从零开始训练,可以快速获得好效果

4.2 数据增强

  • 旋转、翻转、缩放:增加数据多样性
  • 提高泛化能力:帮助模型适应不同的输入变化

4.3 数据不平衡处理

  • 类别权重:给少数类别更高的权重
  • 平衡策略 :使用class_weight='balanced'自动计算权重

4.4 模型评估

  • 准确率:整体分类正确率
  • 混淆矩阵:可视化分类错误情况
  • 精确率、召回率、F1分数:更全面的性能指标

5. 总结与扩展

5.1 主要发现

  • 迁移学习效果优秀:使用MobileNetV2迁移学习,准确率达到91.56%,超过90%目标
  • 数据增强有效:数据增强显著提高了模型泛化能力
  • 类别权重重要:使用类别权重处理数据不平衡问题,提高了模型性能
  • 模型大小合适:模型大小约8.62 MB,满足≤20MB要求

5.2 后续改进方向

  1. 收集更多数据

    • 收集更多真实场景数据
    • 增加数据多样性
    • 处理边缘案例
  2. 模型优化

    • 尝试更先进的模型架构
    • 微调预训练模型
    • 超参数调优
  3. 实际应用

    • 部署为Web应用
    • 开发移动端应用
    • 集成到现有系统中
  4. 功能扩展

    • 支持更多垃圾类别
    • 添加垃圾识别说明
    • 提供环保建议

6. 参考资料

  1. 数据集

  2. 技术文档

  3. 相关论文

    • 迁移学习在图像分类中的应用
    • 垃圾分类的计算机视觉方法研究
  4. 代码仓库(在建中)

    • 项目代码可在GitHub上查看
    • Jupyter Notebook文件包含完整的实现代码

结语

本项目完整展示了从需求界定到模型部署的AI项目周期,通过迁移学习,我们成功构建了一个高准确率的垃圾分类识别系统。在实际应用中,可以根据具体需求扩展功能,如收集更多数据、模型优化、实际应用部署等。

希望本文能够帮助读者理解迁移学习在环保领域的应用,并为实际项目提供参考。如有问题或建议,欢迎交流讨论!


作者 :Testopia
日期 :2026年2月
标签:#迁移学习 #图像分类 #垃圾分类 #MobileNetV2 #环保 #AI项目周期 #Python

相关推荐
智算菩萨3 小时前
Claude Sonnet 4.6:大语言模型架构演进与前沿性能评估
人工智能·ai编程·ai写作
deepdata_cn3 小时前
聚类用于人群标签的实操思路
机器学习·数据挖掘·聚类
weixin_427179283 小时前
cursor新版本
ai·ai编程
想用offer打牌19 小时前
一站式了解Agent Skills
人工智能·后端·ai编程
一切尽在,你来19 小时前
LangGraph快速入门
人工智能·python·langchain·ai编程
Faker66363aaa21 小时前
Mask R-CNN实现植物存在性检测与分类详解_基于R50-FPN-GRoIE_1x_COCO模型分析
人工智能·分类·cnn
icestone20001 天前
使用Cursor开发大型项目的技巧
前端·人工智能·ai编程
rainstop_31 天前
为 Claude Code 开发自定义 Skill:解决中国地图坐标系转换痛点
gis·ai编程·claude