Tensorflow—第四讲网络八股扩展

本讲概述

一、自制数据集

我们用六万张数字图片自制训练集,一万张数字图片制作测试集

代码(注释已经很清楚了,就不解释了):

python 复制代码
def generateds(path, txt):
    f = open(txt, 'r')  # 以只读形式打开txt文件
    contents = f.readlines()  # 读取文件中所有行
    f.close()  # 关闭txt文件
    x, y_ = [], []  # 建立空列表
    for content in contents:  # 逐行取出
        value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列
        img_path = path + value[0]  # 拼出图片路径和文件名
        img = Image.open(img_path)  # 读入图片
        img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式
        img = img / 255.  # 数据归一化 (实现预处理)
        x.append(img)  # 归一化后的数据,贴到列表x
        y_.append(value[1])  # 标签贴到列表y_
        print('loading : ' + content)  # 打印状态提示

    x = np.array(x)  # 变为np.array格式
    y_ = np.array(y_)  # 变为np.array格式
    y_ = y_.astype(np.int64)  # 变为64位整型
    return x, y_  # 返回输入特征x,返回标签y_

二、数据增强

对图像数据的增强,就是对图像进行简单形变,用来应对因拍照角度不同引起的图片变形。

python 复制代码
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28,

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=False,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

fit时需要4维,所以先给数据增加了一个维度

跟之前比还有一处改变 :

flow****方法通常用于生成批次(batch)数据

三、断点续训

下次再训练时会加载上次保存的模型

save_weights_only:是否只保留文件参数;save_best_only:是否只保留最优结果;在fit函数中加入回调选项callbacks返回到history中

实现代码:

python 复制代码
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

四、参数提取

np.set_printoptions(threshold=np.inf) 这个设置是全局的,会影响到之后所有NumPy数组的打印行为。如果你想恢复默认的打印选项,可以再次调用 np.set_printoptions() 而不传递任何参数。

v.name: 这是变量(权重或偏置)的名称。在模型中,每个变量通常都有一个唯一的名字,这个名字有助于你识别模型中的不同参数。

v.shape: 这是变量的形状。在神经网络中,权重和偏置通常具有特定的形状,这对应于它们在网络中的组织方式。记录形状有助于了解每个参数的维度结构。

v.numpy() : 这是将变量的值转换为NumPy数组。由于深度学习框架(如TensorFlow或PyTorch)中的变量可能是特殊类型的张量,使用.numpy()方法可以将它们的值以NumPy数组的形式提取出来。记录这些值有助于分析或保存模型的当前状态。

实现代码:

python 复制代码
np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

五、 acc/loss可视化

从history 中提取acc,val_acc,loss,val_loss,再用matplotlib画图

实现代码:

python 复制代码
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

六、应用---给图识物

使用predict应用预测 :

实现代码 :

python 复制代码
from PIL import Image
import numpy as np
import tensorflow as tf

model_save_path = './checkpoint/mnist.ckpt'

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')])

model.load_weights(model_save_path)

preNum = int(input("input the number of test pictures:"))

for i in range(preNum):
    image_path = input("the path of test picture:")
    img = Image.open(image_path)
    img = img.resize((28, 28), Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))

    for i in range(28):
        for j in range(28):
            if img_arr[i][j] < 200:
                img_arr[i][j] = 255
            else:
                img_arr[i][j] = 0

    img_arr = img_arr / 255.0
    x_predict = img_arr[tf.newaxis, ...]
    result = model.predict(x_predict)

    pred = tf.argmax(result, axis=1)

    print('\n')
    tf.print(pred)

img = img.resize((28, 28), Image.ANTIALIAS):将其大小调整为28x28像素,因为训练的数据输入的图片为28x28像素。Image.ANTIALIAS是一个高级滤波器,用于在缩放过程中平滑图像,减少锯齿效应。

**img_arr = np.array(img.convert('L')):**将PIL图像对象转换为NumPy数组,并使用convert('L')方法将图像转换为灰度(即单通道)。

for循环:对图像进行阈值处理,将所有像素值小于200的设置为255(白色),大于等于200的设置为0(黑色)。这是一种简单的二值化方法。二值化处理特别适用于处理灰度图像,尤其是当图像是手写数字识别时,这种方法可以帮助模型更容易地区分数字的笔画和背景。

**img_arr = img_arr / 255.0:**将图像数组的像素值归一化到0到1的范围内,这是许多神经网络模型所期望的输入格式。

**x_predict = img_arr[tf.newaxis, ...]:**将归一化后的图像数组增加一个维度,从(28, 28)变为(1, 28, 28),以匹配模型的输入要求。

**pred = tf.argmax(result, axis=1):**使用tf.argmax函数从预测结果中获取最大概率对应的索引,这代表了模型预测的类别。

相关推荐
Aileen_0v02 小时前
【玩转OCR | 腾讯云智能结构化OCR在图像增强与发票识别中的应用实践】
android·java·人工智能·云计算·ocr·腾讯云·玩转腾讯云ocr
阿正的梦工坊3 小时前
深入理解 PyTorch 的 view() 函数:以多头注意力机制(Multi-Head Attention)为例 (中英双语)
人工智能·pytorch·python
Ainnle3 小时前
GPT-O3:简单介绍
人工智能
OceanBase数据库官方博客3 小时前
向量检索+大语言模型,免费搭建基于专属知识库的 RAG 智能助手
人工智能·oceanbase·分布式数据库·向量数据库·rag
测试者家园3 小时前
ChatGPT助力数据可视化与数据分析效率的提升(一)
软件测试·人工智能·信息可视化·chatgpt·数据挖掘·数据分析·用chatgpt做软件测试
Loving_enjoy5 小时前
ChatGPT详解
人工智能·自然语言处理
人类群星闪耀时5 小时前
深度学习在灾难恢复中的作用:智能运维的新时代
运维·人工智能·深度学习
图王大胜5 小时前
模型 确认偏误(关键决策)
人工智能·职业发展·管理·心理·认知·决策
机器懒得学习6 小时前
从随机生成到深度学习:使用DCGAN和CycleGAN生成图像的实战教程
人工智能·深度学习