Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1

问题描述

如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题

python 复制代码
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Flatten,Conv2D,BatchNormalization,Activation
from tensorflow.keras import Model
import os
import numpy as np

# np.set_printoptions(threshold=np.inf)


class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.conv1 = Conv2D(6, (5,5), activation='sigmoid')
        self.pool1 = MaxPooling2D(pool_size=(2,2),strides=2)
        self.conv2 = Conv2D(16, (5,5), activation='sigmoid')
        self.pool2 = MaxPooling2D(pool_size=(2,2),strides=2)

        self.flatten1 = Flatten()
        self.f1=Dense(120,activation='sigmoid')
        self.f2=Dense(84,activation='sigmoid')
        self.f3=Dense(10,activation='softmax')

    def call(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.flatten1(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255.0,x_test/255.0


model = Baseline()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
              ,metrics=['sparse_categorical_accuracy'])

checkpoint_save_path="lenet.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
    model.load_weights(checkpoint_save_path)
    print("---------------------Loaded model---------------")

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True
                                              ,save_best_only=True, verbose=1)


history=model.fit(x_train,y_train,batch_size=32, epochs=5, validation_data=(x_test, y_test)
          ,validation_freq=1,callbacks=[cp_callback])
model.summary()

file=open('weights_lenet.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy())+'\n')
file.close()

train_acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(loss,label='train_loss')
plt.plot(val_loss,label='val_loss')
plt.title('model loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(val_acc,label='val_acc')
plt.title('model acc')
plt.legend()
plt.show()

代码写的看起来没有问题,但是就是acc一直在0.1,总共10个类,也就是说网络根本没有训练效果,就是瞎蒙的。为什么会这样呢。想知道答案的直接跳到最后。下面是我踩的坑,

踩坑

我尝试升级tensorflow版本,但是我们知道升级tensorflow,对应的cudatoolkit 和cudnn 也要升级,官网版本对应


conda install cudatoolkit==11.2.0

但是我去安装的时候显示PackagesNotFoundError: The following packages are not available from current channels:

搜不到这个版本,conda search cudatoolkit查看可以安装的版本
就是没有11.2,这就很烦人,

我电脑环境是

powershell 复制代码
windows11
cuda 12.3
cudnn 8.9.7

我不能把电脑cuda卸载重新装,因为我pytorch要求的是上面的环境。我尝试去官网再安装一个cuda但是失败了(想试一下windows电脑能不能安装两个cuda)。总之折腾了一下午

解决办法

方法一

cudatoolkit 和cudnn保持不变,直接升级tensorflow
pip install tensorflow==2.4

但是这样就不能用gpu训练了,跑代码的时候用的是cpu,具体原因我也不是很清楚,

方法二

看我之前的文章,卸载电脑上的cuda安装,安装cuda11.2和对应的cudnn8.1
cuda下载地址
cudnn下载地址

然后安装tensoflow 2.10版本
conda install tensorflow_gpu==2.10.0


你windows电脑如果想同时可以跑tensorflow和pytorch,建议电脑的cuda环境就按照tensorflow的安装。

因为pytorch安装比较简单,一般会自带对应的cuda,而tensorflow对cuda要求比较严格,用指令(conda install cudatoolkit==11.2.0 )一般找不到对应的版本,只能去官网下载

windows要是想跑代码就用pytorch吧,tensorflow对windows真的很不友好,tensorflow2.10以上直接不支持了,可以用实验室的服务器跑tensorflow代码

相关推荐
NAGNIP1 天前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab1 天前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab1 天前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP1 天前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年1 天前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼1 天前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS1 天前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区1 天前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈1 天前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang1 天前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx