import torch
import torch.nn as nn
import torch.nn.functional as F
class ToyModel(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
return out
class ToyModel2(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckpt
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
out += self.b1
return out
class ToyModel3(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
self.b1 = nn.Parameter(torch.randn(outChannels),)
self.register_buffer("c1", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
out += self.b1
out += self.c1
return out
python复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from pathlib import Path
from models import ToyModel2, ToyModel, ToyModel3
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s')
if __name__ == "__main__":
savePath = Path("toymodel3.pth")
logger = logging.getLogger(__name__)
inp = torch.randn(3, 5)
model = ToyModel3(inp.size(1), inp.size(1) * 2)
pred = model(inp)
logger.info(f"{pred.size()=}")
for m in model.modules():
logger.info(m)
for name, param in model.named_parameters():
logger.info(f"{name = }, {param.size() = }, {param.requires_grad=}")
for name, buffer in model.named_buffers():
logger.info(f"{name = }, {buffer.size() = }")
torch.save(model.state_dict(), savePath)
python复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from models import ToyModel, ToyModel2, ToyModel3
if __name__ == "__main__":
savePath = Path("toymodel3.pth")
inp = torch.randn(3, 5)
model = ToyModel3(inp.size(1), inp.size(1) * 2)
ckpt = torch.load(savePath, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt)
pred = model(inp)
print(f"{pred.size()=}")
for m in model.modules():
print(m)
for name, param in model.named_parameters():
print(f"{name = }, {param.size() = }, {param.requires_grad=}")
for name, buffer in model.named_buffers():
print(f"{name = }, {buffer.size() = }")