最近看代码,发现代码中有wandb有关的内容,搜索了一下发现是一个模型训练工具,然后学习了一下,这里记录一下使用过程,方便以后查阅。
WandB使用笔记
- [登录WandB 并 创建团队](#登录WandB 并 创建团队)
- [安装 WandB 并 登录](#安装 WandB 并 登录)
- 模型训练过程跟踪
- 模型版本管理
- 自动调参
- 不同的模型训练工具对比
- 参考资料
作者自注:之前训练模型一直使用的是Visdom,感觉非常好用,然后现在学习了一下WandB,发现先各有优劣。Visdom的曲线实时跟踪效果好,但是功能简单。WandB曲线实时跟踪效果差(可能是我的网的问题),但是功能强大,可以保存每次模型调优的参数,这样就不用手动再记录了;可以实现模型的版本管理,这样就可以随便改代码,不用担心改坏了;可以进行参数分析,这样就可以有目的的进行参数调优;可以进行自动调参,这样可在完成粗调制后进行局部的参数寻优。感觉以后两个可以同时使用,提高模型调优的效率
登录WandB 并 创建团队
点击下面的网站进入WandB:https://wandb.ai/site,然后点击界面中的 LOGIN
进行登录。
如下需要选择登录的方式,这里我选择的是 GitHub 。
完成登陆后进入如下初始界面,点击图片中红框中的内容,创建一个新的 team
。
之后进入如下界面,输入团队名称,并点击 Create team
,完成团队的创建。
团队创建成功后出现如下界面,选择是否把自己的 runs
更新到 team
,这里选择 Update
。
如此就完成了登录和团建创建过程!
如果想要删除创建的团队,则在主界面点击创建的团队,如下图所示:
进入团队后,点击 Team settings
,如下图所示:
接着滑动到最下面,点击 Delete team
:
接着需要你输入 团队的名称 进行删除,这里的逻辑跟GitHub删除项目一样。
安装 WandB 并 登录
使用 pip
安装 WandB:
bash
pip install wandb
验证安装是否成功:
bash
wandb --version
首次使用 WandB 时,需要登录账户:
bash
wandb login
登录后,WandB 会提示输入 API 密钥。可以从 WandB 的 API 密钥页面 获取密钥,点击图片中的红框部分,复制密钥,然后粘贴到上图的 3 标识的地方,并点击回车,如此就完成了登录过程。
如果你之前已经登陆过了,则会出现如下的内容:
然后在终端输入如下的命令即可重新登录:
bash
wandb login --relogin
模型训练过程跟踪
将如下代码复制到PyCharm中,进行实验。
python
import wandb
import torch
from torch import nn
import torchvision
from torchvision import transforms
import datetime
from argparse import Namespace
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(
project_name='wandb_demo',
batch_size=512,
hidden_layer_width=64,
dropout_p=0.1,
lr=1e-4,
optim_type='Adam',
epochs=150,
ckpt_path='checkpoint.pt'
)
def create_dataloaders(config):
transform = transforms.Compose([transforms.ToTensor()])
ds_train = torchvision.datasets.MNIST(root="./mnist/", train=True, download=True, transform=transform)
ds_val = torchvision.datasets.MNIST(root="./mnist/", train=False, download=True, transform=transform)
ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
dl_train = torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True, drop_last=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, drop_last=True)
return dl_train, dl_val
def create_net(config):
net = nn.Sequential()
net.add_module("conv1", nn.Conv2d(in_channels=1, out_channels=config.hidden_layer_width, kernel_size=3))
net.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))
net.add_module("conv2", nn.Conv2d(in_channels=config.hidden_layer_width,
out_channels=config.hidden_layer_width, kernel_size=5))
net.add_module("pool2", nn.MaxPool2d(kernel_size=2, stride=2))
net.add_module("dropout", nn.Dropout2d(p=config.dropout_p))
net.add_module("adaptive_pool", nn.AdaptiveMaxPool2d((1, 1)))
net.add_module("flatten", nn.Flatten())
net.add_module("linear1", nn.Linear(config.hidden_layer_width, config.hidden_layer_width))
net.add_module("relu", nn.ReLU())
net.add_module("linear2", nn.Linear(config.hidden_layer_width, 10))
net.to(device)
return net
def train_epoch(model, dl_train, optimizer):
model.train()
for step, batch in enumerate(dl_train):
features, labels = batch
features, labels = features.to(device), labels.to(device)
preds = model(features)
loss = nn.CrossEntropyLoss()(preds, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return model
def eval_epoch(model, dl_val):
model.eval()
accurate = 0
num_elems = 0
for batch in dl_val:
features, labels = batch
features, labels = features.to(device), labels.to(device)
with torch.no_grad():
preds = model(features)
predictions = preds.argmax(dim=-1)
accurate_preds = (predictions == labels)
num_elems += accurate_preds.shape[0]
accurate += accurate_preds.long().sum()
val_acc = accurate.item() / num_elems
return val_acc
python
def train(config=config):
dl_train, dl_val = create_dataloaders(config)
model = create_net(config);
optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
# ======================================================================
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
model.run_id = wandb.run.id
# ======================================================================
model.best_metric = -1.0
for epoch in range(1, config.epochs + 1):
model = train_epoch(model, dl_train, optimizer)
val_acc = eval_epoch(model, dl_val)
if val_acc > model.best_metric:
model.best_metric = val_acc
torch.save(model.state_dict(), config.ckpt_path)
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
# ======================================================================
wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
# ======================================================================
# ======================================================================
wandb.finish()
# ======================================================================
return model
上述代码最关键的就是如下三个部分:
- 初始化部分:
python
wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
- 模型训练参数上传
python
wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
- 模型训练完成关闭wandb:
python
wandb.finish()
最后在PyCharm中输入如下代码,即可运行上述代码:
python
model = train(config)
代码运行成功,即可出现如下的界面,点击下图中红框中的部分,即可跳转到曲线监视界面。
模型训练过程监视界面如下图所示:
点击下图中的红框部分,更改曲线的横坐标值。
如下图所示,将横坐标值更改为 epoch。
然后我们还可以增加一个 section
。
在新的 section
中添加新的显示模块,如下图所示:
此处我们添加了验证集的准确率,实现实时的监控。
模型训练结束,我们可以点击 runs
查看历史记录。
如下图可以看到,我们刚才监视的曲线,如图中的长方形红框所示。然后点击小红框中的 runs
,查看每一次训练过程的模型参数。
每一次模型训练的参数如下图所示,可以选择图中红框中的内容,选择需要的参数进行显示。
可选择的指标如下图所示:
对于某些我们比较关注的指标,我们可以将其固定显示:
固定后,我们回到 Workspace
界面,即可看到固定的参数。
模型版本管理
除了可以记录实验日志传递到 wandb 网站的云端服务器 并进行可视化分析。wandb还能够将实验关联的数据集,代码和模型 保存到 wandb 服务器。我们可以通过 wandb.log_artifact的方法来保存任务的关联的重要成果。例如 dataset, code,和 model,并进行版本管理。
当我们跑出一个相对不错的结果时,我们希望把这个结果给保存下来,此时我们就可以使用该功能。
我们先使用run_id 恢复 run任务,以便继续记录。
python
import wandb
# resume the run
run = wandb.init(project='wandb_demo', id='6h5xkv16', resume='allow')
上述代码中的 id
是用来关联我们训练的 runs
的,参数的值来自下图红框中的内容,想搞关联某一次的训练过程,就把某一次训练的 ID
写入上述代码。
保存数据集的代码:
python
# save dataset
arti_dataset = wandb.Artifact(name='mnist', type='dataset')
arti_dataset.add_dir('mnist/')
wandb.log_artifact(arti_dataset)
保存模型文件的代码:
python
# save code
arti_code = wandb.Artifact(name='py', type='code')
arti_code.add_file('./wandb_test.py')
wandb.log_artifact(arti_code)
保存模型权重的代码:
python
# save model
arti_model = wandb.Artifact(name='cnn', type='model')
arti_model.add_file(config.ckpt_path)
wandb.log_artifact(arti_model)
最后结束时要使用一下代码:
python
# finish时会提交保存
wandb.finish()
上传后的效果如图所示:
自动调参
sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。
使用Sweep的3步骤:
- 配置 sweep_config
python
# 配置 Sweep config
sweep_config = {
'method': 'random', # 选择调优算法,超参数搜索方法:随机搜索
'metric': { # 定义调优目标
'name': 'val_acc',
'goal': 'maximize'
},
'parameters': { # 定义超参空间
'project_name': {'value': 'wandb_demo'}, # 固定不变的超参
'epochs': {'value': 10},
'ckpt_path': {'value': 'checkpoint.pt'},
'optim_type': { # 离散型分布超参
'values': ['Adam', 'SGD', 'AdamW']
},
'hidden_layer_width': {
'values': [16, 32, 48, 64, 80, 96, 112, 128]
},
'lr': { # 连续型分布超参
'distribution': 'log_uniform_values',
'min': 1e-6,
'max': 0.1
},
'batch_size': {
'distribution': 'q_uniform',
'q': 8,
'min': 32,
'max': 256,
},
'dropout_p': {
'distribution': 'uniform',
'min': 0,
'max': 0.6,
}
},
# 'early_terminate': { # 定义剪枝策略 (可选)
# 'type': 'hyperband', # 使用 HyperBand 作为早停策略
# 'min_iter': 3, # 最小评估迭代次数(第 3 次迭代后开始考虑剪枝)
# 'eta': 2, # 成倍增长的资源分配比例(每次迭代中仅保留约 1/eta 的实验)
# 's': 3 # HyperBand 的最大阶数,影响资源分配的层级
# }
}
python
from pprint import pprint
pprint(sweep_config)
Sweep支持如下3种调优算法:
(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。
(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。
(3)贝叶斯搜索:bayes.
创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。
- 初始化 sweep controller
python
# 初始化 sweep controller
sweep_id = wandb.sweep(sweep_config, project=config.project_name)
- 启动 sweep agents
python
# 启动 Sweep agent
# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)
等代码跑完我们就有了一个 sweep
,如下图所示:
进入 sweep
之后就可以添加 Parallel coordinates
和 Parameter importance
进行参数分析。
不同的模型训练工具对比
工具 | 实验管理 | 数据版本控制 | 模型部署 | 团队协作 | 离线支持 | 特点 |
---|---|---|---|---|---|---|
TensorBoard | ✅ | ❌ | ❌ | ❌ | ✅ | 轻量级工具,适合快速原型开发 |
WandB | ✅ | ✅ | ✅ | ✅ | ✅ | 功能全面,支持超参数调优和实时协作 |
Comet | ✅ | ❌ | ❌ | ✅ | ✅ | 简单易用,支持离线模式 |
MLflow | ✅ | ✅ | ✅ | ✅ | ✅ | 实验管理与模型部署一体化 |
Neptune | ✅ | ❌ | ❌ | ✅ | ❌ | 强大的可视化功能 |
Sacred | ✅ | ❌ | ❌ | ❌ | ✅ | 极简实验管理工具 |
Polyaxon | ✅ | ✅ | ✅ | ✅ | ❌ | 分布式训练与大规模实验管理支持 |
DVC | ✅ | ✅ | ❌ | ❌ | ✅ | 专注于数据和模型版本控制 |
ClearML | ✅ | ✅ | ✅ | ✅ | ✅ | 全面的 MLOps 功能 |