自定义层和读写文件

自定义层

自定义一个没有任何参数的层

python 复制代码
import torch
import torch.nn.functional as F
from torch import nn

class CenteredLayer(nn.Module):
	def __init__(self):
		super().__init__()
	
	def forward(self, X):
		return X - X.mean()


layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

将层作为组件和冰岛构建更复杂的模型中

python 复制代码
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

Y = net(torch.rand(4m 8))
Y.mean()

带参数的层

python 复制代码
class MyLinear(nn.Module):
	def __init__(self, in_units, units):
		super().__init__()
		self.weight = nn.Parameter(torch.randn(in_units, units))
		self.bias = nn.Parameter(torch.randn(units,))
	
	def forward(self, X):
		linear = torch.matmul(X, self.weight.data) + self.bias.data
		return F.relu(linear)

dense = MyLinear(5, 3)
dense.weight

使用自定义的层执行传播计算

python 复制代码
dense(torch.rand(2, 5))

读写文件

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
x2 == x

存储一个张量列表

python 复制代码
y = torch.zeros(4)
torch.save([x, y], 'x-files')
x2, y2 = torch.load('x-files')

写入或读取字典

python 复制代码
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')

加载和保存模型参数

python 复制代码
class MLP(nn.Module):
	def __init__(self):
		super().__init__()
		self.hidden = nn.Linear(20, 256)
		self.output = nn.Linear(256, 10)
	
	def forward(self, x):
		return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

将模型存储为文件

python 复制代码
torch.save(net.state_dict(), 'mlp.params')

# 保存参数后需要我们自己保存MLP的定义, 需要有定义才能加载
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
相关推荐
Li emily5 小时前
港股api接入指南:实时行情与历史数据获取
python·api·fastapi
AI技术增长5 小时前
Pytorch图像去噪实战(十三):DDIM加速扩散模型采样,让去噪从1000步降到50步
人工智能·pytorch·python
刀法如飞5 小时前
Python列表去重:从新手三连到高阶特技,20种解法全收录
python·算法·编程语言
小糖学代码5 小时前
LLM系列:1.python入门:16.正则表达式与文本处理 (re)
人工智能·pytorch·python·深度学习·神经网络·正则表达式
清水白石0086 小时前
从“类型体操”到工程设计:用 Python 解释协变、逆变与不变
网络·windows·python
Ai173163915796 小时前
10大算力芯片某某XXU全解析:CPU/GPU/TPU/NPU/LPU/FPGA/RPU/BPU/DPU/GPGPU
大数据·图像处理·人工智能·深度学习·计算机视觉·自动驾驶·知识图谱
我是大聪明.6 小时前
大模型Tokenizer原理:深入理解BPE与WordPiece子词编码技术
人工智能·深度学习·机器学习
人工智能培训6 小时前
工程科研中的AI应用:结构力学分析技巧
人工智能·深度学习·机器学习·docker·容器
hrhcode6 小时前
【LangGraph】四.持久化:保存和恢复执行状态
python·ai·langchain·agent·langgraph
xxyy8887 小时前
关于labelimg安装后在标注过程中闪退和死机的问题处理
开发语言·python