【小白笔记】PyTorch 和 Python 基础的这些问题

1. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 这句话是固定的吗?

  • 人话解释: 这句话是 PyTorch 代码中用于实现设备无关性 (Device Agnostic)的标准写法,它几乎是固定的模板。

  • 功能拆解:

    1. torch.cuda.is_available():询问 PyTorch:"机器上有没有英伟达的 GPU?"
    2. "cuda" if ... else "cpu":这是一个 Python 的三元表达式
      • 如果 GPU 可用,选择 "cuda"
      • 如果 GPU 不可用,选择 "cpu"
    3. torch.device(...):根据上一步的结果,创建一个 PyTorch 设备对象
  • 固定性: 这行代码是 PyTorch 社区高度推荐的,因为它能让你的代码在有 GPU 的机器上自动加速 ,在没有 GPU 的机器上自动回退到 CPU ,无需修改代码。因此,在 PyTorch 项目中,这是应该背诵和使用的固定写法


2. y_train = torch.from_numpy(y_train).long() 为什么不用 int

  • 人话解释: long()int() 在 PyTorch 中都代表整数类型,但它们分别对应不同的位宽(bit width)。

    • long() (即 torch.int64): 这是一个 64 位的整数类型。它是 PyTorch 中处理索引标签 、以及大型计数默认和推荐类型。
    • int() (即 torch.int32): 这是一个 32 位的整数类型。
  • 主要原因 (惯例和兼容性):

    1. 标签类型要求: 在 PyTorch 的许多内置函数中(例如损失函数 nn.CrossEntropyLoss),要求输入的目标标签张量必须是 torch.long (即 64 位整数) 类型。
    2. 安全范围: 64 位整数可以表示更大的数字,尽管像鸢尾花这种简单的分类任务用 32 位足够,但使用 long() 更安全、更符合 PyTorch 的习惯。
  • 记忆点: 在 PyTorch 中,特征 (X) 用 float()标签/索引 (Y) 用 long()


3. .to(self.device): 数据上设备。这个是什么用法?to 这个前面没有定义这个功能啊?

  • 人话解释: .to() 是 PyTorch 张量(Tensor)对象自带的"移动"能力。它不是一个需要您在类中定义的方法,而是 PyTorch 库已经给所有张量写好的内置方法

  • 功能: .to(目标) 方法用于将一个张量移动到指定的设备(如 CPU 或 CUDA/GPU),或转换为指定的数据类型。

  • 用法:

    python 复制代码
    my_tensor = torch.tensor([1, 2, 3])
    # 移动到 GPU
    my_tensor_on_gpu = my_tensor.to('cuda') 
    
    # 移动到我们在 __init__ 中设置好的 self.device 上
    self.X_train = X_train.to(self.device)
  • 记忆点: 张量. to(device) 是 PyTorch 中数据上/下设备(GPU/CPU)的标准动作。


4. predictions.append(pred_label) 这个 append 是啥?经常见这个用法,为什么不用 add

  • 人话解释: append 是 Python 列表(List)对象的一个内置方法 ,意思是"在列表的末尾添加一个新元素"。

  • 为什么不用 add

    • 在 Python 中,加法运算 + 具有不同的语义,例如:
      • 数值: 1 + 2 得到 3
      • 集合 (Set): 没有 add 方法,使用 set.add(element)
      • 列表: [1] + [2] 会得到 [1, 2](这是连接两个列表)。
    • 为了明确"向列表中添加一个元素 "这个操作,Python 的设计者选择了 append 这个词。add 通常用于集合 (set) 或用于表示数值相加。
  • 记忆点:

    • List (列表)的末尾添加元素用:.append()
    • Set (集合)中添加元素用:.add()

