卷积神经网络CNN

一、今天学了什么?

一套完整的 CNN 图像分类项目,分两大块:

  1. 训练模型读文件夹图片 → 搭建卷积网络 → 训练 → 保存 model.h5
  2. 预测单张图片加载模型 → 处理图片 → 扩维度 → 输出类别和概率

训练代码核心模板

复制代码
# 1. 导包
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import Dense, Flatten, Conv2D, MaxPool2D

# 2. 路径 & 类别
data_train = 'data/train'
data_val   = 'data/val'
CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']

# 3. 图片参数
BATCH_SIZE = 64
IMG_H = 32
IMG_W = 32

# 4. 归一化
datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

# 5. 自动读图(固定5个参数)
train_gen = datagen.flow_from_directory(
    directory=data_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    target_size=(IMG_H, IMG_W),
    classes=CLASS_NAMES
)

val_gen = datagen.flow_from_directory(...)  # 同上改val

# 6. 搭建CNN
model = keras.Sequential([
    Conv2D(6, 5, activation='relu', input_shape=(32,32,3)),
    MaxPool2D(2,2),
    Conv2D(16,5, activation='relu'),
    MaxPool2D(2,2),
    Conv2D(120,2, activation='relu'),
    Flatten(),
    Dense(84, activation='relu'),
    Dense(6, activation='softmax')  # 6类写6
])

# 7. 编译(多分类必用这个loss)
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# 8. 训练(必须写history=才能画图)
history = model.fit(train_gen, validation_data=val_gen, epochs=50)

# 9. 保存模型
model.save('model.h5')

预测代码核心模板(直接抄)

复制代码
import cv2
import numpy as np
from keras.models import load_model

CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']
model = load_model('model.h5')

# 1. 读图 + 预处理
img = cv2.imread('test.bmp')
img = cv2.resize(img, (32,32))
img = img / 255.0  # 必须归一化

# 2. 扩维度(必写!模型要4维)
test_img = np.expand_dims(img, axis=0)

# 3. 预测
pred = model.predict(test_img)
idx = np.argmax(pred)   # 找最大概率下标
name = CLASS_NAMES[idx]
prob = np.max(pred)

print(f"结果:{name},概率:{prob:.2f}")

二、核心重点(必须记住)

1. 文件夹结构是关键

复制代码
data/train/类别1/xxx.jpg
data/train/类别2/xxx.jpg
data/val/类别1/xxx.jpg
...

flow_from_directory 是自动按文件夹分类

2. 图片必须统一处理(三步固定)

  • 缩放:resize(32, 32)
  • 归一化:/255(0~255 → 0~1)
  • 预测时扩维度:np.expand_dims(..., axis=0)

3. CNN 网络固定套路

复制代码
卷积 Conv2D → 池化 MaxPooling2D
→ 卷积 → 池化
→ 卷积
→ Flatten 展平
→ Dense 全连接
→ Dense(类别数, softmax)

4. 多分类必须用

复制代码
loss='categorical_crossentropy'

5. 训练时要接收 history 才能画图

复制代码
history = model.fit(...)

6. 预测两步神操作

复制代码
history = model.fit(...)

三、小白高频易错点(最容易报错的地方)

1. loss 函数写错

  • 二分类才用 binary_crossentropy
  • **6 分类必须用 categorical_crossentropy**写错直接不收敛、准确率极低。

2. 类别顺序乱填

CLASS_NAMES 顺序必须和文件夹顺序一致,否则识别全错。

3. 池化层参数写在括号外面 ❌

复制代码
model.add(MaxPooling2D(...), strides=...)  # 错
model.add(MaxPooling2D(..., strides=...))  # 对

4. 预测时没扩维度

模型要 4 维:(1,32,32,3)图片只有 3 维:(32,32,3)不加 expand_dims 直接报错。

5. 训练没写 history =

后面画图 history.history 会报错。

6. 图片预处理和训练不一致

训练时 32×32、/255预测时也要 完全一样,不然模型看不懂。

四、一句话记住今天所有内容

文件夹放好图 → 统一缩放到 32×32 → 归一化 → 搭 CNN 训练 → 保存模型 → 预测时扩维度 → argmax 出结果

相关推荐
管二狗赶快去工作!3 小时前
体系结构论文(九十九):Large Language Models (LLMs) for Electronic Design Automation (EDA)
人工智能·语言模型·自然语言处理
Rubin智造社3 小时前
04月09日AI每日参考:Anthropic Mythos限制公开,Meta发布首款超级智能模型
人工智能·开源大模型·ai安全·anthropic·claude mythos·meta muse spark·google gemma 4
沪漂阿龙3 小时前
PyTorch 张量与自动微分完全指南:从核心概念到实战训练
人工智能·pytorch·python
LaughingZhu3 小时前
Product Hunt 每日热榜 | 2026-04-09
人工智能·经验分享·深度学习·神经网络·产品运营
roman_日积跬步-终至千里3 小时前
【系统架构师-案例题-Web应用系统架构设计】22年(4)基于边缘计算的智能门禁系统
人工智能·系统架构·边缘计算
星纬智联技术3 小时前
微信小程序72小时交付:从“营销噱头”到“标准服务”的拐点已至
人工智能·aigc·搜索引擎优化
小真zzz3 小时前
搜极星:你的免费“AI内容验真器”
大数据·人工智能·ai·chatgpt·seo·geo
格林黄3 小时前
【无标题】
人工智能·python
奇思智算3 小时前
LLaMA/Bert/扩散模型微调GPU选型及租用指南
人工智能·bert·llama