CNN的四维Pytorch张量格式

1 引言

前序学习进程中,已经了解了CNN相关的基本运算,包括卷积运算原理卷积值提取卷积图像扩充

再继续深入的时候,发现一个小知识点比较重要:CNN要求格式标准的四维张量,具体有[样本数,通道数,高(行),宽(列)]。

今天我们就来学习一下CNN的四维张量格式相关的小知识。

代码测试

使用reshape()函数规定张量格式

首先来看一段代码:

python 复制代码
import torch

# 2样本2通道(正确,已验证)
data_2s2c = torch.tensor([
    # 样本1-通道1(两层括号:通道层+像素层)
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本1-通道2
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本2-通道1
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本2-通道2
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).reshape(2, 2, 4, 4)
print("2样本2通道形状:", data_2s2c.shape)  # 输出:(2,2,4,4)

代码运行后会输出:
这里有一个重要定义,data_2s2c不仅是以一个四维pytorch张量,还通过reshape(2,2,4,4)被规定了格式,所以print()函数按照reshape()函数规定的效果输出。

为进行对比,可以尝试删除reshape()函数。

不使用reshape()函数规定张量格式

此时的代码为:

python 复制代码
import torch

# 2样本2通道(正确,已验证)
data_2s2c = torch.tensor([
    # 样本1-通道1(两层括号:通道层+像素层)
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本1-通道2
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本2-通道1
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本2-通道2
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).reshape(2, 2, 4, 4)

# 2样本2通道(正确,已验证)
data_2s2d = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32)
print("2样本2通道形状:", data_2s2c.shape)  # 输出:(2,2,4,4)
print("样本通道形状:", data_2s2d.shape)  # 输出:(4,4,4)

此时的代码和是上一节的不同在于,先复制data_2s2c改名为data_2s2d,然后删除reshape()函数,最后使用size()函数输出data_2s2c和data_2s2d的格式,输出效果为:

此时会看到data_2s2d的形状格式混淆了样本数和通道数,把它们相乘的结果2X2直接输出。

此时有一个重要的议题,如果想把data_2s2d转换为标准格式的四维Pytorch张量,该如何修改代码?

样本数和通道数判断

有一个简单的办法,如果张量定义过程中没有使用reshape()函数说明通道数和样本数,那就根据实际的代码判断:

data_2s2d = torch.tensor([

样本

\[1, 2, 3, 4\], \[5, 6, 7, 8\], \[9, 10, 11, 12\], \[13, 14, 15, 16

],

样本

\[17, 18, 19, 20\], \[21, 22, 23, 24\], \[25, 26, 27, 28\], \[29, 30, 31, 32

],

样本

\[33, 34, 35, 36\], \[37, 38, 39, 40\], \[41, 42, 43, 44\], \[45, 46, 47, 48

],

样本

\[49, 50, 51, 52\], \[53, 54, 55, 56\], \[57, 58, 59, 60\], \[61, 62, 63, 64

]

], dtype=torch.float32)

这段代码很直白,一个大括号内部,有四个大括号,每个大括号内部又有四个大括号,也就是会有三层大括号,不放按照从外到内的顺序,一次定义为第一层、第二层和第三层。

我们看第二层大括号,它约束了好几个第三层,这些第三层大括号内部的数字就是具体的张量,第三层可以直接看出张量的高度,对应行数,以及张量的宽度,对应列数。

由第三层张量组成的第二层,就是一个样本。也就是每一个第二层大括号内部的所有第三层张量组成一个样本。

所以上述代码有4个样本,此时没有定义通道数,应当增加一个通道数定义,通道数在CNN对应的Pytorch张量中属于第1个量(从第0个开始计),此时只需要在data_2s2d后面加一个unsqueeze(1),就会强制将通道数定义为1,代码为:

python 复制代码
import torch

# 2样本2通道(正确,已验证)
data_2s2c = torch.tensor([
    # 样本1-通道1(两层括号:通道层+像素层)
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本1-通道2
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本2-通道1
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本2-通道2
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).reshape(2, 2, 4, 4)

# 2样本2通道(正确,已验证)
data_2s2d = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32)

# 2样本2通道(正确,已验证)
data_2s2f = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).unsqueeze(1)
print("2样本2通道形状:", data_2s2c.shape)  # 输出:(2,2,4,4)
print("样本通道形状:", data_2s2d.shape)  # 输出:(4,4,4)
print("样本通道形状:", data_2s2f.shape)  # 输出:(4,4,4)

此时新增了data_2s2f,相较于2s2d在最后增加了unsqueeze(1),代码输出的效果为:

样本数和通道数判断

实际上,定义张量的过程中,应当明确知道究竟定义了样本还是通道,因为中括号内部也可以代表通道数,比如继续修改上述代码:

python 复制代码
import torch

# 2样本2通道(正确,已验证)
data_2s2c = torch.tensor([
    # 样本1-通道1(两层括号:通道层+像素层)
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本1-通道2
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本2-通道1
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本2-通道2
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).reshape(2, 2, 4, 4)

# 2样本2通道(正确,已验证)
data_2s2d = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32)

# 2样本2通道(正确,已验证)
data_2s2f = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).unsqueeze(1)

# 2样本2通道(正确,已验证)
data_2s2g = torch.tensor([
    # 样本
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ],
    # 样本
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24],
        [25, 26, 27, 28],
        [29, 30, 31, 32]
    ],
    # 样本
    [
        [33, 34, 35, 36],
        [37, 38, 39, 40],
        [41, 42, 43, 44],
        [45, 46, 47, 48]
    ],
    # 样本
    [
        [49, 50, 51, 52],
        [53, 54, 55, 56],
        [57, 58, 59, 60],
        [61, 62, 63, 64]
    ]
], dtype=torch.float32).unsqueeze(0)
print("2样本2通道形状:", data_2s2c.shape)  # 输出:(2,2,4,4)
print("样本通道形状:", data_2s2d.shape)  # 输出:(4,4,4)
print("样本通道形状:", data_2s2f.shape)  # 输出:(4,4,4)
print("样本通道形状:", data_2s2g.shape)  # 输出:(4,4,4)

此时新增了data_2s2g,这个张量使用unsqueeze(0)补充了样本维,输出效果为:

很显然,data_2s2f和data_2s2g的输出完全不一样。

总结

初步学习了CNN的四维Pytorch张量格式的基础知识。

相关推荐
文心快码BaiduComate28 分钟前
百度云与光本位签署战略合作:用AI Agent 重构芯片研发流程
前端·人工智能·架构
风象南1 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
曲幽2 小时前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
Mintopia2 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮2 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬3 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia3 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区3 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两6 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
敏编程6 小时前
一天一个Python库:jsonschema - JSON 数据验证利器
python