PyTorch:6-可视化
注:所有资料来源且归属于thorough-pytorch(https://datawhalechina.github.io/thorough-pytorch/),下文仅为学习记录
6.1:可视化网络结构
Keras中可以调用model.summary()
的API进行模型参数可视化
torchinfo
是由torchsummary
和torchsummaryX
重构出的库,用于可视化网络结构
6.1.1:使用print函数,打印模型基础信息
【案例:resnet18】
模型构建:
python
import torchvision.models as models
model = models.resnet18()
直接print模型:只能得出基础构件的信息
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
... ...
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
结果:既不能显示出每一层的shape,也不能显示对应参数量的大小。
6.1.2:使用torchinfo,可视化网络结构
安装:
bash
# 安装方法一
pip install torchinfo
# 安装方法二
conda install -c conda-forge torchinfo
使用:
使用torchinfo.summary()
函数,必需的参数分别是model,input_size[batch_size,channel,h,w]。
python
import torchvision.models as models
from torchinfo import summary
resnet18 = models.resnet18()
# 实例化模型
summary(resnet18, (1, 3, 224, 224))
# 1:batch_size 3:图片的通道数 224: 图片的高宽
结构化输出:
bash
=========================================================================================
Layer (type:depth-idx) Output Shape Param #
=========================================================================================
ResNet -- --
├─Conv2d: 1-1 [1, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128
├─ReLU: 1-3 [1, 64, 112, 112] --
├─MaxPool2d: 1-4 [1, 64, 56, 56] --
├─Sequential: 1-5 [1, 64, 56, 56] --
│ └─BasicBlock: 2-1 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-1 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [1, 64, 56, 56] --
│ └─BasicBlock: 2-2 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-8 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-9 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-10 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-11 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-12 [1, 64, 56, 56] --
├─Sequential: 1-6 [1, 128, 28, 28] --
│ └─BasicBlock: 2-3 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-13 [1, 128, 28, 28] 73,728
│ │ └─BatchNorm2d: 3-14 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-15 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-16 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-17 [1, 128, 28, 28] 256
│ │ └─Sequential: 3-18 [1, 128, 28, 28] 8,448
│ │ └─ReLU: 3-19 [1, 128, 28, 28] --
│ └─BasicBlock: 2-4 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-20 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-21 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-22 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-23 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-24 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-25 [1, 128, 28, 28] --
├─Sequential: 1-7 [1, 256, 14, 14] --
│ └─BasicBlock: 2-5 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-26 [1, 256, 14, 14] 294,912
│ │ └─BatchNorm2d: 3-27 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-28 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-29 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-30 [1, 256, 14, 14] 512
│ │ └─Sequential: 3-31 [1, 256, 14, 14] 33,280
│ │ └─ReLU: 3-32 [1, 256, 14, 14] --
│ └─BasicBlock: 2-6 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-33 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-34 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-35 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-36 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-37 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-38 [1, 256, 14, 14] --
├─Sequential: 1-8 [1, 512, 7, 7] --
│ └─BasicBlock: 2-7 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-39 [1, 512, 7, 7] 1,179,648
│ │ └─BatchNorm2d: 3-40 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-41 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-42 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-43 [1, 512, 7, 7] 1,024
│ │ └─Sequential: 3-44 [1, 512, 7, 7] 132,096
│ │ └─ReLU: 3-45 [1, 512, 7, 7] --
│ └─BasicBlock: 2-8 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-46 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-47 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-48 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-49 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-50 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-51 [1, 512, 7, 7] --
├─AdaptiveAvgPool2d: 1-9 [1, 512, 1, 1] --
├─Linear: 1-10 [1, 1000] 513,000
=========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81
=========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11
=========================================================================================
注意:使用colab或者jupyter notebook时,想要实现该方法,summary()
一定是该单元(即notebook中的cell)的返回值,否则就需要使用print(summary(...))
来可视化。
6.2:CNN可视化
可视化内容:可视化特征是如何提取的、提取到的特征的形式、模型在输入数据上的关注点
6.2.1:CNN卷积核可视化
卷积核在CNN中负责提取特征------可视化特征是如何提取的
靠近输入的层提取的特征是相对简单的结构,靠近输出的层提取的特征和图中的实体形状相近
kernel可视化的核心:特定层的卷积核即特定层的模型权重,可视化卷积核即可视化对应的权重矩阵
【案例:VGG11】
【1】加载模型,确定层信息
python
import torch
from torchvision.models import vgg11
model = vgg11(pretrained=True)
print(dict(model.features.named_children()))
"""
{'0': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'1': ReLU(inplace=True),
'2': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
'3': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'4': ReLU(inplace=True),
'5': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
'6': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'7': ReLU(inplace=True),
'8': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'9': ReLU(inplace=True),
'10': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
'11': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'12': ReLU(inplace=True),
'13': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'14': ReLU(inplace=True),
'15': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
'16': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'17': ReLU(inplace=True),
'18': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
'19': ReLU(inplace=True),
'20': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)}
"""
【2】可视化卷积层的对应参数(第3层)
卷积核对应的应为卷积层(Conv2d)
python
conv1 = dict(model.features.named_children())['3']
kernel_set = conv1.weight.detach()
num = len(conv1.weight.detach())
print(kernel_set.shape)
"""
torch.Size([128, 64, 3, 3])
"""
for i in range(0,num):
i_kernel = kernel_set[i]
plt.figure(figsize=(20, 17))
if (len(i_kernel)) > 1:
for idx, filer in enumerate(i_kernel):
plt.subplot(9, 9, idx+1)
plt.axis('off')
plt.imshow(filer[ :, :].detach(),cmap='bwr')
由于第3层的特征图由64维变为128维,因此共有128*64个卷积核
6.2.2:CNN特征图可视化
特征图:输入的原始图像经过每次卷积层得到的数据
可视化卷积核是为了看模型提取哪些特征,可视化特征图则是为了看模型提取到的特征是什么样子的。
PyTorch提供了一个专用的接口,使得网络在前向传播过程中能够获取到特征图,接口的名称叫hook。
实现过程:
python
class Hook(object):
def __init__(self):
self.module_name = []
self.features_in_hook = []
self.features_out_hook = []
def __call__(self,module, fea_in, fea_out):
print("hooker working", self)
self.module_name.append(module.__class__)
self.features_in_hook.append(fea_in)
self.features_out_hook.append(fea_out)
return None
def plot_feature(model, idx, inputs):
hh = Hook()
model.features[idx].register_forward_hook(hh)
# forward_model(model,False)
model.eval()
_ = model(inputs)
print(hh.module_name)
print((hh.features_in_hook[0][0].shape))
print((hh.features_out_hook[0].shape))
out1 = hh.features_out_hook[0]
total_ft = out1.shape[1]
first_item = out1[0].cpu().clone()
plt.figure(figsize=(20, 17))
for ftidx in range(total_ft):
if ftidx > 99:
break
ft = first_item[ftidx]
plt.subplot(10, 10, ftidx+1)
plt.axis('off')
#plt.imshow(ft[ :, :].detach(),cmap='gray')
plt.imshow(ft[ :, :].detach())
首先实现了一个hook类,之后在plot_feature函数中,将该hook类的对象注册到要进行可视化的网络的某层中。
model在进行前向传播的时候会调用hook的__call__函数,Hook类在此处存储了当前层的输入和输出。
Hook类种的hook(输入为in,输出为out)是一个list,每次前向传播一次,都是调用一次,即 hook 长度会增加1。
6.2.3:CNN class activation map可视化
class activation map (CAM)的作用是判断哪些变量对模型来说是重要的。
在CNN可视化的场景下,即判断图像中哪些像素点对预测结果是重要的。
CAM系列操作的实现可以通过开源工具包pytorch-grad-cam
来实现。
- 安装:
bash
pip install grad-cam
- 案例:
加载图片
python
import torch
from torchvision.models import vgg11,resnet18,resnet101,resnext101_32x8d
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
model = vgg11(pretrained=True)
img_path = './dog.png'
# resize操作是为了和传入神经网络训练图片大小一致
img = Image.open(img_path).resize((224,224))
# 需要将原始图片转为np.float32格式并且在0-1之间
rgb_img = np.float32(img)/255
plt.imshow(img)
CAM可视化
python
from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# 将图片转为tensor
img_tensor = torch.from_numpy(rgb_img).permute(2,0,1).unsqueeze(0)
target_layers = [model.features[-1]]
# 选取合适的类激活图,但是ScoreCAM和AblationCAM需要batch_size
cam = GradCAM(model=model,target_layers=target_layers)
targets = [ClassifierOutputTarget(preds)]
# 上方preds需要设定,比如ImageNet有1000类,这里可以设为200
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
print(type(cam_img))
Image.fromarray(cam_img)
6.2.4:FlashTorch快速实现CNN可视化
https://github.com/MisaOgura/flashtorch
- 安装
bash
pip install flashtorch
- 可视化梯度
python
import matplotlib.pyplot as plt
import torchvision.models as models
from flashtorch.utils import apply_transforms, load_image
from flashtorch.saliency import Backprop
model = models.alexnet(pretrained=True)
backprop = Backprop(model)
image = load_image('/content/images/great_grey_owl.jpg')
owl = apply_transforms(image)
target_class = 24
backprop.visualize(owl, target_class, guided=True, use_gpu=True)
- 可视化卷积核
python
import torchvision.models as models
from flashtorch.activmax import GradientAscent
model = models.vgg16(pretrained=True)
g_ascent = GradientAscent(model.features)
# specify layer and filter info
conv5_1 = model.features[24]
conv5_1_filters = [45, 271, 363, 489]
g_ascent.visualize(conv5_1, conv5_1_filters, title="VGG16: conv5_1")
6.3:使用TensorBoard可视化训练过程
6.3.1:安装
使用pip安装:
pip install tensorboardX
6.3.2:TensorBoard可视化的基本逻辑
可将TensorBoard看做一个记录员,记录我们指定的数据,包括模型每一层的feature map,权重,训练loss等。
TensorBoard将记录下来的内容保存在一个用户指定的文件夹里,程序不断运行中TensorBoard会不断记录,记录下的内容可以通过网页的形式加以可视化。
6.3.3:TensorBoard的配置和启动
【1】指定保存记录数据的文件夹,调用tensorboard中的SummaryWriter作为记录员
python
from tensorboardX import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
# 使用PyTorch自带的tensorboard
writer = SummaryWriter('./runs')
上面的操作实例化SummaryWritter为变量writer,并指定writer的输出目录为当前目录下的"runs"目录。
【2】启动tensorboard
bash
tensorboard --logdir=/path/to/logs/ --port=xxxx
"path/to/logs/"是指定的保存tensorboard记录结果的文件路径
--port是外部访问TensorBoard的端口号,可以通过访问ip:port访问tensorboard
6.3.4:TensorBoard模型结构可视化
【1】定义模型
python
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)
self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)
self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(64,32)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(32,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = self.adaptive_pool(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
y = self.sigmoid(x)
return y
model = Net()
print(model)
"""
Net(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
(adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear1): Linear(in_features=64, out_features=32, bias=True)
(relu): ReLU()
(linear2): Linear(in_features=32, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
"""
可视化模型的思路:给定一个输入数据,前向传播后得到模型的结构,再通过TensorBoard进行可视化
【2】使用add_graph
python
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))
writer.close()
6.3.5:TensorBoard图像可视化
- 对于单张图片的显示使用
add_image
- 对于多张图片的显示使用
add_images
- 有时需要使用
torchvision.utils.make_grid
将多张图片拼成一张图片后,用writer.add_image
显示
【案例:CIFAR10】
python
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform_train = transforms.Compose(
[transforms.ToTensor()])
transform_test = transforms.Compose(
[transforms.ToTensor()])
train_data = datasets.CIFAR10(".", train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(".", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)
images, labels = next(iter(train_loader))
# 仅查看一张图片
writer = SummaryWriter('./pytorch_tb')
writer.add_image('images[0]', images[0])
writer.close()
# 将多张图片拼接成一张图片,中间用黑色网格分割
# create grid of images
writer = SummaryWriter('./pytorch_tb')
img_grid = torchvision.utils.make_grid(images)
writer.add_image('image_grid', img_grid)
writer.close()
# 将多张图片直接写入
writer = SummaryWriter('./pytorch_tb')
writer.add_images("images",images,global_step = 0)
writer.close()
6.3.6:TensorBoard连续变量可视化
可视化连续变量(或时序变量)的变化过程,通过add_scalar
实现
python
writer = SummaryWriter('./pytorch_tb')
for i in range(500):
x = i
y = x**2
writer.add_scalar("x", x, i) #日志中记录x在第step i 的值
writer.add_scalar("y", y, i) #日志中记录y在第step i 的值
writer.close()
如果想在同一张图中显示多个曲线,则需要分别建立存放子路径(使用SummaryWriter指定路径即可自动创建,但需要在tensorboard运行目录下),同时在add_scalar中修改曲线的标签使其一致即可。
python
writer1 = SummaryWriter('./pytorch_tb/x')
writer2 = SummaryWriter('./pytorch_tb/y')
for i in range(500):
x = i
y = x*2
writer1.add_scalar("same", x, i) #日志中记录x在第step i 的值
writer2.add_scalar("same", y, i) #日志中记录y在第step i 的值
writer1.close()
writer2.close()
6.3.7:TensorBoard参数分布可视化
对参数(或向量)的变化,或者对其分布进行研究时,可通过add_histogram
实现。
python
import torch
import numpy as np
# 创建正态分布的张量模拟参数矩阵
def norm(mean, std):
t = std * torch.randn((100, 20)) + mean
return t
writer = SummaryWriter('./pytorch_tb/')
for step, mean in enumerate(range(-10, 10, 1)):
w = norm(mean, 1)
writer.add_histogram("w", w, step)
writer.flush()
writer.close()
6.3.8:服务器端使用TensorBoard
由于服务器端没有浏览器(纯命令模式),因此需要进行相应的配置,才可以在本地浏览器,使用tensorboard查看服务器运行的训练过程。
方法【1】【2】都是建立SSH隧道,实现远程端口到本机端口的转发。
【1】MobaXterm
- 在MobaXterm点击Tunneling。
- 选择New SSH tunnel。
- 对新建的SSH通道做以下设置,第一栏选择
Local port forwarding
,< Remote Server>
处填写localhost ,< Remote port>
处填写6006,tensorboard默认会在6006端口进行显示。也可以根据 tensorboard --logdir=/path/to/logs/ --port=xxxx 的命令中的port进行修改,< SSH server>
填写连接服务器的ip地址,<SSH login>
填写连接的服务器的用户名,<SSH port>
填写端口号(通常为22),< forwarded port>
填写本地的一个端口号,以便后续进行访问。 - 设定好之后,点击Save,然后Start。再次启动tensorboard,在本地的浏览器输入
http://localhost:6006/
对其进行访问。
【2】Xshell
- 连接上服务器后,打开当前会话属性,选择隧道,点击添加。
- 目标主机代表的是服务器,源主机代表的是本地,端口的选择根据实际情况而定。
- 启动tensorboard,在本地127.0.0.1:6006 或者 localhost:6006进行访问。
6.4:使用wandb可视化训练过程
wandb是Weights & Biases的缩写,能自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。
6.4.1:安装
【1】使用pip安装
bash
pip install wandb
【2】在官网注册账号并复制API keys:https://wandb.ai/
【3】在本地使用命令登录
bash
wandb login
【4】粘贴API keys
6.4.2:使用
python
import wandb
wandb.init(project='my-project', entity='my-name')
Quickstart | Weights & Biases Documentation (wandb.ai)
project和entity是在wandb上创建的项目名称和用户名
6.4.3:demo演示
【案例:CIFAR10的图像分类】
【1】导入库
python
import random # to set the python random seed
import numpy # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings('ignore')
【2】初始化wandb
python
# 初始化wandb
import wandb
wandb.init(project="thorough-pytorch",
name="wandb_demo",)
【3】设置超参数
使用wandb.config来设置超参数,这样就可以在wandb的界面上看到超参数的变化。
wandb.config的使用方法和字典类似,可以使用config.key的方式来设置超参数。
python
# 超参数设置
config = wandb.config # config的初始化
config.batch_size = 64
config.test_batch_size = 10
config.epochs = 5
config.lr = 0.01
config.momentum = 0.1
config.use_cuda = True
config.seed = 2043
config.log_interval = 10
# 设置随机数
def set_seed(seed):
random.seed(config.seed)
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)
【4】构建train和test的pipeline
python
def train(model, device, train_loader, optimizer):
model.train()
for batch_id, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
loss.backward()
optimizer.step()
# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):
model.eval()
test_loss = 0
correct = 0
example_images = []
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
criterion = nn.CrossEntropyLoss()
test_loss += criterion(output, target).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
example_images.append(wandb.Image(
data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))
# 使用wandb.log 记录你想记录的指标
wandb.log({
"Examples": example_images,
"Test Accuracy": 100. * correct / len(test_loader.dataset),
"Test Loss": test_loss
})
wandb.watch_called = False
def main():
use_cuda = config.use_cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# 设置随机数
set_seed(config.seed)
torch.backends.cudnn.deterministic = True
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据
train_loader = DataLoader(datasets.CIFAR10(
root='dataset',
train=True,
download=True,
transform=transform
), batch_size=config.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(datasets.CIFAR10(
root='dataset',
train=False,
download=True,
transform=transform
), batch_size=config.batch_size, shuffle=False, **kwargs)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = resnet18(pretrained=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
wandb.watch(model, log="all")
for epoch in range(1, config.epochs + 1):
train(model, device, train_loader, optimizer)
test(model, device, test_loader, classes)
# 本地和云端模型保存
torch.save(model.state_dict(), 'model.pth')
wandb.save('model.pth')
if __name__ == '__main__':
main()
其他提供的功能:模型的超参数搜索,模型的版本控制,模型的部署等。