中药细粒度图像分类

在细粒度图像分类(FGVC)领域,Bilinear CNN(BCNN)模型因其能够捕捉图像中的局部特征交互而受到广泛关注。该模型通过双线性池化操作将两个不同CNN提取的特征进行外积运算,从而获得更加丰富的特征表示,这对于区分外观相似但属于不同子类别的物体尤其有效。然而,BCNN通常计算成本较高,限制了其在移动设备或资源受限环境下的应用。

为了实现轻量化并保持高精度的细粒度分类,可以考虑将MobileNetV2引入到BCNN框架中。MobileNetV2以其深度可分离卷积和倒残差结构著称,能够在减少计算复杂度的同时保证较高的分类性能。此外,MobileNetV2中的线性瓶颈和逐点卷积有助于更有效地处理稀疏数据,进一步提升网络的表达能力。

在此基础上,添加Inception模块是一个值得探索的方向。Inception模块通过并行使用多种尺寸的卷积核,能够同时捕捉不同尺度的特征信息,这对于中药这种形态各异、纹理复杂的对象来说尤为重要。结合Inception模块的多尺度特征提取能力和MobileNetV2的高效架构,可以在不显著增加计算负担的前提下增强模型对细节特征的敏感度。

弱监督学习则允许我们仅依赖图像级别的标签来进行训练,无需精确的边界框或部分注释,这大大降低了标注成本,并使得大规模数据集的应用成为可能。特别是在中药分类这样一个需要大量专业知识才能准确标注的领域,弱监督方法能够显著降低专家标注的工作量。

好的我来讲代码部分

https://github.com/HaoMood/bilinear-cnn.git 基于这个大佬的代码改进

https://github.com/hackerjackL/xilidu.git 这是我的代码仓库

bash 复制代码
pip install torch torchvision pillow tqdm #理论上应该是这些

torch >=2.0  #1.0版本不行,有些函数和方法用不了
python 3.8

models.py

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import MobileNet_V2_Weights

class InceptionModule(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(InceptionModule, self).__init__()
        self.branch1 = nn.Conv2d(in_channels, ch1x1, kernel_size=1)
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3red, kernel_size=1),
            nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5red, kernel_size=1),
            nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = F.relu(self.branch1(x))
        branch2 = F.relu(self.branch2(x))
        branch3 = F.relu(self.branch3(x))
        branch4 = F.relu(self.branch4(x))
        return torch.cat([branch1, branch2, branch3, branch4], 1)

class MobileNetV2Classifier(nn.Module):
    """Bilinear CNN Model using MobileNetV2"""
    def __init__(self, num_classes):
        super(MobileNetV2Classifier, self).__init__()
        model_urls = {
            'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
        }
        mobilenet_v2 = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        self.features = mobilenet_v2.features
        
        # Freeze the features layers
        for param in self.features.parameters():
            param.requires_grad = False
        
        # Add a new classifier on top of the features
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(mobilenet_v2.last_channel, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.mean([2, 3])  # Global average pooling
        x = self.classifier(x)
        return x

InceptionModule

InceptionModule 类实现了一个经典的 Inception 模块,它可以在 GoogLeNet 等模型中找到。这个模块允许网络在多个尺度上并行处理信息。

  • 构造函数 (__init__ 方法):

    • 接受输入通道数以及各个分支的通道数配置。
    • 分支1: 使用一个 1x1 卷积来减少维度。
    • 分支2: 先通过一个 1x1 卷积进行降维,然后使用一个 3x3 卷积(带有 padding 来保持尺寸)。
    • 分支3: 类似于分支2,但使用的是 5x5 卷积。
    • 分支4: 首先进行 3x3 最大池化,然后通过一个 1x1 卷积调整通道数。
  • 前向传播 (forward 方法):

    • 对每个分支应用 ReLU 激活函数,并将它们的结果沿通道维度拼接起来。

MobileNetV2Classifier

MobileNetV2Classifier 类基于 MobileNet V2 模型,用于图像分类任务。它利用了预训练的 MobileNet V2 特征提取器,并在其基础上添加了一个新的分类头。

  • 构造函数 (__init__ 方法):

    • 加载预训练的 MobileNet V2 模型(使用 weights=MobileNet_V2_Weights.IMAGENET1K_V1 参数指定加载 ImageNet 上预训练的权重)。
    • 冻结特征提取层的参数(即设置 requires_grad=False),以便只训练新添加的分类层。
    • 添加了一个由 Dropout 层和全连接层组成的分类头。全连接层的输入大小是 MobileNet V2 的最后一个特征通道数(mobilenet_v2.last_channel),输出大小是类别数。
  • 前向传播 (forward 方法):

    • 输入数据首先通过 MobileNet V2 的特征提取层。
    • 然后对特征图进行全局平均池化(即将每个特征图缩减为单个数值),这一步通常用于将二维特征转换为一维特征向量。
    • 最终通过分类器得到最终的分类结果。

logger.py

这里我写了一些前作者没有的 例如显存 打印网络层 预计时间等 不过打印显存信息有误 请还是nvidia --smi 查看吧

python 复制代码
import time
import torch
from tqdm import tqdm
import os
import csv

def print_model_structure(model):
    print("Model structure:")
    print("----------------------------------------------------------------")
    for name, module in model.named_modules():
        if name == '': continue
        indent = name.count('.')
        print(' ' * (4 * indent) + name + ':', str(module).split('\n')[0])
    print("----------------------------------------------------------------")
    print("Model summary:")


def log_training_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used, writer):
    info = {
        "GPU Memory Used": f"{gpu_mem_used:.2f}MB",
        "Epoch": f"{epoch}",
        "Train Loss": f"{train_loss:.4f}",
        "Train Accuracy": f"{train_acc:.2f}%",
        "Validation Accuracy": f"{val_acc:.2f}%",
        "Epoch Time": f"{epoch_time:.1f}s",
        "Remaining Time": f"{remaining_time:.1f}s"
    }
    print("\t".join(info.values()))
    writer.writerow(info)


