《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

复制代码
pretrained_net=torchvision.models.resnet18(pretrained=True)
# print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉
net=nn.Sequential(*list(pretrained_net.children())[:-2])
复制代码
num_classes=21
net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1))
net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

复制代码
def bilinear_kernel(in_channel,out_channel,kernel_size):
    factor=(kernel_size+1)//2 #上采样放大倍数
    if kernel_size %2==1:
        center=factor-1
    else:
        center=factor-0.5
    og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk,
    filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小
    weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size))
    weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化
    return weight
复制代码
W=bilinear_kernel(num_classes,num_classes,64)
net.Transposed_conv.weight.data.copy_(W)

(3)测试

复制代码
conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False)
conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4))
img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB'))
X=img.unsqueeze(0)
Y=conv_transopsed(X)
out_img=Y[0].permute(1,2,0).detach()
print('input image shape',img.permute(1,2,0).shape)
print('output image shape',out_img.shape)

d2l.set_figsize()
fig,axes=plt.subplots(1,2)
axes[0].imshow(img.permute(1,2,0))
axes[0].set_title('input image')

axes[1].imshow(out_img)
axes[1].set_title('output image')
d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)

train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)

voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'

def read_voc_images(voc_dir, is_train=True):

"""读取所有VOC图像并标注"""

这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt

txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',

'train.txt' if is_train else 'val.txt')

mode = torchvision.io.image.ImageReadMode.RGB

with open(txt_fname, 'r') as f:

images = f.read().split()

features, labels = [], []

for i, fname in enumerate(images):

读取原始图片

features.append(torchvision.io.read_image(os.path.join(

voc_dir, 'JPEGImages', f'{fname}.jpg')))

读取语义分割标签图

labels.append(torchvision.io.read_image(os.path.join(

voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))

return features, labels

3.训练

def loss(inputs,targets):

return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)

num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()

trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)

d2l.train_ch3(net,trainer,num_epochs,batch_size,device)

4.预测

def predect(img):

X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)

pred=net(X.to(device)).argmax(dim=1)#(1,h,w)

return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)

#根据类别反向找对应的rgb,将像素点涂对应的颜色

def label2image(pred):

colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)

X=pred.long()

return colormap[X,:]

test_images,test_labels=read_voc_images(voc_dir,is_train=False)

n,imags=4,[]

for i in range(n):

crop_rect=(0,0,320,480)

X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)

pred=label2image(predect(X))

imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]

d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

相关推荐
掘金安东尼2 小时前
如何为 AI 编码代理配置 Next.js 项目
人工智能
aircrushin2 小时前
轻量化大模型架构演进
人工智能·架构
文心快码BaiduComate3 小时前
百度云与光本位签署战略合作:用AI Agent 重构芯片研发流程
前端·人工智能·架构
风象南4 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia5 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮5 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬5 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia6 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区6 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两9 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent