卷积神经网络CNN

一、今天学了什么?

一套完整的 CNN 图像分类项目,分两大块:

  1. 训练模型读文件夹图片 → 搭建卷积网络 → 训练 → 保存 model.h5
  2. 预测单张图片加载模型 → 处理图片 → 扩维度 → 输出类别和概率

训练代码核心模板

复制代码
# 1. 导包
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import Dense, Flatten, Conv2D, MaxPool2D

# 2. 路径 & 类别
data_train = 'data/train'
data_val   = 'data/val'
CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']

# 3. 图片参数
BATCH_SIZE = 64
IMG_H = 32
IMG_W = 32

# 4. 归一化
datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

# 5. 自动读图(固定5个参数)
train_gen = datagen.flow_from_directory(
    directory=data_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    target_size=(IMG_H, IMG_W),
    classes=CLASS_NAMES
)

val_gen = datagen.flow_from_directory(...)  # 同上改val

# 6. 搭建CNN
model = keras.Sequential([
    Conv2D(6, 5, activation='relu', input_shape=(32,32,3)),
    MaxPool2D(2,2),
    Conv2D(16,5, activation='relu'),
    MaxPool2D(2,2),
    Conv2D(120,2, activation='relu'),
    Flatten(),
    Dense(84, activation='relu'),
    Dense(6, activation='softmax')  # 6类写6
])

# 7. 编译(多分类必用这个loss)
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# 8. 训练(必须写history=才能画图)
history = model.fit(train_gen, validation_data=val_gen, epochs=50)

# 9. 保存模型
model.save('model.h5')

预测代码核心模板(直接抄)

复制代码
import cv2
import numpy as np
from keras.models import load_model

CLASS_NAMES = ['Cr','In','Pa','PS','Rs','Sc']
model = load_model('model.h5')

# 1. 读图 + 预处理
img = cv2.imread('test.bmp')
img = cv2.resize(img, (32,32))
img = img / 255.0  # 必须归一化

# 2. 扩维度(必写!模型要4维)
test_img = np.expand_dims(img, axis=0)

# 3. 预测
pred = model.predict(test_img)
idx = np.argmax(pred)   # 找最大概率下标
name = CLASS_NAMES[idx]
prob = np.max(pred)

print(f"结果:{name},概率:{prob:.2f}")

二、核心重点(必须记住)

1. 文件夹结构是关键

复制代码
data/train/类别1/xxx.jpg
data/train/类别2/xxx.jpg
data/val/类别1/xxx.jpg
...

flow_from_directory 是自动按文件夹分类

2. 图片必须统一处理(三步固定)

  • 缩放:resize(32, 32)
  • 归一化:/255(0~255 → 0~1)
  • 预测时扩维度:np.expand_dims(..., axis=0)

3. CNN 网络固定套路

复制代码
卷积 Conv2D → 池化 MaxPooling2D
→ 卷积 → 池化
→ 卷积
→ Flatten 展平
→ Dense 全连接
→ Dense(类别数, softmax)

4. 多分类必须用

复制代码
loss='categorical_crossentropy'

5. 训练时要接收 history 才能画图

复制代码
history = model.fit(...)

6. 预测两步神操作

复制代码
history = model.fit(...)

三、小白高频易错点(最容易报错的地方)

1. loss 函数写错

  • 二分类才用 binary_crossentropy
  • **6 分类必须用 categorical_crossentropy**写错直接不收敛、准确率极低。

2. 类别顺序乱填

CLASS_NAMES 顺序必须和文件夹顺序一致,否则识别全错。

3. 池化层参数写在括号外面 ❌

复制代码
model.add(MaxPooling2D(...), strides=...)  # 错
model.add(MaxPooling2D(..., strides=...))  # 对

4. 预测时没扩维度

模型要 4 维:(1,32,32,3)图片只有 3 维:(32,32,3)不加 expand_dims 直接报错。

5. 训练没写 history =

后面画图 history.history 会报错。

6. 图片预处理和训练不一致

训练时 32×32、/255预测时也要 完全一样,不然模型看不懂。

四、一句话记住今天所有内容

文件夹放好图 → 统一缩放到 32×32 → 归一化 → 搭 CNN 训练 → 保存模型 → 预测时扩维度 → argmax 出结果

相关推荐
tzc_fly7 分钟前
AnisoAlign:各向异性模态对齐
人工智能·深度学习·机器学习
极客老王说Agent15 分钟前
2026供应链智变:实在Agent供应链库存预测助手核心能力与配置深度教程
人工智能·机器学习·ai·chatgpt
刘一说16 分钟前
AI热点资讯日报 - 2026年5月15日
人工智能
冬奇Lab21 分钟前
RAG 系列(十七):Agentic RAG——让 Agent 主导检索过程
人工智能·llm·源码
结构化知识课堂1 小时前
AI产品经理入门实战:如何理解计算机视觉?
人工智能·计算机视觉·产品经理·ai产品经理·ai产品设计
我没胡说八道1 小时前
2026论文工具选购指南:降重、降AI率、排版一站式筛选
人工智能·经验分享·深度学习·考研·aigc·学习方法
初心未改HD1 小时前
深度学习之MLP与反向传播算法详解
人工智能·深度学习·算法
刀法如飞1 小时前
【Go 字符串查找的 20 种实现方式,用不同思路解决问题】
人工智能·算法·go
阿正的梦工坊1 小时前
ALiBi:让大语言模型“免训练“外推到更长序列的位置编码方法
人工智能·语言模型·自然语言处理