深度学习(17)卷积层里的多输入多输出通道

1. 多个输入通道

① 核的通道数与输入的通道数一样。

2. 多个输出通道

① 每个输出通道可以匹配图片里面特定的模式。

② 把每个通道里面识别出来的模式组合起来,就得到组合模式识别。

3. 1X1卷积层

4. 二维卷积层

5. 总结

1. 输入与输出(使用自定义)

python 复制代码
# 多输入通道互相关运算
import torch
from d2l import torch as d2l
from torch import nn

# 多通道输入运算
def corr2d_multi_in(X,K):
    return sum(d2l.corr2d(x,k) for x,k in zip(X,K)) # X,K为3通道矩阵,for使得对最外面通道进行遍历        

X = torch.tensor([[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]],
                  [[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]]])
K = torch.tensor([[[0.0,1.0],[2.0,3.0]],[[1.0,2.0],[3.0,4.0]]])
print(corr2d_multi_in(X,K))

# 多输出通道运算
def corr2d_multi_in_out(X,K):  # X为3通道矩阵,K为4通道矩阵,最外面维为输出通道      
    return torch.stack([corr2d_multi_in(X,k) for k in K],0) # 大k中每个小k是一个3D的Tensor。0表示stack堆叠函数里面在0这个维度堆叠。           

print(K.shape)
print((K+1).shape)
print((K+2).shape)
print(K)
print(K+1)
K = torch.stack((K, K+1, K+2),0) # K与K+1之间的区别为K的每个元素加1
print(K.shape)
print(corr2d_multi_in_out(X,K))
复制代码
tensor([[ 56.,  72.],
        [104., 120.]])
torch.Size([2, 2, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 2])
tensor([[[0., 1.],
         [2., 3.]],

        [[1., 2.],
         [3., 4.]]])
tensor([[[1., 2.],
         [3., 4.]],

        [[2., 3.],
         [4., 5.]]])
torch.Size([3, 2, 2, 2])
tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])
python 复制代码
help(torch.stack)
复制代码
Help on built-in function stack:

stack(...)
    stack(tensors, dim=0, *, out=None) -> Tensor
    
    Concatenates a sequence of tensors along a new dimension.
    
    All tensors need to be of the same size.
    
    Arguments:
        tensors (sequence of Tensors): sequence of tensors to concatenate
        dim (int): dimension to insert. Has to be between 0 and the number
            of dimensions of concatenated tensors (inclusive)
    
    Keyword args:
        out (Tensor, optional): the output tensor.

2. 1X1卷积(使用自定义)

方法1:corr2d_multi_in_out_1x1(矩阵乘法方式)

怎么做:

  1. 把3个通道的3×3图片,每个通道拉平成9个数字

  2. 变成3行9列的矩阵

  3. 用矩阵乘法一次性算出所有结果

比喻: 像做表格,一次性用公式算出所有格子


方法2:corr2d_multi_in_out(普通滑动窗口方式)

怎么做:

  1. 把1×1的卷积核在每个位置滑动

  2. 用一个一个嵌套循环计算

  3. 一个位置算完,再算下一个位置

比喻: 像手动计算,一个格子一个格子算

python 复制代码
# 1×1卷积的多输入、多输出通道运算
def corr2d_multi_in_out_1x1(X,K):
    c_i, h, w = X.shape # 输入的通道数、宽、高
    c_o = K.shape[0]    # 输出的通道数
    X = X.reshape((c_i, h * w)) # 拉平操作,每一行表示一个通道的特征
    K = K.reshape((c_o,c_i)) 
    Y = torch.matmul(K,X) 
    return Y.reshape((c_o, h, w))

X = torch.normal(0,1,(3,3,3))   # norm函数生成0到1之间的(3,3,3)矩阵 
K = torch.normal(0,1,(2,3,1,1)) # 输出通道是2,输入通道是3,核是1X1

Y1 = corr2d_multi_in_out_1x1(X,K)
Y2 = corr2d_multi_in_out(X,K)
assert float(torch.abs(Y1-Y2).sum()) < 1e-6
print(float(torch.abs(Y1-Y2).sum()))
复制代码
0.0

3. 1X1卷积(使用框架)

把 PyTorch 的卷积层(只能吃4D输入)包装成可以直接处理2D图片的函数

python 复制代码
def comp_conv2d(conv2d, X): # conv2d 作为传参传进去,在内部使用
    #(1, 1) + (8, 8) = (1, 1, 8, 8)(批量大小, 通道数, 高, 宽)
    #因为 PyTorch 的 Conv2d 要求输入必须是4维:
    X = X.reshape((1,1)+X.shape) # 在维度前面加入一个通道数和批量大小数
    #输出 Y 的形状(批量, 输出通道, 输出高, 输出宽)
    Y = conv2d(X)  # 卷积处理是一个四维的矩阵
    return Y.reshape(Y.shape[2:]) # 将前面两个维度拿掉

X = torch.rand(size=(8,8))
conv2d = nn.Conv2d(1,1,kernel_size=3,padding=1,stride=2) # Pytorch里面卷积函数的第一个参数为输入通道,第二个参数为输出通道   
print(comp_conv2d(conv2d,X).shape) 

conv2d = nn.Conv2d(1,1,kernel_size=(3,5),padding=(0,1),stride=(3,4)) # 一个稍微复杂的例子
print(comp_conv2d(conv2d,X).shape)
复制代码
torch.Size([4, 4])
torch.Size([2, 2])
相关推荐
Lee川3 小时前
Milvus 实战:当 RAG 遇上向量数据库,从"玩具 Demo"到"生产可用的"那一步
前端·数据库·人工智能
晚烛4 小时前
CANN 调试工具与性能剖析:从日志分析到 NPU 行为追踪的完整调试体系
开发语言·windows·python·深度学习·缓存
小a彤4 小时前
elec-ops-inspection:电力巡检缺陷检测,NPU推理速度提升3倍
人工智能·cann
ZhengEnCi5 小时前
09aaa-LayerNorm是什么?
人工智能
这是谁的博客?5 小时前
AI Agent 安全架构设计:漏洞分析与防护策略深度解析
人工智能·安全·网络安全·ai·agent·安全架构·架构设计
人月神话-Lee5 小时前
【图像处理】Sobel 边缘检测——让机器“看见“轮廓
图像处理·人工智能·计算机视觉·ios·ai编程·swift
冬奇Lab6 小时前
Agent系列(四):工具调用深度解析——Agent 的手和眼
人工智能·llm
Black蜡笔小新6 小时前
自动化AI算法训练服务器DLTM助力医学影像分析进入AI智能分析新时代
人工智能·算法·自动化
冬奇Lab6 小时前
一天一个开源项目(第111篇):Understand Anything - 把代码库变成可探索知识图谱的 AI 引擎
人工智能·开源·llm
猿饵块6 小时前
git--github
人工智能