张量的形状操作以及拼接

张量的形状操作函数概括

张量的形状变换操作函数

reshape()

squeeze()

unsqueeze()

transpose()

permute()

view()

contiguous()

需要掌握的函数

reshape()、unsqueeze()、permute()、view()

reshape()

在不改变内容的前提下,对其形状做改变。

注意:转换后元素总的个数不能变

python 复制代码
torch.random.manual_seed(10)
t1 = torch.randint(0,11,[2,3])
print(f"t1 = {t1}")
print(f"t1.shape = {t1.shape}")
t2  = t1.reshape(3,2)
print(f"t2 = {t2}")
print(f"t2.shape = {t2.shape}")

unsqueeze()

在指定的轴上增加一个(1)维度

python 复制代码
t1 = torch.randint(0, 11, [2, 3])
    t2 = t1.unsqueeze(0)
    print(f"t2 = {t2}")
    print(f"t2.shape = {t2.shape}")
    t3 = t1.unsqueeze(1)
    print(f"t3 = {t3}")
    print(f"t3.shape = {t3.shape}")
    t4 = t1.unsqueeze(2)
    print(f"t4 = {t4}")
    print(f"t4.shape = {t4.shape}")

squeeze()

删除所有为1的维度,等价于降维

python 复制代码
t1 = torch.randint(0, 11, [2,1,3,1,1])
    print(f"t1 = {t1}")
    print(f"t1.shape = {t1.shape}")
    t2 = t1.squeeze()
    print(f"t2 = {t2}")
    print(f"t2.shape = {t2.shape}")

transpose()和permute()

transpose() 一次只能交换2个维度

permute() 一次可以同时交换多个维度

python 复制代码
t1 = torch.randint(0, 11, [2,3,4])
    print(f"t1.shape = {t1.shape}")
    t2 = t1.transpose(0,1)
    print(f"t2.shape = {t2.shape}")
    t3 = t1.permute(2,0,1)
    print(f"t3.shape = {t3.shape}")

view()和contiguous()

view只修改连续的张量的形状(连续指的是内存的连续)

view可以改变原来的张量比如t1.view(),t1的形状也发生了改变

is_contiguous() 判断张量是否连续

contiguous() 将不连续的张量变成连续的

python 复制代码
t1 = torch.randint(0, 11, [2,3])
    t2 = t1.view(3,2)
    print(f"t2.shape = {t2.shape}")
    #通过transpose将张量变为不连续的
    t1 = t1.transpose(1,0)
    # print(f"t1.is_contiguous() = {t1.is_contiguous()}")
    # t3 = t1.view(2,3)
    # print(f"t3.shape = {t3.shape}")
    #通过contiguous()变为连续的然后再转换
    t1 = t1.contiguous()
    print(f"t1.shape = {t1.shape}")
    t4 = t1.view(2,3)
    print(f"t4.shape = {t4.shape}")

张量的拼接

cat() 不改变维度数拼接张量,除了拼接的那个维度外其它的维度必须保持一致

stack() 会改变维度,拼接张量,所有的维度都必须保持一致

拼接张量可以是新维度,但是无论新旧维度,所有维度都必须保持一致

cat()

python 复制代码
t1 = torch.randint(0,5,[3,4])
t2 = torch.randint(0,5,[2,4])
t3 = torch.cat([t1,t2],dim=0)
print(f"t3.shape = {t3.shape}")

stack()

python 复制代码
t1 = torch.randint(0,5,[2,3])
t2 = torch.randint(0,5,[2,3])
t3 = torch.stack([t1,t2],dim=0)
print(f"t3.shape:{t3.shape}")
t4 = torch.stack([t1,t2],dim=1)
print(f"t4.shape:{t4.shape}")
t5 = torch.stack([t1,t2],dim=2)
print(f"t5.shape:{t5.shape}")
相关推荐
一次旅行16 小时前
HyperTool:突破传统工具调用限制,让Agent更高效执行复杂任务
人工智能
陈天伟教授17 小时前
图解人工智能(58)人工智能应用-围棋国手
人工智能·语音识别·机器翻译
闻道参看17 小时前
2026年AI优质企业培训系统综合测评:合规管控/数据量化
人工智能
老虾头17 小时前
科技贴近烟火:本地化 AI,赋能各行各业日常经营
人工智能
毒爪的小新17 小时前
Linux 环境极速部署 vLLM:从零搭建生产级大模型推理服务
linux·人工智能·ai·语言模型·vllm
老大白菜17 小时前
25美元,DIY开源可穿戴智能AI眼镜:Arduino+乐鑫ESP32+DeepSeek项目
人工智能
岁月宁静18 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
小和尚同志18 小时前
AI 自动化测试探索(一):Playwright MCP
前端·人工智能·aigc
硅谷秋水18 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