Pytorch个人学习记录总结 05

目录

神经网络的基本骨架

[卷积操作 torch.nn.functional.conv2d](#卷积操作 torch.nn.functional.conv2d)


神经网络的基本骨架

搭建Neural Network骨架主要用到的包是torch.nn,官方文档网址:torch.nn --- PyTorch 2.0 documentation,其中torch.nn.Module很重要,是所有所有神经网络模块的基类(即自己搭建的网络必须继承torch.nn.Module 基类),官方文档地址:Module --- PyTorch 2.0 documentation

搭建模型时,集成torch.nn.Module后必须要重写两个函数:__init__()forward()

python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

卷积操作 torch.nn.functional.conv2d

torch.nn包含了torch.nn.functional,两者中都包含了Conv、Pool等层操作,且用法和效果都是一样的(但是具体的输入参数有所不同)。用的torch.nn.functional.conv2d举例,但其实在以后使用中,torch.nn.Conv2d更常用。

python 复制代码
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor
python 复制代码
CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

torch.nn.functional.conv2d中的Input、weight(也就是kernel)都必须是4维张量,每维的含义是[batch_size, C, H, W],必要的时候,可用reshape()或unsqueeze()对张量进行扩维。

(1) reshape是对改变tensor的形状,各维度的乘积与原本保持一致。

(2) unsqueeze是在指定维度上扩充一个1维。

python 复制代码
import torch

x = torch.arange(15)
x2 = torch.reshape(x, [3, 5])	# 用list或tuple表示形状都可以
y1_reshape = torch.reshape(x, [1, 1, 3, 5])  # reshape:只要所有维度乘在一起的积不变,就可以任意扩充多个维度
y2_unsqueeze = torch.unsqueeze(x2, 2)	# unsequeeze:第二个参数的数据类型是int,所以只能在指定维度上扩充一个1维(升维)
c_squeeze = torch.squeeze(y1_reshape)	# sequeeze:只传入一个tensor参数,然后将tensor的所有1维删掉(降维)

print('x.shape:{}'.format(x.shape))
print('x2.shape:{}'.format(x2.shape))
print('y1_reshape.shape:{}'.format(y1_reshape.shape))
print('y2_unsqueeze.shape:{}'.format(y2_unsqueeze.shape))
print('c_squeeze.shape:{}'.format(c_squeeze.shape))
python 复制代码
import torch
import torch.nn.functional as F

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]])
kernel = torch.tensor([[1, 2, 1],
                       [0, 1, 0],
                       [2, 1, 0]])

print(input.shape)
print(kernel.shape)

# input、kernel都扩充到4维
input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))

out = F.conv2d(input, kernel, stride=1)
print('out={}'.format(out))

out2 = F.conv2d(input, kernel, stride=2)
print('out2={}'.format(out2))

out3 = F.conv2d(input, kernel, stride=1, padding=1)
print('out3={}'.format(out3))
相关推荐
张铁铁是个小胖子14 分钟前
微服务学习
java·学习·微服务
sp_fyf_20241 小时前
【大语言模型】ACL2024论文-35 WAV2GLOSS:从语音生成插值注解文本
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·数据挖掘
AITIME论道1 小时前
论文解读 | EMNLP2024 一种用于大语言模型版本更新的学习率路径切换训练范式
人工智能·深度学习·学习·机器学习·语言模型
Dovir多多2 小时前
Python数据处理——re库与pydantic的使用总结与实战,处理采集到的思科ASA防火墙设备信息
网络·python·计算机网络·安全·网络安全·数据分析
明明真系叻2 小时前
第二十六周机器学习笔记:PINN求正反解求PDE文献阅读——正问题
人工智能·笔记·深度学习·机器学习·1024程序员节
XianxinMao3 小时前
Transformer 架构对比:Dense、MoE 与 Hybrid-MoE 的优劣分析
深度学习·架构·transformer
88号技师4 小时前
2024年12月一区SCI-加权平均优化算法Weighted average algorithm-附Matlab免费代码
人工智能·算法·matlab·优化算法
IT猿手4 小时前
多目标应用(一):多目标麋鹿优化算法(MOEHO)求解10个工程应用,提供完整MATLAB代码
开发语言·人工智能·算法·机器学习·matlab
青春男大4 小时前
java栈--数据结构
java·开发语言·数据结构·学习·eclipse
88号技师4 小时前
几款性能优秀的差分进化算法DE(SaDE、JADE,SHADE,LSHADE、LSHADE_SPACMA、LSHADE_EpSin)-附Matlab免费代码
开发语言·人工智能·算法·matlab·优化算法