class TrainingLogger:
    def __init__(self, epochs, result_file):
        self.pbar = tqdm(total=epochs, desc="Training", unit="epoch")
        self.result_file = result_file
        self.fieldnames = ["GPU Memory Used", "Epoch", "Train Loss", "Train Accuracy", 
                           "Validation Accuracy", "Epoch Time", "Remaining Time"]
        with open(self.result_file, mode='w', newline='') as file:
            self.writer = csv.DictWriter(file, fieldnames=self.fieldnames)
            self.writer.writeheader()

    def update_progress(self):
        self.pbar.update(1)

    def close_progress(self):
        self.pbar.close()

    def log_epoch_info(self, epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used):
        with open(self.result_file, mode='a', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=self.fieldnames)
            log_training_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used, writer)

config.py

这里存放了一些定义的超参 运行代码 例如

python train.py --data_dir /root/images --batch_size 32 --epochs 100

python 复制代码
import argparse

def get_config():
    parser = argparse.ArgumentParser(description="Bilinear CNN Training")
    parser.add_argument("--data_dir", type=str, default="./images",
                       help="Root directory of dataset (contains train/val folders)")
    parser.add_argument("--model_dir", type=str, default="./models",
                       help="Directory to save trained models")
    parser.add_argument("--batch_size", type=int, default=32,
                       help="Input batch size for training")
    parser.add_argument("--epochs", type=int, default=100,
                       help="Number of epochs to train")
    parser.add_argument("--lr", type=float, default=0.001,
                       help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                       help="Weight decay")
    parser.add_argument("--workers", type=int, default=4,
                       help="Number of data loading workers")
    parser.add_argument("--optimizer", type=str, default="adam",
                       choices=["sgd", "adam"],
                       help="Optimizer to use (default: sgd)")
    parser.add_argument("--scheduler", type=str, default="reduce_on_plateau",
                       choices=["reduce_on_plateau", "cosine_annealing"],
                       help="Learning rate scheduler to use (default: reduce_on_plateau)")
    parser.add_argument("--patience", type=int, default=3,
                       help="Patience for ReduceLROnPlateau scheduler (default: 3)")
    return parser.parse_args()

下面是一个trainer.py 这个是串联我们其他py的核心组件

1. 初始化方法 (__init__)

Trainer 类的初始化方法首先设置了一些基本参数,如设备类型(CPU 或 GPU)、训练目录等。它还定义了数据预处理的方式,并加载了训练和验证数据集。这里使用了 torchvision.transforms 来进行数据增强,包括随机水平翻转、随机裁剪等,以提高模型的泛化能力。

数据加载
  • 使用 torchvision.datasets.ImageFolder 来加载数据集,该函数假设数据集按类别组织在不同的子文件夹中。
  • 数据预处理步骤包括调整大小、数据增强、转换为张量以及归一化处理。
模型创建
  • 创建了一个 MobileNetV2Classifier 实例,如果存在多个 GPU,则使用 nn.DataParallel 来并行化模型训练。
  • 打印模型结构,方便调试和理解模型架构。
