PyTorch 中 functional.py 文件介绍

PyTorch

PyTorch 是一个开源的机器学习库,广泛用于计算机视觉和自然语言处理等应用。它由 Facebook 的人工智能研究团队开发,并得到了许多研究机构和企业的支持。PyTorch 以其易用性、灵活性和强大的社区支持而受到欢迎。一些特点如下:

  1. 动态计算图(Dynamic Computation Graphs):PyTorch 使用动态计算图,这意味着图在运行时构建,允许更直观和灵活的模型设计。这与传统的静态图框架(如 TensorFlow 1.x)不同,后者需要在执行前定义整个图。

  2. 自动微分(Automatic Differentiation):PyTorch 提供了自动微分功能,可以自动计算梯度,这对于训练深度学习模型至关重要。

  3. 丰富的API:PyTorch 提供了大量的预定义层、优化器和损失函数,使得模型的构建和训练变得简单。

  4. 多GPU支持:PyTorch 支持多GPU训练,可以有效地利用硬件资源,加速模型训练。

  5. 序列化和模型共享:PyTorch 允许轻松地保存和加载模型,便于模型的共享和部署。

  6. 社区支持:PyTorch 拥有一个活跃的社区,提供了大量的教程、文档和预训练模型。

functional.py

功能介绍

  1. 在 PyTorch 中,torch.nn.functional.py 是一个包含多种函数的模块,这些函数提供了不包含可学习参数的层的实现。这意味着,与 torch.nn.Module 中的层不同,torch.nn.functional.py 中的函数不跟踪梯度或在训练过程中更新参数。这些函数通常用于执行各种操作,如卷积、池化、激活函数、损失函数等。

常用函数

  1. 卷积函数

    • torch.nn.functional.conv1d: 一维卷积函数。
    • torch.nn.functional.conv2d: 二维卷积函数。
    • torch.nn.functional.conv3d: 三维卷积函数。
    • torch.nn.functional.conv_transpose1d, torch.nn.functional.conv_transpose2d, torch.nn.functional.conv_transpose3d: 转置卷积(也称为去卷积)函数。
  2. 池化函数

    • torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d: 平均池化函数。
    • torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d: 最大池化函数。
    • torch.nn.functional.adaptive_max_pool1d, torch.nn.functional.adaptive_max_pool2d: 自适应最大池化函数。
    • torch.nn.functional.adaptive_avg_pool1d, torch.nn.functional.adaptive_avg_pool2d: 自适应平均池化函数。
  3. 非线性激活函数

    • torch.nn.functional.relu: 修正线性单元(ReLU)激活函数。
    • torch.nn.functional.sigmoid: Sigmoid 激活函数。
    • torch.nn.functional.tanh: 双曲正切激活函数。
  4. 归一化函数

    • torch.nn.functional.batch_norm: 批量归一化函数。
  5. 线性函数

    • torch.nn.functional.linear: 线性变换函数。
  6. Dropout 函数

    • torch.nn.functional.dropout: Dropout 函数。
  7. 距离函数

    • torch.nn.functional.pairwise_distance: 计算两个张量之间的成对距离。
  8. 损失函数

    • torch.nn.functional.cross_entropy: 交叉熵损失函数。
    • torch.nn.functional.binary_cross_entropy: 二进制交叉熵损失函数。
    • torch.nn.functional.nll_loss: 负对数似然损失函数。
  9. 视觉函数

    • torch.nn.functional.pixel_shuffle: 用于将张量重新排列以增加空间分辨率的函数。
    • torch.nn.functional.pad: 用于填充张量的函数。

使用示例

  1. 卷积函数示例
python 复制代码
import torch
import torch.nn.functional as F

# 创建一个输入张量,假设是一个单通道的28x28图像
input = torch.randn(1, 1, 28, 28)

# 定义卷积核的权重和偏置
weight = torch.randn(1, 1, 3, 3)
bias = torch.randn(1)

# 使用 F.conv2d 进行卷积操作
output = F.conv2d(input, weight, bias)

print(output.shape)  # 输出张量的形状
  1. 池化函数示例
python 复制代码
# 使用 F.max_pool2d 进行最大池化,
# 池化(Pooling)是卷积神经网络(CNN)中常用的一种技术,它用于降低特征的空间维度(高和宽),
# 同时保留最重要的信息。池化操作通常在卷积层之后应用,可以减少模型的参数数量和计算量,
# 提高模型的抽象能力,并且有助于提取更具有泛化性的特征。
pooled = F.max_pool2d(input, kernel_size=2)

print(pooled.shape)  # 输出张量的形状
  1. 激活函数示例
python 复制代码
# 使用 F.relu 作为激活函数
activated = F.relu(input)

print(activated.shape)  # 输出张量的形状
  1. 损失函数示例
python 复制代码
# 假设我们有一些预测和目标标签
predictions = torch.randn(10)
targets = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# 使用 F.cross_entropy 计算交叉熵损失
loss = F.cross_entropy(predictions, targets)

print(loss.item())  # 输出损失值
  1. 归一化函数示例
python 复制代码
# 假设我们有一个批量的输入数据
inputs = torch.randn(20, 10)

# 使用 F.batch_norm 进行批量归一化
output = F.batch_norm(inputs, running_mean=torch.zeros(10), running_var=torch.ones(10))

print(output.shape)  # 输出张量的形状
  1. Dropout 函数示例
python 复制代码
# 使用 F.dropout 进行dropout操作,正则化技术,Dropout 通过在训练过程中随机"丢弃"(即将输出设置为零)
# 一部分神经元的输出,来减少神经元之间复杂的共适应关系。
dropped = F.dropout(input, p=0.2)

print(dropped.shape)  # 输出张量的形状

相关源码

  1. GitHub地址:https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py
相关推荐
Sxiaocai4 分钟前
使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类
pytorch·深度学习·分类
GL_Rain5 分钟前
【OpenCV】Could NOT find TIFF (missing: TIFF_LIBRARY TIFF_INCLUDE_DIR)
人工智能·opencv·计算机视觉
shansjqun10 分钟前
教学内容全覆盖:航拍杂草检测与分类
人工智能·分类·数据挖掘
狸克先生12 分钟前
如何用AI写小说(二):Gradio 超简单的网页前端交互
前端·人工智能·chatgpt·交互
baiduopenmap27 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
小任同学Alex30 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术36 分钟前
微软 Ignite 2024 大会
人工智能
江瀚视野1 小时前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
陪学1 小时前
百度遭初创企业指控抄袭,维权还是碰瓷?
人工智能·百度·面试·职场和发展·产品运营
QCN_1 小时前
湘潭大学人工智能考试复习1(软件工程)
人工智能