卷积神经网络CNN

一、今天学了什么?

一套完整的 CNN 图像分类项目,分两大块:

  1. 训练模型读文件夹图片 → 搭建卷积网络 → 训练 → 保存 model.h5
  2. 预测单张图片加载模型 → 处理图片 → 扩维度 → 输出类别和概率

训练代码核心模板

复制代码
# 1. 导包
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import Dense, Flatten, Conv2D, MaxPool2D

# 2. 路径 & 类别
data_train = 'data/train'
data_val   = 'data/val'
CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']

# 3. 图片参数
BATCH_SIZE = 64
IMG_H = 32
IMG_W = 32

# 4. 归一化
datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

# 5. 自动读图(固定5个参数)
train_gen = datagen.flow_from_directory(
    directory=data_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    target_size=(IMG_H, IMG_W),
    classes=CLASS_NAMES
)

val_gen = datagen.flow_from_directory(...)  # 同上改val

# 6. 搭建CNN
model = keras.Sequential([
    Conv2D(6, 5, activation='relu', input_shape=(32,32,3)),
    MaxPool2D(2,2),
    Conv2D(16,5, activation='relu'),
    MaxPool2D(2,2),
    Conv2D(120,2, activation='relu'),
    Flatten(),
    Dense(84, activation='relu'),
    Dense(6, activation='softmax')  # 6类写6
])

# 7. 编译(多分类必用这个loss)
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# 8. 训练(必须写history=才能画图)
history = model.fit(train_gen, validation_data=val_gen, epochs=50)

# 9. 保存模型
model.save('model.h5')

预测代码核心模板(直接抄)

复制代码
import cv2
import numpy as np
from keras.models import load_model

CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']
model = load_model('model.h5')

# 1. 读图 + 预处理
img = cv2.imread('test.bmp')
img = cv2.resize(img, (32,32))
img = img / 255.0  # 必须归一化

# 2. 扩维度(必写!模型要4维)
test_img = np.expand_dims(img, axis=0)

# 3. 预测
pred = model.predict(test_img)
idx = np.argmax(pred)   # 找最大概率下标
name = CLASS_NAMES[idx]
prob = np.max(pred)

print(f"结果:{name},概率:{prob:.2f}")

二、核心重点(必须记住)

1. 文件夹结构是关键

复制代码
data/train/类别1/xxx.jpg
data/train/类别2/xxx.jpg
data/val/类别1/xxx.jpg
...

flow_from_directory 是自动按文件夹分类

2. 图片必须统一处理(三步固定)

  • 缩放:resize(32, 32)
  • 归一化:/255(0~255 → 0~1)
  • 预测时扩维度:np.expand_dims(..., axis=0)

3. CNN 网络固定套路

复制代码
卷积 Conv2D → 池化 MaxPooling2D
→ 卷积 → 池化
→ 卷积
→ Flatten 展平
→ Dense 全连接
→ Dense(类别数, softmax)

4. 多分类必须用

复制代码
loss='categorical_crossentropy'

5. 训练时要接收 history 才能画图

复制代码
history = model.fit(...)

6. 预测两步神操作

复制代码
history = model.fit(...)

三、小白高频易错点(最容易报错的地方)

1. loss 函数写错

  • 二分类才用 binary_crossentropy
  • **6 分类必须用 categorical_crossentropy**写错直接不收敛、准确率极低。

2. 类别顺序乱填

CLASS_NAMES 顺序必须和文件夹顺序一致,否则识别全错。

3. 池化层参数写在括号外面 ❌

复制代码
model.add(MaxPooling2D(...), strides=...)  # 错
model.add(MaxPooling2D(..., strides=...))  # 对

4. 预测时没扩维度

模型要 4 维:(1,32,32,3)图片只有 3 维:(32,32,3)不加 expand_dims 直接报错。

5. 训练没写 history =

后面画图 history.history 会报错。

6. 图片预处理和训练不一致

训练时 32×32、/255预测时也要 完全一样,不然模型看不懂。

四、一句话记住今天所有内容

文件夹放好图 → 统一缩放到 32×32 → 归一化 → 搭 CNN 训练 → 保存模型 → 预测时扩维度 → argmax 出结果

相关推荐
东坡肘子1 小时前
SPI 加入 Apple,Swift 迈向自举 -- 肘子的 Swift 周报 #142
人工智能·swiftui·swift
小和尚同志9 小时前
AI 自动化测试探索(二):Chrome-devtools MCP
人工智能·e2e·aigc
冬奇Lab11 小时前
Workflow 系列(02):设计范式——四层架构、三种 Context 传递模式与确认门设计
人工智能·agent·工作流引擎
冬奇Lab12 小时前
每日一个开源项目(第145篇):Trellis - 把项目记忆、规范和任务上下文持久化进代码仓库
人工智能·开源·资讯
有道AI情报局12 小时前
Harness即产品
人工智能·agent
罗西的思考13 小时前
机器人 / 强化学习】HIL-SERL:人类在环驱动的具身智能进化框架
人工智能·算法·机器学习
IT_陈寒14 小时前
SpringBoot自动配置的坑,我的API突然就404了
前端·人工智能·后端
笃行35014 小时前
从零到上线:用 EdgeOne Makers + CodeBuddy 搭一个「对账核对员」AI Agent
人工智能
用户68563262086915 小时前
Claude Code 乱猜字段名?我给它写了一个"数据库查询约束 Skill"
人工智能
你_好15 小时前
# 给你的产品嵌入一个「会操作界面的 AI 助手」
人工智能