16.CNN——猫狗二分类识别

网络训练------猫狗二分类

逻辑梳理

  • 起手式:检查显卡状态
  • 数据增强策略
  • 数据读取
  • 定义损失函数
  • 定义优化器
  • 编写骨干网络
  • 开始训练+打印信息
  • 模型验证
  • 输出参数文件
bash 复制代码
import os
import torch
import torch as py
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms


use_cuda = py.cuda.is_available()

if use_cuda:
    device = py.device("cuda")
else:
    device = py.device("cpu")

print(f"device is {device}\n")

# 构建网络模型,写需要训练参数的层
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16, 3,1) # conv2d (通道数,卷积核个数,卷积核大小,步长1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(18496, 512)#输入节点数,输出节点数,先确认输出,输入随便写让他报错
        self.fc2 = nn.Linear(512, 2)#1和2都可以,1是Sigmoid+二分类交叉熵,2是Softmax+多分类交叉熵
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)

        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        #最后返回的是对数几率,如果不做预测,就不用接softmax算概率
        #如果做预测,就将对数几率取最大,对应索引就是预测的类别

        return x

# 创建模型实例
model = Net().to(device)

#compose是组合,组合进行一些数据增强
transforms_for_train = transforms.Compose(
    [
        transforms.Resize((150, 150)),

        # 常用
        #transforms.RandomRotation(degrees = (15, 170)), # 随机旋转角度范围从15到170
        transforms.ColorJitter(0.5,0.5,0.5,0.5), # 随机改变图片亮度,对比度,饱和度
        
        # 不常用
        transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转
        transforms.RandomVerticalFlip(p=0.5),# 随机竖直翻转
        #transforms.RandomCrop(150), # 随机裁剪,大小是150
        #transforms.RandomResizedCrop(size=150), # 随机裁剪,缩放到150
        #transforms.RandomAffine(degrees = (30,60), translate=(0.2, 0.2), scale=(1.5, 0.5), shear = (-3, 3)),# 仿射变换 旋转30到60,平移0.2,缩放,伸缩

        transforms.ToTensor(),
    ])

transforms_for_test = transforms.Compose(
    [
        transforms.Resize((150, 150)),
        transforms.ToTensor()
    ])

base_dir = '.data/cat_dog_data' # 用于保存数据的目录
train_dir = os.path.join(base_dir, 'train') # 用于训练
test_dir = os.path.join(base_dir, 'test') # 用于测试

train_datasets = datasets.ImageFolder(train_dir, transform=transforms_for_train)
test_datasets = datasets.ImageFolder(test_dir, transform=transforms_for_test)

# 验证数据集目录是否正确
print(train_datasets.classes)
print(train_datasets.class_to_idx)

# 测试集、训练集加载器,取一个批次一个批次的数据,num_workers表示用几个线程加载数据
train_loader = torch.utils.data.DataLoader(train_datasets,batch_size=64, shuffle=True,num_workers = 6,pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_datasets,batch_size=64, shuffle=False,pin_memory=True, num_workers = 6   )


# 之所以Net最后输出的是对数几率而不是概率,是为了使用多分类交叉熵这个函数,函数本身要求的为归一化的对数几率
# 对于二分类任务有三种写法
# Linear输出层一个节点 + Sigmoid 非线性变换得概率+ BceLoss 输出层逻辑回归
# Linear输出层一个节点 + BCEWithLogitsLoss 二分类交叉熵(内部集成了Sigmoid)
# Linear输出层两个节点 + CrossEntropyLoss 多分类交叉熵(内部集成了Softmax)
loss_f = torch.nn.CrossEntropyLoss()

#优化器
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)

# 开始训练
EPOCHS = 20 # 训练轮数
for epoch in range(EPOCHS):
    print(epoch)

    model.train()
    running_loss = 0.0
    running_corrects = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        # 接收一个批次的图片x和标签y
        X,y = data.to(device), target.to(device)
        # 开始每次正向、反向传播之前,初始化梯度
        optimizer.zero_grad()

        # 正向传播
        y_pred = model(X)
        # 计算平均loss
        loss = loss_f(y_pred, y)

        #反向传播
        loss.backward()
        #更新模型参数
        optimizer.step()
        # 记录损失和,item()取出该张量的元素
        running_loss += loss.item() 
        
        # 获取预测的类别号
        pred = y_pred.argmax(dim = 1, keepdim = True)
        # 统计当前批次预测正确的样本数,view_as是改变y的形状使其和y_pred一样
        running_corrects += pred.eq(y.view_as(pred)).sum().item() 

    # 打印每一轮数据信息
    # 假设100张图片,batch_size为20,那么需要5个batch来处理这100张图片
    # loss: 计算的是一个batch的20张图片训练的网络的loss,这是个平均值
    # 继续循环遍历求和,running_loss计算的是5个batch的平均loss和
    # running_loss * batch_size是当前轮,用各批次平均loss算出来的总loss
    # len(train_loader)当前一轮训练的所有图片数量,是100张
    # 得出每张图片的平均loss
    epoch_loss = running_loss * 64 / len(train_datasets)

    # epoch_accuracy = 累加每张图片的正确数,/总的图片数,就是正确的比例
    epoch_acc = running_corrects / len(train_datasets)

    print(f"Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
    

model.eval() # 设置为验证模式,固定模型,不让其继续优化

test_loss = 0.0
test_corrects = 0.0

# 声明以下行为和梯度无关
with torch.no_grad():
    for data, target in test_loader:   
        X, y = data.to(device), target.to(device)

        y_pred = model(X)
        loss = loss_f(y_pred, y)
        test_loss += loss.item() 
        
        pred = y_pred.argmax(dim = 1, keepdim = True)
        test_corrects += pred.eq(y.view_as(pred)).sum().item()

epoch_loss = test_loss * 64 / len(test_datasets)
epoch_acc = test_corrects / len(test_datasets)
print(f"Test Loss: {epoch_loss:.4f} Test Acc: {epoch_acc:.4f}")

# 保存模型权重
torch.save(model.state_dict(), "/workspace/model_weights.pth")
print("模型权重已保存到 /workspace/model_weights.pth")

保存权重

训练结束后,会生成权重文件,一般为.pt或.pth,可用于后续直接进行预测

预测文件

bash 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os

# 定义与训练时相同的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(18496, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型并加载权重
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load("/workspace/model_weights.pth"))
model.eval()

# 定义与训练时相同的图像变换
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
])

# 单张图片预测
def predict_image(image_path):
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        _, pred = torch.max(output, 1)
        prob = torch.nn.functional.softmax(output, dim=1)[0] * 100
    
    classes = ['cat', 'dog']
    return classes[pred.item()], prob[pred.item()].item()

# 批量预测
def predict_batch(image_dir):
    if not os.path.isdir(image_dir):
        raise ValueError("输入路径不是有效目录")
        
    image_files = [f for f in os.listdir(image_dir) 
                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if not image_files:
        print("警告: 目录中没有图片文件")
        return
        
    total = len(image_files)
    correct = 0
    
    for filename in image_files:
        filepath = os.path.join(image_dir, filename)
        try:
            # 从文件名提取真实标签(假设文件名格式: label_xxx.jpg)
            true_label = filename.split('_')[0].lower()
            pred_label, confidence = predict_image(filepath)
            
            # 验证预测结果
            is_correct = (pred_label == true_label)
            if is_correct:
                correct += 1
                
            print(f"{filename}: 预测={pred_label}({confidence:.1f}%) | 实际={true_label} | {'✓' if is_correct else '✗'}")
        except Exception as e:
            print(f"{filename}: 处理失败 - {str(e)}")
    
    print(f"\n验证完成: 准确率 {correct}/{total} ({correct/total*100:.1f}%)")

# 示例用法
if __name__ == "__main__":
    mode = input("请选择模式 [1]单张图片 [2]批量验证: ")
    
    if mode == "1":
        image_path = input("请输入图片路径: ")
        # image_path = "/path/to/your/image.jpg"
        if os.path.exists(image_path):
            class_name, confidence = predict_image(image_path)
            print(f"预测结果: {class_name} (置信度: {confidence:.2f}%)")
        else:
            print("错误: 文件不存在")
    elif mode == "2":
        dir_path = input("请输入图片目录路径: ")
        predict_batch(dir_path)
    else:
        print("错误: 无效模式选择")