Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1

问题描述

如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题

python 复制代码
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Flatten,Conv2D,BatchNormalization,Activation
from tensorflow.keras import Model
import os
import numpy as np

# np.set_printoptions(threshold=np.inf)


class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.conv1 = Conv2D(6, (5,5), activation='sigmoid')
        self.pool1 = MaxPooling2D(pool_size=(2,2),strides=2)
        self.conv2 = Conv2D(16, (5,5), activation='sigmoid')
        self.pool2 = MaxPooling2D(pool_size=(2,2),strides=2)

        self.flatten1 = Flatten()
        self.f1=Dense(120,activation='sigmoid')
        self.f2=Dense(84,activation='sigmoid')
        self.f3=Dense(10,activation='softmax')

    def call(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.flatten1(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255.0,x_test/255.0


model = Baseline()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
              ,metrics=['sparse_categorical_accuracy'])

checkpoint_save_path="lenet.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
    model.load_weights(checkpoint_save_path)
    print("---------------------Loaded model---------------")

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True
                                              ,save_best_only=True, verbose=1)


history=model.fit(x_train,y_train,batch_size=32, epochs=5, validation_data=(x_test, y_test)
          ,validation_freq=1,callbacks=[cp_callback])
model.summary()

file=open('weights_lenet.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy())+'\n')
file.close()

train_acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(loss,label='train_loss')
plt.plot(val_loss,label='val_loss')
plt.title('model loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(val_acc,label='val_acc')
plt.title('model acc')
plt.legend()
plt.show()

代码写的看起来没有问题,但是就是acc一直在0.1,总共10个类,也就是说网络根本没有训练效果,就是瞎蒙的。为什么会这样呢。想知道答案的直接跳到最后。下面是我踩的坑,

踩坑

我尝试升级tensorflow版本,但是我们知道升级tensorflow,对应的cudatoolkit 和cudnn 也要升级,官网版本对应


conda install cudatoolkit==11.2.0

但是我去安装的时候显示PackagesNotFoundError: The following packages are not available from current channels:

搜不到这个版本,conda search cudatoolkit查看可以安装的版本
就是没有11.2,这就很烦人,

我电脑环境是

powershell 复制代码
windows11
cuda 12.3
cudnn 8.9.7

我不能把电脑cuda卸载重新装,因为我pytorch要求的是上面的环境。我尝试去官网再安装一个cuda但是失败了(想试一下windows电脑能不能安装两个cuda)。总之折腾了一下午

解决办法

方法一

cudatoolkit 和cudnn保持不变,直接升级tensorflow
pip install tensorflow==2.4

但是这样就不能用gpu训练了,跑代码的时候用的是cpu,具体原因我也不是很清楚,

方法二

看我之前的文章,卸载电脑上的cuda安装,安装cuda11.2和对应的cudnn8.1
cuda下载地址
cudnn下载地址

然后安装tensoflow 2.10版本
conda install tensorflow_gpu==2.10.0


你windows电脑如果想同时可以跑tensorflow和pytorch,建议电脑的cuda环境就按照tensorflow的安装。

因为pytorch安装比较简单,一般会自带对应的cuda,而tensorflow对cuda要求比较严格,用指令(conda install cudatoolkit==11.2.0 )一般找不到对应的版本,只能去官网下载

windows要是想跑代码就用pytorch吧,tensorflow对windows真的很不友好,tensorflow2.10以上直接不支持了,可以用实验室的服务器跑tensorflow代码

相关推荐
晨曦_子画1 分钟前
编程语言之战:AI 之后的 Kotlin 与 Java
android·java·开发语言·人工智能·kotlin
道可云2 分钟前
道可云人工智能&元宇宙每日资讯|2024国际虚拟现实创新大会将在青岛举办
大数据·人工智能·3d·机器人·ar·vr
人工智能培训咨询叶梓12 分钟前
探索开放资源上指令微调语言模型的现状
人工智能·语言模型·自然语言处理·性能优化·调优·大模型微调·指令微调
zzZ_CMing12 分钟前
大语言模型训练的全过程:预训练、微调、RLHF
人工智能·自然语言处理·aigc
newxtc13 分钟前
【旷视科技-注册/登录安全分析报告】
人工智能·科技·安全·ddddocr
成都古河云14 分钟前
智慧场馆:安全、节能与智能化管理的未来
大数据·运维·人工智能·安全·智慧城市
UCloud_TShare17 分钟前
浅谈语言模型推理框架 vLLM 0.6.0性能优化
人工智能
软工菜鸡21 分钟前
预训练语言模型BERT——PaddleNLP中的预训练模型
大数据·人工智能·深度学习·算法·语言模型·自然语言处理·bert
vivid_blog27 分钟前
大语言模型(LLM)入门级选手初学教程 III
人工智能·语言模型·自然语言处理
AI视觉网奇1 小时前
sklearn 安装使用笔记
人工智能·算法·sklearn