python机器学习(手写数字识别)

导包

import matplotlib.pyplot as plt

import pandas as pd

from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier

import joblib

from collections import Counter

1. 定义函数 show_digit(idx), 用于查看: 数字图片.

def show_digit(idx):

idx: 行索引, 即: 要哪行的数据.

1. 读取数据, 获取df对象.

data = pd.read_csv('data/手写数字识别.csv')

细节: 非法值校验.

if idx < 0 or idx > len(data) - 1 :

return

2. 获取数据, 即: 特征 + 标签

x = data.iloc[:, 1:]

y = data.iloc[:, 0]

3. 查看下数据集.

print(f'x的维度: {x.shape}') # (42000, 784)

print(f'y的各分类数量: {Counter(y)}') # Counter({1: 4684, 7: 4401, 3: 4351, 9: 4188, 2: 4177, 6: 4137, 0: 4132, 4: 4072, 8: 4063, 5: 3795})

4. 获取具体的 某张图片, 即: 某行数据 => 样本数据.

step1: 把图片转成 28 * 28像素(二维数组)

digit = x.iloc[idx].values.reshape(28, 28)

step2: 显示图片.

plt.imshow(digit, cmap='gray') # 灰色显示 => 灰度图

step3: 取消坐标显示.

plt.axis('off')

step4: 显示图片

plt.show()

2. 定义函数 train_model(), 用于训练模型.

def train_model():

1. 读取数据, 获取df对象.

data = pd.read_csv('data/手写数字识别.csv')

2. 获取数据, 即: 特征 + 标签

x = data.iloc[:, 1:]

y = data.iloc[:, 0]

3. 数据的预处理.

step1: x轴(像素点)的 归一化处理.

x = x / 255

step2: 区分训练集和测试集. stratify: 按照y的类别比例进行分割

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y, random_state=21)

4. 训练模型.

estimator = KNeighborsClassifier(n_neighbors=3)

estimator.fit(x_train, y_train)

5. 模型预估, 评测正确率.

my_score = estimator.score(x_test, y_test)

print(f'模型预估正确率: {my_score}') # 0.9657142857142857

6. 保存模型.

joblib.dump(estimator, 'model/knn.pth')

3. 定义use_model()函数, 用于: 测试模型.

def use_model(): # pytest

1. 加载图片.

img = plt.imread('data/demo.png') # 28 * 28像素

plt.imshow(img, cmap='gray') # 灰度图

plt.show()

2. 加载模型.

estimator = joblib.load('model/knn.pth')

3. 预测图片.

img = img.reshape(1, -1) # 效果等价于: reshape(1, 784)

y_test = estimator.predict(img)

print(f'预测的数字是: {y_test}')

4. 在main函数中测试.

if name == 'main':

显示数字图片.

show_digit(10)

show_digit(21)

训练模型

train_model()

测试模型

use_model()

相关推荐
井底哇哇14 分钟前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证18 分钟前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩43 分钟前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控1 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
ℳ₯㎕ddzོꦿ࿐1 小时前
解决Python 在 Flask 开发模式下定时任务启动两次的问题
开发语言·python·flask
CodeClimb1 小时前
【华为OD-E卷 - 第k个排列 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
一水鉴天1 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
Channing Lewis2 小时前
什么是 Flask 的蓝图(Blueprint)
后端·python·flask
倔强的石头1062 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构
B站计算机毕业设计超人2 小时前
计算机毕业设计hadoop+spark股票基金推荐系统 股票基金预测系统 股票基金可视化系统 股票基金数据分析 股票基金大数据 股票基金爬虫
大数据·hadoop·python·spark·课程设计·数据可视化·推荐算法