《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

复制代码
pretrained_net=torchvision.models.resnet18(pretrained=True)
# print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉
net=nn.Sequential(*list(pretrained_net.children())[:-2])
复制代码
num_classes=21
net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1))
net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

复制代码
def bilinear_kernel(in_channel,out_channel,kernel_size):
    factor=(kernel_size+1)//2 #上采样放大倍数
    if kernel_size %2==1:
        center=factor-1
    else:
        center=factor-0.5
    og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk,
    filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小
    weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size))
    weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化
    return weight
复制代码
W=bilinear_kernel(num_classes,num_classes,64)
net.Transposed_conv.weight.data.copy_(W)

(3)测试

复制代码
conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False)
conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4))
img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB'))
X=img.unsqueeze(0)
Y=conv_transopsed(X)
out_img=Y[0].permute(1,2,0).detach()
print('input image shape',img.permute(1,2,0).shape)
print('output image shape',out_img.shape)

d2l.set_figsize()
fig,axes=plt.subplots(1,2)
axes[0].imshow(img.permute(1,2,0))
axes[0].set_title('input image')

axes[1].imshow(out_img)
axes[1].set_title('output image')
d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)

train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)

voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'

def read_voc_images(voc_dir, is_train=True):

"""读取所有VOC图像并标注"""

这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt

txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',

'train.txt' if is_train else 'val.txt')

mode = torchvision.io.image.ImageReadMode.RGB

with open(txt_fname, 'r') as f:

images = f.read().split()

features, labels = [], []

for i, fname in enumerate(images):

读取原始图片

features.append(torchvision.io.read_image(os.path.join(

voc_dir, 'JPEGImages', f'{fname}.jpg')))

读取语义分割标签图

labels.append(torchvision.io.read_image(os.path.join(

voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))

return features, labels

3.训练

def loss(inputs,targets):

return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)

num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()

trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)

d2l.train_ch3(net,trainer,num_epochs,batch_size,device)

4.预测

def predect(img):

X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)

pred=net(X.to(device)).argmax(dim=1)#(1,h,w)

return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)

#根据类别反向找对应的rgb,将像素点涂对应的颜色

def label2image(pred):

colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)

X=pred.long()

return colormap[X,:]

test_images,test_labels=read_voc_images(voc_dir,is_train=False)

n,imags=4,[]

for i in range(n):

crop_rect=(0,0,320,480)

X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)

pred=label2image(predect(X))

imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]

d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

相关推荐
CoovallyAIHub9 分钟前
Energies | 8版YOLO对8版Transformer实测光伏缺陷检测,RF-DETR-Small综合胜出
深度学习·算法·计算机视觉
Kel36 分钟前
深入剖析 openai-node 源码:一个工业级 TypeScript SDK 的架构之美
javascript·人工智能·架构
岛雨QA1 小时前
Skill学习指南🧑‍💻
人工智能·agent·ai编程
zh路西法1 小时前
【宇树机器人强化学习】(七):复杂地形的生成与训练
python·深度学习·机器学习·机器人
波动几何1 小时前
从人性到无名:一条向内的觉悟之路
人工智能
EllenLiu1 小时前
架构演进与性能压榨:在金融 RAG 中引入条款森林 (FoC)
人工智能·架构
IT_陈寒1 小时前
深入理解JavaScript:核心原理与最佳实践
前端·人工智能·后端
Presto1 小时前
AI 时代 .env 文件不再安全——我试图找到替代方案,然后撞上了一堵墙
人工智能
IT WorryFree1 小时前
OpenClaw-Medical-Skills 仓库介绍
人工智能·skill·openclaw
多年小白2 小时前
今日AI科技简报 | 2026年3月19日
人工智能·科技·ai编程