5. unsqueeze(0): 增加维度进行广播。是啥意思?

  • 人话解释: unsqueeze(0) 的意思是**"在第 0 个位置(最前面)增加一个维度,把这个向量变成一个矩阵"**。

  • 目的: PyTorch/NumPy 中的广播机制要求参与运算的张量维度能匹配。

  • 举例:

    • 原始样本 x_new 是一个特征向量 ,比如 [3.0, 3.0]。它的维度是 (2,)
    • 训练集 X_train 是一个特征矩阵 ,比如 6 个样本,维度是 (6, 2)

    如果直接相减 X_train - x_new,PyTorch 不知道怎么对齐。

    • x_new.unsqueeze(0) 后: [3.0, 3.0] 变成了 [[3.0, 3.0]]。维度从 (2,) 变成了 (1, 2)
    • 现在: 一个 (6, 2) 的矩阵和一个 (1, 2) 的行向量就可以使用广播机制进行减法了。
  • 记忆点: unsqueeze 是在不改变数据的前提下,增加一个维度(通常是为了满足广播或函数输入的要求)。


6. unsqueeze(0): 广播关键步骤。是啥意思?

  • 人话解释: 这里的"关键步骤"指的是,unsqueeze(0)激活 PyTorch 广播机制的关键。

  • 原因: 正如上一点所说,如果不增加这个维度,PyTorch 不会知道如何将 x_new 的值与 X_train 中的所有行(样本)进行匹配。一旦维度变成 (1, 2),PyTorch 就理解了:"哦,需要把这个 (1, 2) 的向量复制 6 次,然后进行逐元素相减。"

  • 记忆点: unsqueeze(0) 是我们手动调整维度 ,以便让 PyTorch 的自动广播机制 能够工作的前置条件


7. 广播机制是啥?广播机制,计算新样本与所有训练样本的特征差。?

  • 人话解释: 广播机制 (Broadcasting) 是 PyTorch 和 NumPy 中一种聪明地处理不同形状数组之间运算 的机制。它的核心思想是:在不实际复制数据的情况下,让维度较小的数组"伸展"到和维度较大的数组一样大,然后进行运算。

  • 在 KNN 中的应用:

    • 目标:计算 X_train(所有样本)与 x_new(新样本)之间的差。
    • 没有广播: 你必须写一个循环,对 X_train 中的每一行都减去 x_new,或者手动创建一个和 X_train 一样大的新样本矩阵。这效率很低。
    • 使用广播:
      • 原始:X_train (6, 2)x_new_expanded (1, 2)
      • 广播过程:PyTorch 发现第 0 维不匹配 (6 和 1),但可以扩展。它将 x_new_expanded 逻辑上复制 6 次,变成一个 (6, 2) 的张量。
      • 最终效果:differences = X_train - x_new_expanded 等价于:
        KaTeX parse error: Expected 'EOF', got '&' at position 49: ...\text{new}, 1} &̲ x\{1,2} - x\...

  • 记忆点: 广播机制 就是让你可以用一个小尺寸 的张量(如单个样本)直接对一个大尺寸 的张量(如整个数据集)进行运算(如加减乘除)。它是实现向量化计算的关键。

相关推荐
摇滚侠32 分钟前
Spring Boot3零基础教程,响应式编程的模型,笔记109
java·spring boot·笔记
工业互联网专业1 小时前
基于协同过滤算法的小说推荐系统_django+spider
python·django·毕业设计·源码·课程设计·spider·协同过滤算法
星星的月亮叫太阳1 小时前
large-scale-DRL-exploration 代码阅读 总结
python·算法
Q_Q19632884751 小时前
python+django/flask基于Echarts+Python的图书零售监测系统设计与实现(带大屏)
spring boot·python·django·flask·node.js·php
深度学习lover2 小时前
<数据集>yolo航拍交通目标识别数据集<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·航拍交通目标识别
YuanDaima20482 小时前
[CrewAI] 第5课|基于多智能体构建一个 AI 客服支持系统
人工智能·笔记·多智能体·智能体·crewai
程序猿20232 小时前
Python每日一练---第二天:合并两个有序数组
开发语言·python
权泽谦2 小时前
用 Flask + OpenAI API 打造一个智能聊天机器人(附完整源码与部署教程)
python·机器人·flask
njxiejing2 小时前
Numpy一维、二维、三维数组切片实例
开发语言·python·numpy
许长安2 小时前
c/c++ static关键字详解
c语言·c++·经验分享·笔记