《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

复制代码
pretrained_net=torchvision.models.resnet18(pretrained=True)
# print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉
net=nn.Sequential(*list(pretrained_net.children())[:-2])
复制代码
num_classes=21
net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1))
net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

复制代码
def bilinear_kernel(in_channel,out_channel,kernel_size):
    factor=(kernel_size+1)//2 #上采样放大倍数
    if kernel_size %2==1:
        center=factor-1
    else:
        center=factor-0.5
    og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk,
    filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小
    weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size))
    weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化
    return weight
复制代码
W=bilinear_kernel(num_classes,num_classes,64)
net.Transposed_conv.weight.data.copy_(W)

(3)测试

复制代码
conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False)
conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4))
img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB'))
X=img.unsqueeze(0)
Y=conv_transopsed(X)
out_img=Y[0].permute(1,2,0).detach()
print('input image shape',img.permute(1,2,0).shape)
print('output image shape',out_img.shape)

d2l.set_figsize()
fig,axes=plt.subplots(1,2)
axes[0].imshow(img.permute(1,2,0))
axes[0].set_title('input image')

axes[1].imshow(out_img)
axes[1].set_title('output image')
d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)

train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)

voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'

def read_voc_images(voc_dir, is_train=True):

"""读取所有VOC图像并标注"""

这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt

txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',

'train.txt' if is_train else 'val.txt')

mode = torchvision.io.image.ImageReadMode.RGB

with open(txt_fname, 'r') as f:

images = f.read().split()

features, labels = [], []

for i, fname in enumerate(images):

读取原始图片

features.append(torchvision.io.read_image(os.path.join(

voc_dir, 'JPEGImages', f'{fname}.jpg')))

读取语义分割标签图

labels.append(torchvision.io.read_image(os.path.join(

voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))

return features, labels

3.训练

def loss(inputs,targets):

return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)

num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()

trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)

d2l.train_ch3(net,trainer,num_epochs,batch_size,device)

4.预测

def predect(img):

X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)

pred=net(X.to(device)).argmax(dim=1)#(1,h,w)

return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)

#根据类别反向找对应的rgb,将像素点涂对应的颜色

def label2image(pred):

colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)

X=pred.long()

return colormap[X,:]

test_images,test_labels=read_voc_images(voc_dir,is_train=False)

n,imags=4,[]

for i in range(n):

crop_rect=(0,0,320,480)

X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)

pred=label2image(predect(X))

imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]

d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

相关推荐
好奇龙猫2 小时前
【人工智能学习-AI入试相关题目练习-第六次】
人工智能·学习
Piar1231sdafa2 小时前
【深度学习】YOLOv8-SPDConv筷子部件识别与分类系统实战
深度学习·yolo·分类
咚咚王者2 小时前
人工智能之核心基础 机器学习 第十七章 Scikit-learn工具全解析
人工智能·机器学习·scikit-learn
向上的车轮2 小时前
VS Code在AI编辑器关键问题上处理如何?
人工智能·编辑器
沛沛老爹2 小时前
Web开发者进阶AI:企业级Agent Skills安全策略与合规架构实战
前端·人工智能·架构
说私域2 小时前
基于AI客服链动2+1模式商城小程序的社群运营策略研究——以千人社群活跃度提升为例
人工智能·微信·小程序·私域运营
大猫子的技术日记2 小时前
从DALL·E到Seedream:AI文生图技术全景速览与实战指南
人工智能
无bug代码搬运工2 小时前
文献阅读:Class-incremental Learning for Time Series:Benchmark and Evaluation
人工智能·深度学习·transformer
乾元2 小时前
智能化侦察:利用 LLM 进行自动化资产暴露面识别与关联
运维·网络·人工智能·网络协议·安全·自动化