Tensorflow2学习卷积神经网络入门

前言

前面写了篇入门的文章Tensorflow2学习之MNIST数据训练和测试,这里的demo属于全连接网络(FCN)。还有一种是卷积网络(CNN)相比较与全连接网络在很多方面都有优势。深入点的理论我是不懂的,这里只说点我能理解的。

  1. 由于CNN是基于卷积核的FCN基于像素点的。所以在某些应用场景中减少参数实现效率(卷积运算)、降低局部特征点影响(局部相关)、降低图像变换的影响、权值共享有很大优势的。
  2. FCN在做图像分割深度学习技术都是基于它实现的网络结构,比如 U-Net。
  3. CNN在做图像分类、目标检测方面的实现效果是非常出色的。

失传已久的记忆 平时矩阵相乘@

  1. <math xmlns="http://www.w3.org/1998/Math/MathML"> A ∗ B A*B </math>A∗B A列数要和B行数相同
  2. <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 ∗ 1 + 3 ∗ 0 + 2 ∗ 5 = 11 1*1+3*0+2*5=11 </math>1∗1+3∗0+2∗5=11
  3. <math xmlns="http://www.w3.org/1998/Math/MathML"> 4 ∗ 1 + 0 ∗ 0 + 1 ∗ 5 = 9 4*1+0*0+1*5=9 </math>4∗1+0∗0+1∗5=9

卷积的矩阵相乘

1.卷积矩阵的对应元素相乘,⨀符号表示哈达马积(Hadamard Product)

  1. 通过设置参数 padding='SAME'、strides=1 可以直接得到输入、输出同大小的
  2. 卷积层,其中 padding 的具体数量由 TensorFlow 自动计算并完成填充操作
  3. 当s > 1 时,设置 padding='SAME'将使得输出高、宽将成1/s倍地减少。
  4. out = tf.nn.conv2d(x,w,strides=1,padding=[[0,0],[0,0],[0,0],[0,0]])
  5. padding=[[0,0],[上,下],[左,右],[0,0]]

卷积网络

Conv2D卷积操作 注意下输入input_shape=(32, 32, 3)需要调整

ini 复制代码
model = models.Sequential()
model.add(
(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3), kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(MaxPooling2D((2, 2)))

经典的卷积网络有AlexNet、VGG、GoogLeNet

VGG16直接引用的写法。或者自己写网络模型
  1. 注意下最后输出 这里激活函数用了sigmoid做猫狗的测试。上一次我们对于手写字体用了分类为10的msoftmax。关于激活函数网上可以自己了解下。
  2. model.add(keras.layers.Dense(11, activation='softmax')) # 多分类输出一般用softmax分类器
  3. 下面为VGG模型微调,如不需要微调,则可将conv_base.trainable设置为FALSE
ini 复制代码
def VGG16_for_tf():  # 设置函数来创建模型
    model = models.Sequential()
    # 使用VGG16作为模型的基本框架,加载在ImageNet上训练的权重,并去除输出层
    base_model = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
    base_model.trainable = False # 冻结VGG16的权重,只训练输出层
    model.add(base_model)
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))  # 二分类,使用sigmoid映射到[0.0,1.0]
    return model
数据加载

再次提下在前面介绍Tensorflow2学习之MNIST数据训练和测试的时候注意训练用了history = model.fit(x_train, y_train_onehot, epochs=20)分别指定了训练的数据和对应的标签,这里学到了一个新的方法不用再辛苦处理单独拼接提取数据了,主要是语法和方法的封装吧。 这里的模型测试和前面一样用的h5,后面会试下tf模型

python 复制代码
train_image_path = glob.glob(r'dogs-vs-cats\train**.jpg')  # 猫狗数据集存放路径
train_image_label = [int(os.path.basename(p).split('.')[0] == 'cat') for p in train_image_path]  # 文件名字,并编码,cat为1,dog为0
train_image_ds = tf.data.Dataset.from_tensor_slices((train_image_path, train_image_label))
train_image_ds = train_image_ds.map(load_preprocess_image, num_parallel_calls=AUTOTUNE)
#打乱随机提取大小
train_image_ds = train_image_ds.shuffle(train_count).batch(BATCH_SIZE)
相关推荐
IT_陈寒2 分钟前
Vite 凭什么比 Webpack 快50%?揭秘闪电构建背后的黑科技
前端·人工智能·后端
寻见90343 分钟前
救命!RAG检索总跑偏?bge-reranker-large彻底解决「找错文档」痛点
人工智能·langchain
TechFind1 小时前
我用 OpenClaw 搭了一套运营 Agent,每天自动生产内容、分发、追踪数据——独立开发者的运营平替
人工智能·agent
小成C1 小时前
Vibe Coding 时代,研发体系该怎么重新分工
人工智能·架构·全栈
37手游后端团队1 小时前
全网最简单!从零开始,轻松把 openclaw 小龙虾装回家
人工智能·后端·openai
该用户已不存在1 小时前
月薪2w养不起龙虾?试试OpenClaw+Ollama
人工智能·aigc·ai编程
Seeker1 小时前
别盲目跟风“养龙虾”!OpenClaw爆火背后,这些致命安全风险必须警惕
人工智能·安全
golang学习记1 小时前
Claude Code 官宣新 AI 功能!随时随地 AI 为你打工
人工智能·claude
IvanCodes2 小时前
OpenClaw保姆级安装教程:windows&ubuntu
人工智能
Serverless社区3 小时前
AgentRun实践指南:Agent 的宝藏工具—All-In-One Sandbox
人工智能