pytorch张量创建、张量复制

pytorch张量创建、张量复制

首先注意一点:在torch中,可导张量计算出的新张量也是可导的,新张量与原张量具有可导连接,那么原张量就不是叶子张量,新张量成了叶子张量。

创建方式一:torch.tensor()

torch.tensor(data , *, dtype=None , device=None , requires_grad=False , pin_memory=False ) → Tensor

torch.tensor只能从指定的数据创建,但是可以指定数据属性,是否可微分等属性。pin_memory是将张量放置到锁业内存中,所以这个张量只能被cpu使用。

python 复制代码
import torch
a = [1, 2, 3]
b = torch.tensor(a, requires_grad=True, dtype=torch.float64)
创建方式二:torch.Tensor

按照形状创建,如果输入列表,就按照指定数据创建。

整数:torch.ShortTensor 16位,torch.IntTensor 32位,torch.LongTensor 64位

浮点:torch.FloatTensor=torch.Tensor 32位,torch.DoubleTensor 64位

注意:torch.Tensor(int1, int2,int3)会创建[int1, int2,int3]形状的张量,如果传入列表元组等,就会返回该列表元组张量。

python 复制代码
import torch
torch.Tensor(3) 
'''tensor([-2.6853e+05,  1.9983e-42,  2.3694e-38])'''
torch.Tensor(3, 1) 
'''
tensor([[3.2842e-15],
        [3.1714e+00],
        [2.3694e-38]])
'''
torch.Tensor([3, 1])
'''
tensor([3., 1.])
'''
同设备内复制 - tensor.data /tensor.detach()/tensor.clone的区别

这三个单独会用都会和原张量有牵扯:

  1. tensor.data和tensor.detach():随着原张量的数值变化而变化。剥离开了原张量的微分图。
  2. tensor.clone() : 还处于原张量的微分图中。复制了原张量的数值。也就是tesnor.clone().bachward()后,原张量的微分图会进行一次反向传导。
  3. 完全没牵扯:tensor.clone().detach()

举例:

python 复制代码
import torch

a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
b = a * 2

b_data = b.data
b_detach = b.detach()
b_clone = b.clone()
print(b, b_data, b_detach, b_clone)
'''
tensor(2., grad_fn=<MulBackward0>) tensor(2.) tensor(2.) tensor(2., grad_fn=<CloneBackward0>)
'''
# 当其中一个改变时,tensor.data, tensor.detach也会改变。tensor.clone不会改变。
b_detach.zero_()
print(b, b_data, b_detach, b_clone)
'''
tensor(0., grad_fn=<MulBackward0>) tensor(0.) tensor(0.) tensor(2., grad_fn=<CloneBackward0>)
'''

当tensor.detach或者tensor.data改变数值时,并不会影响原张量的微分传导结果。

python 复制代码
import torch

a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
b = a * 2

b_data = b.data
b_detach = b.detach()
b_clone = b.clone()

# a的微分结果不受影响
b_detach.zero_()
b.backward(retain_graph=True)
print(a.grad)

# 如果原张量本身变化,则会受到影响。
b.zero_()
a.grad.zero_()
b.backward()
print(a.grad)
'''
tensor(2.)
tensor(0.)
'''

tensor.clone会保持原张量的微分传导图,并会叠加到结果上。

python 复制代码
import torch

a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
b = a * 2

b_clone = b.clone()

b.backward(retain_graph=True)
print(a.grad)
b_clone.backward()
print(a.grad)
'''
tensor(2.)
tensor(4.)
'''
跨设备复制

方法很多,实际使用就用以下这种:

python 复制代码
device = "cuda:0" if torch.cuda.is_available() else "cpu"
temp = torch.tensor(2)
temp.to(deivce) # 如果有gpu就放到gpu.
temp = temp.cpu() # 复制到cpu上
相关推荐
Yolanda9417 分钟前
【人工智能】《从零搭建AI问答助手项目(九):Prompt优化》
人工智能·prompt
wj30558537820 分钟前
课程 9:模型测试记录与 Prompt 策略
linux·人工智能·python·comfyui
小和尚同志20 分钟前
深入使用 skill-creator:结合真实生产级实践
人工智能·aigc
DevSecOps选型指南21 分钟前
安全419专访悬镜安全 | 穿越周期在 AI 浪潮中定义数字供应链安全新范式
人工智能
沪漂阿龙35 分钟前
面试题详解:GraphRAG 全面解析——知识图谱增强 RAG、Local Search、Global Search、社区摘要、工程落地与评估指标一次讲透
人工智能·知识图谱
WangN235 分钟前
Unitree RL Lab 学习笔记【通识】
人工智能·机器学习
haina201940 分钟前
海纳AI亮相《科创中国》,解码招聘“智”变之路
人工智能·ai面试·ai招聘
星寂樱易李1 小时前
iperf3 + Python-- 网络带宽、网速、网络稳定性
开发语言·网络·python
阿星AI工作室1 小时前
刘润年中大课笔记:一句话说清AI落地之战的本质
大数据·人工智能·创业创新·商业
qingfeng154151 小时前
企业微信机器人开发:如何实现自动化与智能运营?
人工智能·python·机器人·自动化·企业微信