pytorch 图像的卷积操作

目录

1.卷积核基本参数说明

2.卷积相关操作说明

3.卷积操作示例


1.卷积核基本参数说明

pytorch进行图像卷积操作之前,需要把图像素格式进行分离,比如一个图像为rgb格式,把R,G,B取出来作为一个ndarray,前文讲过,在pytorch中进行图像转Tensor,大小变换,相关处理的库,基本都放在 from torchvision import transforms里面,对于把正常的图像转换为单独的RGB的ndarray,并且归一化,使用 transforms.ToTensor即可一次性完成转换。在训练图像相关模型的时候,主要是训练卷积核的参数,一般的3*3的卷积核结构如代码所示:

python 复制代码
import cv2
import os

import numpy as np
import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from torchvision import transforms
#定义卷积核心,bias为False则不要偏置参数
#输入通道为3,输出通道为1,卷积核大小为3*3,偏置为真
cov = nn.Conv2d(3,1,3,bias=True)
print(cov.state_dict())

'''
OrderedDict([('weight', tensor([[[[ 0.1062,  0.0600, -0.0675],
          [-0.0303,  0.0045, -0.0276],
          [ 0.0114,  0.1434, -0.1323]],

         [[-0.0622, -0.0029, -0.0695],
          [-0.0282, -0.0664, -0.0157],
          [ 0.0037, -0.0900, -0.0588]],

         [[-0.1231, -0.1717,  0.1089],
          [ 0.0051,  0.1269, -0.0846],
          [-0.0662,  0.0817,  0.1689]]]])), ('bias', tensor([0.0631]))])

进程已结束,退出代码为 0
'''
2.卷积相关操作说明

用transforms.ToTensor把图像分为RGB单独通道且归一化后,就可以对图像进行卷积操作,示例代码如图:

python 复制代码
import cv2
import os
import numpy as np
import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from torchvision import transforms

cov = nn.Conv2d(3,1,3,bias=True)
# print(cov.state_dict())
#初始化卷积核所以参数为0.5
for x in cov.parameters():
    nn.init.constant_(x,0.5)

print(cov.state_dict())
d = torch.ones(3,6,6)
d = torch.unsqueeze(d,0)
print(d)
c = cov(d)
print(c)

'''
OrderedDict([('weight', tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]])), ('bias', tensor([0.5000]))])
tensor([[[[1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1.]]]])
tensor([[[[14., 14., 14., 14.],
          [14., 14., 14., 14.],
          [14., 14., 14., 14.],
          [14., 14., 14., 14.]]]], grad_fn=<ConvolutionBackward0>)
'''

从示例代码可以看出,因为我们定义的3通道输入的3*3卷积核心,就生成了3个3*3的核心,3个核心分比对3个通道进行卷积((对应位置直接相乘)然后求和加偏置),得出输出,同理如果定义卷积核输出为三,那么就会定义3*3=9个卷积核每三个卷积核分别对图像进行卷积操作,得出三个输出通道。

3.卷积操作示例

以一张图像为例打开图像,定义卷积核进行卷积操作:

python 复制代码
import cv2
import os
import numpy as np
import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from torchvision import transforms

cov = nn.Conv2d(3,3,3,bias=True)
for x in cov.parameters():
    nn.init.constant_(x,0.05)
print(cov.state_dict())

img = cv2.imread("E:/test/pythonProject/test.jpg")
img = cv2.resize(img,dsize=(320,240))
print('img.shape',img.shape)
trans = transforms.ToTensor()
timg = trans(img)
print('timg.shape',timg.shape)
cimg = cov(timg)
print('cimg.shape',cimg.shape)

timg = timg.permute(1,2,0)
ta = timg.numpy()

cimg = cimg.permute(1,2,0)
ca = cimg.data.numpy()

cv2.imshow("test",img)
cv2.imshow("ta",ta)
cv2.imshow("cimg",ca)

cv2.waitKey()

'''
OrderedDict([('weight', tensor([[[[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]]],


        [[[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]]],


        [[[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]],

         [[0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500],
          [0.0500, 0.0500, 0.0500]]]])), ('bias', tensor([0.0500, 0.0500, 0.0500]))])
img.shape (240, 320, 3)
timg.shape torch.Size([3, 240, 320])
cimg.shape torch.Size([3, 238, 318])

进程已结束,退出代码为 0
'''

这里定义的卷积核输入为3通道,输出为3通道,这里三组卷积核,每组卷积核包含三个卷积核,三个卷积核分别对三个通道进行卷积,最后每组输出一个通道,三组输出三个通道图像,因为卷积核参数一样,所以最后卷积输出的RGB值相等,输出灰色图像。

这里注意:

复制代码
cimg = cimg.permute(1,2,0)

这个函数是进行维度调换,理解不了,可以先把他转为numpy,再用cv2.merge((r,g,b))函数进行融合,cv2.split(imgt) 可以把图像重新分为 r g b 的numpy.ndarray结构,如代码所示:

python 复制代码
t = cimg.data.numpy()

r = t[0]
g = t[1]
b = t[2]

imgt = cv2.merge((r,g,b))
r,g,b = cv2.split(imgt)
print(r.shape,g.shape,b.shape)

cv2.imshow("imgt",imgt)
cv2.waitKey()

'''
(238, 318) (238, 318) (238, 318)
'''
相关推荐
春末的南方城市23 分钟前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成
zmjia11125 分钟前
AI大语言模型进阶应用及模型优化、本地化部署、从0-1搭建、智能体构建技术
人工智能·语言模型·自然语言处理
jndingxin39 分钟前
OpenCV视频I/O(14)创建和写入视频文件的类:VideoWriter介绍
人工智能·opencv·音视频
_.Switch41 分钟前
Python Web 应用中的 API 网关集成与优化
开发语言·前端·后端·python·架构·log4j
一个闪现必杀技1 小时前
Python入门--函数
开发语言·python·青少年编程·pycharm
AI完全体1 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
GZ_TOGOGO1 小时前
【2024最新】华为HCIE认证考试流程
大数据·人工智能·网络协议·网络安全·华为
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
新缸中之脑1 小时前
Ollama 运行视觉语言模型LLaVA
人工智能·语言模型·自然语言处理
小鹿( ﹡ˆoˆ﹡ )1 小时前
探索IP协议的神秘面纱:Python中的网络通信
python·tcp/ip·php