HIT 模式识别 手写汉字分类 Python实现

训练集数据 TrainSamples-400.csv,含 100 个不同汉字,每个汉字 400 个实例,每个实例均为 64*64 的二值图像;

训练集标注TrainSamples-400.csv,为 40000 个 0 到 99 间的整数,表示训练集中每个实例所属汉字类别;

测试集数据 TestSamples-300.csv,为 30000 个实例,每个实例格式同训练集。

要求标注测试集,输出 Result.csv。

python 复制代码
import numpy as np
import pandas as pd
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers

def train():
    data = pd.read_csv("TrainSamples-400.csv", header=None)
    train_image = data.to_numpy()
    data = pd.read_csv("TrainLabels-400.csv", header=None)
    train_label = data.to_numpy()
    train_label = to_categorical(train_label)
    network = models.Sequential()
    network.add(layers.Input(shape = (64, 64, 1)))
    network.add(layers.Conv2D(64, (5, 5), activation = 'relu'))
    network.add(layers.MaxPooling2D((2, 2)))
    network.add(layers.Conv2D(96, (3, 3), activation = 'relu'))
    network.add(layers.MaxPooling2D((2, 2)))
    network.add(layers.Conv2D(48, (3, 3), activation = 'relu'))
    network.add(layers.Flatten())
    network.add(layers.Dense(768, activation = 'relu'))
    network.add(layers.Dense(100, activation = 'softmax'))
    network.summary()
    network.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
    network.fit(train_image.reshape(40000, 64, 64, 1), train_label, epochs = 5, batch_size = 64, validation_split = 0.1, validation_freq = 1)
    network.save('saved_model/my_model')
    
def test():
    data = pd.read_csv("TestSamples-300.csv", header = None)
    test_image = data.to_numpy()
    network = models.load_model('saved_model/my_model')
    network.summary()
    test_label = network.predict(test_image.reshape(30000, 64, 64, 1))
    test_label = np.array([np.argmax(i) for i in test_label])
    pd.DataFrame(test_label).to_csv('Result.csv', header = None, index = False)

if __name__ == '__main__':
    train()
    test()
相关推荐
抠头专注python环境配置12 分钟前
基于Python与深度学习的智能垃圾分类系统设计与实现
pytorch·python·深度学习·分类·垃圾分类·vgg·densenet
愈努力俞幸运29 分钟前
flask 入门 token, headers,cookie
后端·python·flask
梦想是成为算法高手42 分钟前
带你从入门到精通——知识图谱(一. 知识图谱入门)
人工智能·pytorch·python·深度学习·神经网络·知识图谱
用什么都重名42 分钟前
Conda 虚拟环境安装配置路径详解
windows·python·conda
阿也在北京1 小时前
基于Neo4j和TuGraph的知识图谱与问答系统搭建——胡歌的导演演员人际圈
python·阿里云·知识图谱·neo4j
计算机徐师兄1 小时前
Python基于知识图谱的胆囊炎医疗问答系统(附源码,文档说明)
python·知识图谱·胆囊炎医疗问答系统·python胆囊炎医疗问答系统·知识图谱的胆囊炎医疗问答系统·python知识图谱·医疗问答系统
北冥码鲲1 小时前
【保姆级教程】从零入手:Python + Neo4j 构建你的第一个知识图谱
python·知识图谱·neo4j
B站计算机毕业设计超人1 小时前
计算机毕业设计Python+大模型音乐推荐系统 音乐数据分析 音乐可视化 音乐爬虫 知识图谱 大数据毕业设计
人工智能·hadoop·爬虫·python·数据分析·知识图谱·课程设计
喵手1 小时前
Python爬虫零基础入门【第三章:Requests 静态爬取入门·第5节】限速与礼貌爬取:并发、延迟、频率控制!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·requests静态爬取·限速与爬取