**前置知识:
nn:neural network神经网络
1、torch.nn:与神经网络有关的库
Containers:torch.nn中的一个模块
Module:所有神经网络模型的基础类(Base class for all neural network modules)
注意:任何自定义的模型都应该继承自 nn.Module
,并实现 __init__
和 forward
方法,以定义模型的结构和前向传播的逻辑
2、x=torch.tensor(1.0):将输入数据转为张量,因为模型期望收到张量(tensor)类型的输入
因为模型的操作(如加法、矩阵乘法等)都是在张量上进行的
张量可以简单理解为一种多维数组,用于表示数据
-
标量(0维张量) :一个单一的数字,比如
5
-
向量(1维张量) :一组数字,比如
[1, 2, 3],
可以看作是一条线上的点 -
矩阵(2维张量):一个数字的表格,比如[ [1, 2], [3, 4] ]
-
更高维的张量(3维及以上) :想象一下一个立方体,里面有许多数字,比如颜色的RGB值。更高维的张量可以表示更复杂的数据结构,比如视频帧、三维图像等
3、forward方法:如何由input计算得到output
forward和__call__的联系:
forward
是你定义的前向传播逻辑,用于计算输出的方法
__call__
是一个特殊方法,用于使得模型实例可以像函数一样被调用,并负责调用 forward
以及处理其他一些功能
所以能让实例像函数一样被调用的实际上是__call__而不是forward
**代码:
自定义新模型:
继承nn.Module基类------>重写__init__方法和forward方法
python
import torch
from torch import nn
class Xigua(nn.Module):
def __init__(self):
super().__init__()
def forward(self,input):
output=input+1
return output
xigua1=Xigua() #先实例化新模型类,才能把它作为工具(一般有__call__方法的都这样做)
x=torch.tensor(1.0)
output=xigua1(x)
print(output)