引言
在深度学习领域,卷积神经网络(Convolutional Neural Network, CNN)是一种广泛应用于图像识别任务的神经网络结构。LeNet是一种经典的CNN结构,被广泛应用于基础的图像分类任务。本文将介绍如何使用LeNet卷积神经网络实现手写数字识别,并使用Pytorch实现LeNet手写数字识别,使用PyQt5实现手写板GUI界面,使用户能够通过手写板输入数字并进行识别。
完整代码下载:Python手写数字识别带手写板GUI界面 Pytorch代码 含训练模型 (付费资源,如果你觉得这篇博客对你有帮助,欢迎购买支持~)
1. LeNet卷积神经网络
LeNet是由Yann LeCun等人于1998年提出的卷积神经网络结构,主要用于手写字符识别。在本文中,我们将使用LeNet结构构建一个用于手写数字识别的神经网络模型。以下是LeNet的基本结构:
Layer 1: Convolutional Layer
- Input: 28x28x1 (灰度图像)
- Filter: 5x5, Stride: 1, Depth: 6
- Activation: Sigmoid
- Output: 28x28x6
Layer 2: Average Pooling Layer
- Input: 28x28x6
- Pooling: 2x2, Stride: 2
- Output: 14x14x6
Layer 3: Convolutional Layer
- Input: 14x14x6
- Filter: 5x5, Stride: 1, Depth: 16
- Activation: Sigmoid
- Output: 10x10x16
Layer 4: Average Pooling Layer
- Input: 10x10x16
- Pooling: 2x2, Stride: 2
- Output: 5x5x16
Layer 5: Fully Connected Layer
- Input: 5x5x16
- Output: 120
- Activation: Sigmoid
Layer 6: Fully Connected Layer
- Input: 120
- Output: 84
- Activation: Sigmoid
Layer 7: Output Layer
- Input: 84
- Output: 10 (对应0-9的数字)
- Activation: Softmax
2. 手写数字识别实现
使用深度学习框架(例如Pytorch)构建LeNet模型:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.sigmoid(self.conv1(x))
x = self.pool1(x)
x = F.sigmoid(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 16 * 5 * 5)
x = F.sigmoid(self.fc1(x))
x = F.sigmoid(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
并使用手写数字数据集MNIST进行训练。确保正确实现数据预处理和模型训练过程:
python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from net import Net
if __name__ == "__main__":
# 设置训练参数
batch_size = 64
epochs = 140
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 数据集
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
# 输出提示信息
print("batch_size:", batch_size)
print("data_batches:", len(trainloader))
print("epochs:", epochs)
# 神经网络
net = Net().to(device)
net.load_state_dict(torch.load('model.pth'))
# 损失函数和优化器
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练网络
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)
# 反向传播优化参数
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 938 == 937: # 每轮输出损失值
print('[epoch: %d, batches: %d] loss: %.5f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
torch.save(net.state_dict(), './model.pth') # 每轮保存模型参数
print('Finished Training')
3. 手写板GUI界面开发
模型训练完成后,为了让用户通过手写板输入数字,我们将开发一个简单直观的GUI界面。使用GUI库(例如PyQt5),创建一个窗口,包含一个手写板区域,用户可以在上面写数字。添加一个识别按钮,点击后将手写板上的数字送入LeNet模型进行识别,并在界面上显示识别结果。
以下是PyQt5代码示例:
python
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
import sys
import torch
from utils import *
from net import Net
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.title = '手写数字识别'
self.initUI()
def initUI(self):
self.setWindowTitle(self.title)
self.setMinimumSize(500, 400)
self.main_widget = QWidget()
self.main_layout = QGridLayout()
self.main_widget.setLayout(self.main_layout)
self.setCentralWidget(self.main_widget)
self.canvas = Canvas()
self.canvas.setFixedSize(300,300)
self.label = QLabel()
self.label.setFixedSize(100,100)
self.label.setText('识别结果')
self.label.setStyleSheet("font-size:15px;color:red")
self.clear_button = QPushButton('清除')
self.clear_button.setFixedSize(100,50)
self.clear_button.clicked.connect(self.canvas.clear)
self.recognize_button = QPushButton('识别')
self.recognize_button.setFixedSize(100,50)
self.recognize_button.clicked.connect(self.recognize)
self.main_layout.addWidget(self.canvas,0,0,3,1)
self.main_layout.addWidget(self.label,0,1)
self.main_layout.addWidget(self.clear_button,1,1)
self.main_layout.addWidget(self.recognize_button,2,1)
def recognize(self):
self.canvas.recognize()
self.label.setText('识别结果: ' + str(self.canvas.recognize()))
class Canvas(QLabel):
x0=-10; y0=-10; x1=-10; y1=-10
def __init__(self):
super(Canvas,self).__init__()
self.pixmap = QPixmap(300, 300)
self.pixmap.fill(Qt.white)
self.Color=Qt.blue
self.penwidth=10
def paintEvent(self,event):
painter=QPainter(self.pixmap)
painter.setPen(QPen(self.Color,self.penwidth,Qt.SolidLine))
painter.drawLine(self.x0,self.y0,self.x1,self.y1)
Label_painter=QPainter(self)
Label_painter.drawPixmap(2,2,self.pixmap)
def mousePressEvent(self, event):
self.x1=event.x()
self.y1=event.y()
def mouseMoveEvent(self, event):
self.x0 = self.x1
self.y0 = self.y1
self.x1 = event.x()
self.y1 = event.y()
self.update()
def clear(self):
self.x0=-10; self.y0=-10; self.x1=-10; self.y1=-10
self.pixmap.fill(Qt.white)
self.update()
def recognize(self):
arr = pixmap2np(self.pixmap)
arr = 255 - arr[:,:,2]
arr = clip_image(arr)
arr = resize_image(arr)
arr = np.expand_dims(arr, axis=0)
arr_batch = np.expand_dims(arr, axis=0)
tensor = torch.FloatTensor(arr_batch)
tensor = (tensor/255 - 0.5) * 2
possibles = net(tensor).detach().numpy()
result = np.argmax(possibles)
return result
if __name__ == '__main__':
net = Net()
net.load_state_dict(torch.load('model.pth'))
app = QApplication(sys.argv)
win = MainWindow()
win.show()
sys.exit(app.exec_())
这个例子中,用户可以在手写板上写数字,点击识别按钮后,程序将手写板上的数字送入LeNet模型进行识别,并在界面上显示识别结果。
通过本文的实践,你可以学到如何使用LeNet卷积神经网络实现手写数字识别,以及如何结合GUI开发一个手写板界面,更直观地进行数字识别交互。希望这篇博客对有所帮助。
完整代码下载:Python手写数字识别带手写板GUI界面 Pytorch代码 含训练模型 (付费资源,如果你觉得这篇博客对你有帮助,欢迎购买支持~)