PyTorch入门学习(十一):神经网络-线性层及其他层介绍

一、简介

神经网络是由多个层组成的,每一层都包含了一组权重和一个激活函数。每层的作用是将输入数据进行变换,从而最终生成输出。线性层是神经网络中的基本层之一,它执行的操作是线性变换,通常表示为:

css 复制代码
y = Wx + b

其中,y 是输出,x 是输入,W 是权重矩阵,b 是偏置。线性层将输入数据与权重矩阵相乘,然后加上偏置,得到输出。线性层的主要作用是进行特征提取和数据的线性组合。

二、PyTorch 中的线性层

在 PyTorch 中,线性层可以通过 torch.nn.Linear 类来实现。下面是一个示例,演示如何创建一个简单的线性层:

python 复制代码
import torch
from torch.nn import Linear

# 创建一个线性层,输入特征数为 3,输出特征数为 2
linear_layer = Linear(3, 2)

在上面的示例中,首先导入 PyTorch 库,然后创建一个线性层 linear_layer,指定输入特征数为 3,输出特征数为 2。该线性层将对输入数据执行一个线性变换。

三、示例:使用线性层构建神经网络

现在,接下来看一个示例,如何使用线性层构建一个简单的神经网络,并将其应用于图像数据。我们使用 PyTorch 和 CIFAR-10 数据集,这是一个广泛使用的图像分类数据集。

python 复制代码
import torch
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
import torchvision.datasets

# 加载 CIFAR-10 数据集
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)

# 定义一个简单的神经网络
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init()
        self.linear1 = Linear(196608, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 将输入数据展平
        x = self.linear1(x)
        return x

# 创建模型实例
model = MyModel()

# 遍历数据集并应用模型
for data in dataloader:
    imgs, targets = data
    outputs = model(imgs)
    print(outputs.shape)

在上面的示例中,首先加载 CIFAR-10 数据集,然后定义了一个简单的神经网络 MyModel,其中包含一个线性层。我们遍历数据集并将输入数据传递给模型,然后打印输出的形状。

四、常见的其他层

除了线性层,神经网络中还有许多其他常见的层,例如卷积层(Convolutional Layers)、池化层(Pooling Layers)、循环层(Recurrent Layers)等。这些层在不同类型的神经网络中起到关键作用。例如,卷积层在处理图像数据时非常重要,循环层用于处理序列数据,池化层用于减小数据维度。在 PyTorch 中,这些层都有相应的实现,可以轻松地构建不同类型的神经网络。

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐
adjusttraining12 分钟前
毁掉孩子视力不是电视和手机,两个隐藏很深因素,很多家长并不知
深度学习·其他
降临-max1 小时前
JavaSE---网络编程
java·开发语言·网络·笔记·学习
操练起来2 小时前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
大白的编程日记.2 小时前
【计算网络学习笔记】MySql的多版本控制MVCC和Read View
网络·笔记·学习·mysql
ziwu4 小时前
【宠物识别系统】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·深度学习·图像识别
ziwu4 小时前
海洋生物识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·深度学习·图像识别
WWZZ20254 小时前
快速上手大模型:深度学习12(目标检测、语义分割、序列模型)
深度学习·算法·目标检测·计算机视觉·机器人·大模型·具身智能
u***42074 小时前
Golang 构建学习
开发语言·学习·golang
车载测试工程师5 小时前
CAPL学习-IP API函数-2
网络·学习·tcp/ip·capl·canoe
YJlio6 小时前
进程和诊断工具学习笔记(8.29):ListDLLs——一眼看清进程里加载了哪些 DLL,谁在偷偷注入
android·笔记·学习