损失函数和优化器
  • 定义了交叉熵损失函数 nn.CrossEntropyLoss(),适用于多分类问题。
  • 根据传入的参数选择合适的优化器(SGD 或 Adam),并且仅对分类头部分的参数进行优化。
学习率调度器
  • 支持两种调度策略:ReduceLROnPlateauCosineAnnealingLR,它们分别根据验证准确率的变化或按照余弦退火方式调整学习率。

2. 训练过程 (train 方法)

train 方法是整个训练流程的核心,它通过循环执行多次迭代(每个 epoch)来进行模型训练。每一轮迭代都包含以下几个步骤:

  • 训练阶段:模型处于训练模式,对每个批次的数据进行前向传播计算损失,然后通过反向传播更新模型权重。
  • 验证阶段:切换到评估模式,不进行梯度计算,仅评估模型性能(准确率)。
  • 学习率调整:根据验证结果调整学习率。
  • 保存最佳模型:如果当前 epoch 的验证准确率优于历史最高值,则保存当前模型状态。

此外,还记录了每次迭代的训练损失、准确率及验证准确率,并通过 TrainingLogger 对象将其写入 CSV 文件以便后续分析。

3. 验证过程 (validate 方法)

validate 方法用于评估模型在验证集上的表现,它与训练阶段的主要区别在于不进行参数更新,而是单纯地计算模型预测的准确性。

4. 结果可视化 (plot_results 方法)

训练完成后,plot_results 方法会生成两张图表,一张展示训练过程中损失值的变化趋势,另一张则显示验证集上准确率的变化情况。这些图表有助于直观地了解模型的学习进程和性能改进情况。

5. 辅助函数

find_next_train_folder 函数用于自动查找下一个可用的训练目录名称,确保每次运行时都有独立的存储空间存放实验结果。

python 复制代码
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from models import MobileNetV2Classifier
from logger import print_model_structure, TrainingLogger
import matplotlib.pyplot as plt


class Trainer:
    def __init__(self, args):
        self.args = args
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 创建新的训练目录
        self.train_dir = find_next_train_folder(self.args.model_dir)
        os.makedirs(self.train_dir, exist_ok=True)

        # 数据预处理
        self.train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        data_dir = args.data_dir  # 修改这个基础路径
        train_dir = os.path.join(data_dir, "train")
        val_dir = os.path.join(data_dir, "val")

        # 检查数据集是否存在
        if not os.path.exists(train_dir) or not os.path.exists(val_dir):
            raise ValueError(f"数据集不存在,请检查{data_dir}文件夹是否正确。"
                             f"\nExpected directories: {train_dir} and {val_dir}")

        self.train_dataset = torchvision.datasets.ImageFolder(
            root=train_dir,
            transform=self.train_transform
        )
        self.val_dataset = torchvision.datasets.ImageFolder(
            root=val_dir,
            transform=self.val_transform
        )

        # 打印类别数量
        self.num_classes = len(self.train_dataset.classes)
        print(f"Number of classes: {self.num_classes}")
        if self.num_classes <= 1:
            raise ValueError(f"数据集中必须包含至少两个类别,当前只有 {self.num_classes} 个类别。")

        # 创建模型
        self.model = MobileNetV2Classifier(self.num_classes).to(self.device)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)

        # 打印模型结构
        print_model_structure(self.model)

        # 定义损失函数和优化器
        self.criterion = nn.CrossEntropyLoss()

        if args.optimizer == "sgd":
            self.optimizer = optim.SGD(
                self.model.module.classifier.parameters() if isinstance(self.model, nn.DataParallel)
                else self.model.classifier.parameters(),
                lr=args.lr,
                momentum=0.9,
                weight_decay=args.weight_decay
            )
        elif args.optimizer == "adam":
            self.optimizer = optim.Adam(
                self.model.module.classifier.parameters() if isinstance(self.model, nn.DataParallel)
                else self.model.classifier.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay
            )
        else:
            raise ValueError(f"Unsupported optimizer: {args.optimizer}")

        # 定义学习率调度器
        if args.scheduler == "reduce_on_plateau":
            self.scheduler = ReduceLROnPlateau(
                self.optimizer, mode='max', factor=0.1, patience=args.patience, verbose=True
            )
        elif args.scheduler == "cosine_annealing":
            self.scheduler = CosineAnnealingLR(
                self.optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1, verbose=True
            )
        else:
            raise ValueError(f"Unsupported scheduler: {args.scheduler}")

        # 数据加载器
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True
        )
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True
        )

        # 初始化训练日志记录器
        result_csv_path = os.path.join(self.train_dir, "result.csv")
        self.logger = TrainingLogger(self.args.epochs, result_csv_path)

        # 记录每个epoch的损失和准确率
        self.train_losses = []
        self.val_accuracies = []

    def train(self):
        best_acc = 0.0
        print(f"Starting training with {self.num_classes} classes...")
        print(f"GPU Rem\tEpoch\tTrain Loss\tTrain Acc\tVal Acc\tTime\tRemaining")

        for epoch in range(1, self.args.epochs + 1):  # 从1开始计数
            start_time = time.time()

            # 训练阶段
            self.model.train()
            train_loss = 0.0
            correct = 0
            total = 0

            for inputs, labels in self.train_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            train_loss /= len(self.train_loader)
            train_acc = 100.0 * correct / total

            # 验证阶段
            val_acc = self.validate()

            # 学习率调整
            if isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step(val_acc)
            else:
                self.scheduler.step()

            # 保存最佳模型
            if val_acc > best_acc:
                best_acc = val_acc
                model_path = os.path.join(
                    self.train_dir,
                    f"best_model_{val_acc:.2f}.pth"
                )
                torch.save(self.model.state_dict(), model_path)

            # 获取GPU内存使用情况
            gpu_mem_used = torch.cuda.memory_allocated(self.device) / 1e6 if torch.cuda.is_available() else 0

            # 打印统计信息
            epoch_time = time.time() - start_time
            remaining_time = (self.args.epochs - epoch) * epoch_time
            self.logger.log_epoch_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used)

            # 更新tqdm进度条
            self.logger.update_progress()

            # 记录损失和准确率
            self.train_losses.append(train_loss)
            self.val_accuracies.append(val_acc)

        self.logger.close_progress()
        print(f"Best validation accuracy: {best_acc:.2f}%")

        # 绘制损失和准确率图
        self.plot_results()

    def validate(self):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        acc = 100.0 * correct / total
        self.model.train()
        return acc

    def plot_results(self):
        epochs = list(range(1, self.args.epochs + 1))

        plt.figure(figsize=(12, 6))

        # 绘制训练损失
        plt.subplot(1, 2, 1)
        plt.plot(epochs, self.train_losses, label='Training Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss over Epochs')
        plt.legend()

        # 绘制验证准确率
        plt.subplot(1, 2, 2)
        plt.plot(epochs, self.val_accuracies, label='Validation Accuracy', color='orange')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy (%)')
        plt.title('Validation Accuracy over Epochs')
        plt.legend()

        # 保存图片
        plot_path = os.path.join(self.train_dir, "training_results.png")
        plt.savefig(plot_path)
        plt.show()


