pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系

pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系


以下代码片段:

python 复制代码
self.x = x.clone().detach()  # 或 torch.tensor(x).float()

用于处理和复制张量 x,并根据需要使其与原始计算图断开联系或改变其数据类型。下面是逐部分详细解释。


1. x.clone()

  • 作用 :对张量 x 进行深拷贝,生成一个新的张量。
    • 新的张量和原始张量具有相同的数据,但存储在不同的内存空间。
    • 修改 clone() 的返回值不会影响原始张量。

示例:

python 复制代码
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.clone()

y[0] = 99.0
print(x)  # tensor([1., 2., 3.], grad_fn=<CloneBackward>)
print(y)  # tensor([99.,  2.,  3.])

2. x.detach()

  • 作用 :返回一个与 x 共享相同数据但 与计算图断开联系 的张量。
    • 通常用于阻止梯度计算。
    • 在神经网络中,如果你不希望某些操作影响反向传播时,会用到 detach()

示例:

python 复制代码
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.detach()

y[0] = 99.0  # y 的数据更改不会影响 x
print(x)  # tensor([1., 2., 3.], requires_grad=True)
print(y)  # tensor([99.,  2.,  3.])

使用场景:

detach() 在以下场景中非常有用:

  1. 阻止梯度传播:

    python 复制代码
    z = x.clone().detach()
    # z 不会参与反向传播,x 的梯度也不会受 z 的影响
  2. 保存模型状态或生成推断结果:

    python 复制代码
    with torch.no_grad():
        output = model(x)  # 临时禁用梯度计算

3. torch.tensor(x).float()

  • 作用 :将输入 x 转换为 PyTorch 张量,并将其数据类型强制为 torch.float32(默认浮点类型)。
  • 适用场景:
    • 输入可能是一个 Python 列表或 NumPy 数组时,用于将其转换为 PyTorch 张量。
    • 确保张量数据类型一致(某些模型或操作对数据类型有严格要求)。

示例:

python 复制代码
x = [[1, 2, 3], [4, 5, 6]]  # Python 列表
y = torch.tensor(x).float()  # 转为 torch.float32 类型的张量
print(y)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

4. 两者的对比与结合

  • x.clone().detach()torch.tensor(x).float() 是不同的操作:

    1. x.clone().detach()
      • 复制一个现有张量,且与原始计算图断开。
      • 适用于 PyTorch 张量 x,不适用于列表或其他数据类型。
    2. torch.tensor(x).float()
      • 将输入转换为新的 PyTorch 张量,适用于从非张量对象(如列表、NumPy 数组)构造张量。
      • 转换过程中可以指定数据类型(如 .float())。
  • 结合使用

    如果需要复制一个张量、改变数据类型,并断开计算图,可以将两者结合:

    python 复制代码
    self.x = torch.tensor(x.clone().detach()).float()

使用场景

x.clone().detach()

  • x 是一个 PyTorch 张量,且需要:
    • 复制数据。
    • 与原始计算图断开。

torch.tensor(x).float()

  • x 是一个非 PyTorch 张量对象(如列表或 NumPy 数组),且需要:
    • 转换为 PyTorch 张量。
    • 确保数据类型为浮点型。

完整示例:

python 复制代码
import torch

# 输入张量
x = torch.tensor([[2.0, -1.0], [1.0, 1.0]], requires_grad=True)

# 使用 clone().detach()
y = x.clone().detach()
y[0, 0] = 99.0
print("x:", x)  # 原始张量不会改变
print("y:", y)  # 新张量修改了

# 使用 torch.tensor()
z = torch.tensor([[1, 2], [3, 4]]).float()
print("z:", z)  # 转换为浮点张量

总结

  • clone():深拷贝一个张量。
  • detach():断开张量与计算图的连接。
  • torch.tensor(x).float():将非张量数据转换为浮点型 PyTorch 张量。
  • 它们在不同场景下各有用途,可以单独使用或结合使用。
相关推荐
代码AI弗森1 小时前
从 IDE 到 CLI:AI 编程代理工具全景与落地指南(附对比矩阵与脚本化示例)
ide·人工智能·矩阵
xchenhao2 小时前
SciKit-Learn 全面分析分类任务 breast_cancer 数据集
python·机器学习·分类·数据集·scikit-learn·svm
007tg4 小时前
从ChatGPT家长控制功能看AI合规与技术应对策略
人工智能·chatgpt·企业数据安全
Memene摸鱼日报4 小时前
「Memene 摸鱼日报 2025.9.11」腾讯推出命令行编程工具 CodeBuddy Code, ChatGPT 开发者模式迎来 MCP 全面支持
人工智能·chatgpt·agi
linjoe995 小时前
【Deep Learning】Ubuntu配置深度学习环境
人工智能·深度学习·ubuntu
独行soc5 小时前
2025年渗透测试面试题总结-66(题目+回答)
java·网络·python·安全·web安全·adb·渗透测试
先做个垃圾出来………6 小时前
残差连接的概念与作用
人工智能·算法·机器学习·语言模型·自然语言处理
AI小书房6 小时前
【人工智能通识专栏】第十三讲:图像处理
人工智能
fanstuck6 小时前
基于大模型的个性化推荐系统实现探索与应用
大数据·人工智能·语言模型·数据挖掘
多看书少吃饭8 小时前
基于 OpenCV 的眼球识别算法以及青光眼算法识别
人工智能·opencv·计算机视觉