《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

复制代码
pretrained_net=torchvision.models.resnet18(pretrained=True)
# print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉
net=nn.Sequential(*list(pretrained_net.children())[:-2])
复制代码
num_classes=21
net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1))
net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

复制代码
def bilinear_kernel(in_channel,out_channel,kernel_size):
    factor=(kernel_size+1)//2 #上采样放大倍数
    if kernel_size %2==1:
        center=factor-1
    else:
        center=factor-0.5
    og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk,
    filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小
    weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size))
    weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化
    return weight
复制代码
W=bilinear_kernel(num_classes,num_classes,64)
net.Transposed_conv.weight.data.copy_(W)

(3)测试

复制代码
conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False)
conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4))
img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB'))
X=img.unsqueeze(0)
Y=conv_transopsed(X)
out_img=Y[0].permute(1,2,0).detach()
print('input image shape',img.permute(1,2,0).shape)
print('output image shape',out_img.shape)

d2l.set_figsize()
fig,axes=plt.subplots(1,2)
axes[0].imshow(img.permute(1,2,0))
axes[0].set_title('input image')

axes[1].imshow(out_img)
axes[1].set_title('output image')
d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)

train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)

voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'

def read_voc_images(voc_dir, is_train=True):

"""读取所有VOC图像并标注"""

这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt

txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',

'train.txt' if is_train else 'val.txt')

mode = torchvision.io.image.ImageReadMode.RGB

with open(txt_fname, 'r') as f:

images = f.read().split()

features, labels = [], []

for i, fname in enumerate(images):

读取原始图片

features.append(torchvision.io.read_image(os.path.join(

voc_dir, 'JPEGImages', f'{fname}.jpg')))

读取语义分割标签图

labels.append(torchvision.io.read_image(os.path.join(

voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))

return features, labels

3.训练

def loss(inputs,targets):

return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)

num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()

trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)

d2l.train_ch3(net,trainer,num_epochs,batch_size,device)

4.预测

def predect(img):

X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)

pred=net(X.to(device)).argmax(dim=1)#(1,h,w)

return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)

#根据类别反向找对应的rgb,将像素点涂对应的颜色

def label2image(pred):

colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)

X=pred.long()

return colormap[X,:]

test_images,test_labels=read_voc_images(voc_dir,is_train=False)

n,imags=4,[]

for i in range(n):

crop_rect=(0,0,320,480)

X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)

pred=label2image(predect(X))

imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]

d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

相关推荐
摘星编程几秒前
RAG重塑搜索:如何用检索增强生成打造企业级AI问答系统
人工智能
啊阿狸不会拉杆1 分钟前
《机器学习导论》第 9 章-决策树
人工智能·python·算法·决策树·机器学习·数据挖掘·剪枝
曦月逸霜6 分钟前
机器学习——个人笔记(持续更新中~)
人工智能·机器学习
新缸中之脑8 分钟前
30个最好的3D相关AI代理技能
人工智能·3d
Pyeako9 分钟前
opencv计算机视觉--LBPH&EigenFace&FisherFace人脸识别
人工智能·python·opencv·计算机视觉·lbph·eigenface·fisherface
工程师老罗11 分钟前
举例说明YOLOv1 输出坐标到原图像素的映射关系
人工智能·yolo·计算机视觉
猫头虎12 分钟前
手动部署开源OpenClaw汉化中文版过程中常见问题排查手册
人工智能·langchain·开源·github·aigc·agi·openclaw
多恩Stone14 分钟前
【3D AICG 系列-9】Trellis2 推理流程图超详细介绍
人工智能·python·算法·3d·aigc·流程图
整得咔咔响15 分钟前
贝尔曼最优公式(BOE)
人工智能·算法·机器学习
2501_9469614718 分钟前
极简大气创业融资 PPT 模板,适合路演、项目宣讲
人工智能·排序算法