def find_next_train_folder(base_dir):
    i = 1
    while True:
        folder_name = os.path.join(base_dir, f"train{i}")
        if not os.path.exists(folder_name):
            return folder_name
        i += 1

最后就是运行脚本 主函数 train.py

python 复制代码
from config import get_config
from trainer import Trainer

def main():
    args = get_config()
    trainer = Trainer(args)
    trainer.train()

if __name__ == "__main__":
    main()

好这就是我们的代码部分

请注意我们的数据集结构

然后无需标注 只需划分train val即可

以上就是内容部分 如有问题请评论区或私信指正谢谢 !! 本科小白一枚 感谢观看!

相关推荐
带娃的IT创业者31 分钟前
机器学习实战(8):降维技术——主成分分析(PCA)
人工智能·机器学习·分类·聚类
调皮的芋头1 小时前
iOS各个证书生成细节
人工智能·ios·app·aigc
flying robot3 小时前
人工智能基础之数学基础:01高等数学基础
人工智能·机器学习
Moutai码农3 小时前
机器学习-生命周期
人工智能·python·机器学习·数据挖掘
188_djh4 小时前
# 10分钟了解DeepSeek,保姆级部署DeepSeek到WPS,实现AI赋能
人工智能·大语言模型·wps·ai技术·ai应用·deepseek·ai知识
Jackilina_Stone4 小时前
【DL】浅谈深度学习中的知识蒸馏 | 输出层知识蒸馏
人工智能·深度学习·机器学习·蒸馏
bug404_4 小时前
分布式大语言模型服务引擎vLLM论文解读
人工智能·分布式·语言模型
Logout:4 小时前
[AI]docker封装包含cuda cudnn的paddlepaddle PaddleOCR
人工智能·docker·paddlepaddle
OJAC近屿智能5 小时前
苹果新品今日发布,AI手机市场竞争加剧,近屿智能专注AI人才培养
大数据·人工智能·ai·智能手机·aigc·近屿智能
代码猪猪傻瓜coding5 小时前
关于 形状信息提取的说明
人工智能·python·深度学习