Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1

问题描述

如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题

python 复制代码
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Flatten,Conv2D,BatchNormalization,Activation
from tensorflow.keras import Model
import os
import numpy as np

# np.set_printoptions(threshold=np.inf)


class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.conv1 = Conv2D(6, (5,5), activation='sigmoid')
        self.pool1 = MaxPooling2D(pool_size=(2,2),strides=2)
        self.conv2 = Conv2D(16, (5,5), activation='sigmoid')
        self.pool2 = MaxPooling2D(pool_size=(2,2),strides=2)

        self.flatten1 = Flatten()
        self.f1=Dense(120,activation='sigmoid')
        self.f2=Dense(84,activation='sigmoid')
        self.f3=Dense(10,activation='softmax')

    def call(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.flatten1(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255.0,x_test/255.0


model = Baseline()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
              ,metrics=['sparse_categorical_accuracy'])

checkpoint_save_path="lenet.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
    model.load_weights(checkpoint_save_path)
    print("---------------------Loaded model---------------")

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True
                                              ,save_best_only=True, verbose=1)


history=model.fit(x_train,y_train,batch_size=32, epochs=5, validation_data=(x_test, y_test)
          ,validation_freq=1,callbacks=[cp_callback])
model.summary()

file=open('weights_lenet.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy())+'\n')
file.close()

train_acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(loss,label='train_loss')
plt.plot(val_loss,label='val_loss')
plt.title('model loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(val_acc,label='val_acc')
plt.title('model acc')
plt.legend()
plt.show()

代码写的看起来没有问题,但是就是acc一直在0.1,总共10个类,也就是说网络根本没有训练效果,就是瞎蒙的。为什么会这样呢。想知道答案的直接跳到最后。下面是我踩的坑,

踩坑

我尝试升级tensorflow版本,但是我们知道升级tensorflow,对应的cudatoolkit 和cudnn 也要升级,官网版本对应


conda install cudatoolkit==11.2.0

但是我去安装的时候显示PackagesNotFoundError: The following packages are not available from current channels:

搜不到这个版本,conda search cudatoolkit查看可以安装的版本
就是没有11.2,这就很烦人,

我电脑环境是

powershell 复制代码
windows11
cuda 12.3
cudnn 8.9.7

我不能把电脑cuda卸载重新装,因为我pytorch要求的是上面的环境。我尝试去官网再安装一个cuda但是失败了(想试一下windows电脑能不能安装两个cuda)。总之折腾了一下午

解决办法

方法一

cudatoolkit 和cudnn保持不变,直接升级tensorflow
pip install tensorflow==2.4

但是这样就不能用gpu训练了,跑代码的时候用的是cpu,具体原因我也不是很清楚,

方法二

看我之前的文章,卸载电脑上的cuda安装,安装cuda11.2和对应的cudnn8.1
cuda下载地址
cudnn下载地址

然后安装tensoflow 2.10版本
conda install tensorflow_gpu==2.10.0


你windows电脑如果想同时可以跑tensorflow和pytorch,建议电脑的cuda环境就按照tensorflow的安装。

因为pytorch安装比较简单,一般会自带对应的cuda,而tensorflow对cuda要求比较严格,用指令(conda install cudatoolkit==11.2.0 )一般找不到对应的版本,只能去官网下载

windows要是想跑代码就用pytorch吧,tensorflow对windows真的很不友好,tensorflow2.10以上直接不支持了,可以用实验室的服务器跑tensorflow代码

相关推荐
DianSan_ERP1 小时前
电商API接口全链路监控:构建坚不可摧的线上运维防线
大数据·运维·网络·人工智能·git·servlet
在人间耕耘1 小时前
HarmonyOS Vision Kit 视觉AI实战:把官方 Demo 改造成一套能长期复用的组件库
人工智能·深度学习·harmonyos
够快云库1 小时前
能源行业非结构化数据治理实战:从数据沼泽到智能资产
大数据·人工智能·机器学习·企业文件安全
Eloudy1 小时前
CHI 开发备忘 08 记 -- CHI spec 08
人工智能·arch·hpc
homelook2 小时前
Transformer与电池管理系统(BMS)的结合是当前 智能电池管理 的前沿研究方向
人工智能·深度学习·transformer
ZPC82102 小时前
docker 镜像备份
人工智能·算法·fpga开发·机器人
ZPC82102 小时前
docker 使用GUI ROS2
人工智能·算法·fpga开发·机器人
ssshooter2 小时前
免费和付费 AI API 选择指南
人工智能·aigc·openai
掘金酱2 小时前
「寻找年味」 沸点活动|获奖名单公示🎊
前端·人工智能·后端