重生学AI第十六集:线性层nn.Linear

nn.Linear

1.背景知识

今天学习nn.Linear(线性层),也叫全连接层,我们在卷积层通过卷积操作获取图像特征,然后在线性层,将图像特征转化为最终的输出结果,在图像分类任务中,就是每个类别的打分,后续在输出层,根据打分来输出每个类别的概率。

比如说我们的目标任务是将一些图片分为十个类别,那么out_features 就是10

参数

  • in_features :输入数据的特征数
  • out_features :输出维度 / 输出类别数
  • bias :偏置

先来看一组代码

python 复制代码
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

datasets = torchvision.datasets.CIFAR10("../dataset",train=False,download=True,transform=transforms.ToTensor())
dataloader = DataLoader(datasets,batch_size=64)



for data in dataloader:
    img, target = data
    print(img.shape)

控制台输出是这样的

到这里我们要先理解一下这个张量数据:(64,4,32,32), 通过前面的学习我们已经知道这个张量数据的形状是(batch_size,channels,height,width),宽32,代表一行有32个像素,每个像素代表不同的颜色,高也是一样,那么这张图就有32*32个像素,也可以把他们看成是数据,如下图所示,下面这张图是一个5*5的格子,你也可以看成是宽=5,高=5的一个图像,按照张量的写法就是,(1,1,5,5)

好,我们已经学会理解一张图像了,再说回前面那个数据,(64,3,32,32)代表这张图一个通道有32*32个数据,也就是1024个数据,一共有3个通道,1024*3=3072,就是说一张图有3072个数据

2.展平的多种方式

那么,我们就要通过展平的方式,将这个图像的张量数据转化成形状为(batch_size,features)的数据,以便我们将数据送入线性层,in_features和out_features的形状都是如此,那么展平的方式有哪些呢?

1.reshape

python 复制代码
output1 = img.reshape(img.size(0),-1)

2.view

python 复制代码
output1 = img.view(img.size(0),-1)

3.flatten

python 复制代码
output1 = torch.flatten(img,start_dim=1)

参数:

  • start_dim: 展平开始的维度(默认从零开始,但要保留批次,所以需要设置为1)
  • end_dim:展平到哪一维(默认到最后)

这三种用哪种都可以,区别不大,选择喜欢的就好


img.size()是一个对象,返回的数据和shape属性值一样,img.size(0)等价于img.shape[0],取得是第一个数据64,也就是图像的批次(batch_size)


完整代码:

python 复制代码
import torchvision

from torch.utils.data import DataLoader
from torchvision import transforms

datasets = torchvision.datasets.CIFAR10("../dataset", train=False, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(datasets, batch_size=64)

for data in dataloader:
    img, target = data
    input1 = img.reshape(img.size(0), -1)
    print(input1.size())

我们通过展平得到了能够送入线性层的数据,如下图所示:

3.线性模型

通过上面的打印,我们已经知道了特征数是3072个,所以第一个参数就写3072,然后需要10个结果类别,偏置就不写了,让他默认即可

python 复制代码
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

datasets = torchvision.datasets.CIFAR10("../dataset", train=False, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(datasets, batch_size=64)


class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear1 = nn.Linear(3072, 10)

    def forward(self, input1):
        output1 = self.linear1(input1)
        return output1


model = LinearModel()

for data in dataloader:
    img, target = data
    input1 = img.reshape(img.size(0), -1)
    output1 = model(input1)
    print(output1.size())

运行结果:

成功的把原来的每张图片的3072个特征转换成了10个结果类别

相关推荐
ZTLJQ5 小时前
序列化的艺术:Python JSON处理完全解析
开发语言·python·json
H5css�海秀5 小时前
今天是自学大模型的第一天(sanjose)
后端·python·node.js·php
SuniaWang5 小时前
《Spring AI + 大模型全栈实战》学习手册系列 · 专题六:《Vue3 前端开发实战:打造企业级 RAG 问答界面》
java·前端·人工智能·spring boot·后端·spring·架构
阿贵---5 小时前
使用XGBoost赢得Kaggle比赛
jvm·数据库·python
无敌昊哥战神5 小时前
【LeetCode 257】二叉树的所有路径(回溯法/深度优先遍历)- Python/C/C++详细题解
c语言·c++·python·leetcode·深度优先
IDZSY04306 小时前
AI社交平台进阶指南:如何用AI社交提升工作学习效率
人工智能·学习
七七powerful6 小时前
运维养龙虾--AI 驱动的架构图革命:draw.io MCP 让运维画图效率提升 10 倍,使用codebuddy实战
运维·人工智能·draw.io
水星梦月6 小时前
大白话讲解AI/LLM核心概念
人工智能
温九味闻醉7 小时前
关于腾讯广告算法大赛2025项目分析1 - dataset.py
人工智能·算法·机器学习
White-Legend7 小时前
第三波GPT5.4 日400刀
人工智能·ai编程