wandb的使用方法,以navrl为例

1.wandb的基础用法

Weights & Biases(简称 wandb)是一个强大的机器学习实验跟踪工具,主要用于记录训练过程中的指标(如 loss、accuracy)、超参数、模型权重、图像/视频等数据。它支持实时可视化、实验比较、团队协作,还能进行超参数搜索和模型版本管理。相比 TensorBoard,wandb 的优势在于云端同步、易于分享和团队管理,免费版对个人和小团队已经足够强大。

python 复制代码
#安装库
pip install wandb

#登录,需要去网站注册账号,并且拿到key
wandb login  #之后输入key就可以登陆了
wandb status #查询登录状态
wandb login --relogin  #重新登录


#支持离线模式,离线训练之后还可以在同步进行查看,免费版是由存储空间的限制的,但是一般也够用了。

wandb 的核心是 "Run"(一次实验运行)。流程大致是:1.用 wandb.init() 初始化一个 Run。2.用 wandb.config 保存超参数。3.在训练循环中用 wandb.log() 记录指标。4.结束时自动上传(或手动 wandb.finish())。先用一个模拟记录数字的代码来感受一下:

python 复制代码
import wandb
import time
import random

# 第1步:启动一个项目实验(叫一个 Run)
wandb.init(
    project="my-first-wandb",   # 项目名字,随便取,第一次运行会自动创建,类似于一个文件夹,里边放着所有的实验。
    name="super-simple-test"    # 这次实验的显示名字,随便取
)

# 第2步:模拟训练10个epoch
for epoch in range(1, 11):  # 1到10
    # 随便生成一点数据
    fake_loss = 2.0 / epoch + random.random() * 0.5   # loss 慢慢下降,random.random()会生成一个 随机浮点数(小数),这个数大于等于 0,小于 1。
    fake_accuracy = 1 - (1.0 / epoch) + random.random() * 0.1  # accuracy 慢慢上升

    # 第3步:把数据发给 wandb 记录,如果设置online就可以在网站上看到结果了。
    wandb.log({
        "loss": fake_loss,
        "accuracy": fake_accuracy,
        "epoch": epoch
    })

    print(f"Epoch {epoch}: loss={fake_loss:.3f}, acc={fake_accuracy:.3f}")
    time.sleep(1)  # 暂停1秒,模拟训练时间

# 第4步:结束实验,记得要关闭。
wandb.finish()

现在来训练一个真正的模型。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import wandb

# 1. 初始化 wandb 项目
wandb.init(
    project="mnist-beginner",          # 项目名
    name="first-real-training",        # 这次实验名
    config={                           # 把超参数保存下来,后面比较实验超方便
        "learning_rate": 0.01,
        "batch_size": 64,
        "epochs": 5,
        "optimizer": "SGD"
    }
)

# 把超参数取出来方便使用
config = wandb.config

# 2. 准备数据(MNIST 手写数字)
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# 3. 定义一个超级简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()

# 4. 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

# 5. 训练循环
for epoch in range(1, config.epochs + 1):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # 6. 测试(验证)准确率
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    avg_test_loss = test_loss / len(test_loader)
    test_acc = 100. * correct / len(test_loader.dataset)
    
    # 7. 最关键:把指标发给 wandb。wandb.log()函数的作用就是告诉wandb要记录的指标
    wandb.log({
        "train_loss": avg_train_loss,
        "train_acc": train_acc,
        "test_loss": avg_test_loss,
        "test_acc": test_acc,
        "epoch": epoch
    })
    
    print(f"Epoch {epoch}: Train Acc {train_acc:.2f}%, Test Acc {test_acc:.2f}%")

# 8. 保存模型(可选,但推荐)
wandb.save("model.pth")  # 会自动上传到云端,告诉 wandb 把本地文件 "model.pth" 保存并上传到你当前的这个实验对应的云端页面上。

# 9. 结束
wandb.finish()

2.以NavRL项目为例

最近在尝试复现修改NavRL项目的代码,有需求要分析一下奖励函数,因此就需要把一些新的变量放在wandb上,所以就有了这个笔记,之前自己并没有用过wandb,也算是一个总结。

python 复制代码
#train.py里边,做了一个wandb的初始化
if (cfg.wandb.run_id is None):
    run = wandb.init(
        project=cfg.wandb.project,
        name=f"{cfg.wandb.name}/{datetime.datetime.now().strftime('%m-%d_%H-%M')}",
        entity=cfg.wandb.entity,
        config=cfg,                     # 把整个配置文件保存到 wandb
        mode=cfg.wandb.mode,
        id=wandb.util.generate_id(),
    )
else:  #如果手动制定了wandb.run_id就从这里恢复训练,resume=must是最严格的恢复模式,确保你真的在继续同一个实验。
    run = wandb.init(..., resume="must")  # 支持从某个 run_id 恢复吗,这是wandb的一个断点重新训练的功能,能够从之前中断的是眼睛却恢复。
    
python 复制代码
#注册,也就是定义一个浮点数的容器,给每个无人机发一个这样的容器
#env.py里边,定义要统计的量的结构,告诉整个系统(目前是torchrl框架),这个环境会返回那些统计指标,每个指标是什么形状,类型,像一个字典。
#定义表格
    # compositesepc :可以理解为一个字典,前边红色为键,后边UnboundedContinuousTensorSpec(1)是torchRL的特有对象,用来描述数据的形状和类型。
    # Unbounded (无界):意思是这个数值没有上下限(不像图片像素必须在0-255)。比如速度奖励可以是 0.5,也可以是 100.0,也可以是 -50.0。
	# Continuous (连续):意思是数据是浮点数 (float),比如 3.14159,而不是整数(离散的)。
	# TensorSpec (张量规格):说明这是为 PyTorch Tensor 准备的。
	# (1):这是形状 (Shape)。表示对于每一个无人机,这个数据只有一个数字(标量)。
    
stats_spec = CompositeSpec({
    "return": UnboundedContinuousTensorSpec(1),
    "episode_len": UnboundedContinuousTensorSpec(1),
    "reach_goal": UnboundedContinuousTensorSpec(1),
    "debug_reward_vel": UnboundedContinuousTensorSpec(1),
    "debug_reward_facing": UnboundedContinuousTensorSpec(1),
    "debug_penalty_smooth": UnboundedContinuousTensorSpec(1),
    "debug_penalty_yaw": UnboundedContinuousTensorSpec(1),
    "debug_heading_error": UnboundedContinuousTensorSpec(1),
    "debug_current_speed": UnboundedContinuousTensorSpec(1),
    "collision": UnboundedContinuousTensorSpec(1),
    "truncated": UnboundedContinuousTensorSpec(1),
}).expand(self.num_envs).to(self.device) #把每个指标扩展成 [num_envs, 1] 形状,支持并行环境(比如 50 个无人机同时跑)。

self.observation_spec["stats"] = stats_spec
self.stats = stats_spec.zero()  # 创建一个全零的 TensorDict,用 stats_spec.zero() 初始化一个全零的 TensorDict,作为后续累加统计量的容器。
python 复制代码
#填数,获取数据,并且调整形状,然后填入这个字典里边。
#env.py里边,计算原始的统计量,也就是统计量在训练环境中的一个填充过程。
#累加型指标(+=):如 return、各种 debug_reward、debug_penalty、debug_heading_error、debug_current_speed ------ 这些在一个 episode 内不断累加,episode 结束时反映平均表现。
#赋值型指标(=):如 reach_goal、collision、truncated、episode_len ------ 这些是当前步的状态,episode 结束时取最后值或是否为 True。

#在_compute_state_and_obs中
self.stats["return"] += self.reward
self.stats["episode_len"][:] = self.progress_buf.unsqueeze(1)
self.stats["reach_goal"] = reach_goal.float().unsqueeze(-1)
self.stats["collision"] = collision_flat.float().unsqueeze(-1)
self.stats["truncated"] = self.truncated.float()

self.stats["debug_reward_vel"] += reward_vel
self.stats["debug_reward_facing"] += reward_facing
self.stats["debug_penalty_smooth"] += penalty_smooth * 0.1
self.stats["debug_penalty_yaw"] += penalty_yaw_rate
self.stats["debug_heading_error"] += torch.abs(heading_error).reshape(self.num_envs, 1)
self.stats["debug_current_speed"] += vel_w.norm(dim=-1).reshape(self.num_envs, 1)


#然后送给train.py
   return TensorDict({
            "agents": TensorDict(
                {
                    "observation": obs,
                }, 
                [self.num_envs]
            ),
            "stats": self.stats.clone(),
            "info": self.info
        }, self.batch_size)
python 复制代码
#tran.py去收集这些stats
#transformed_env.observation_spec.keys(True, True) 会递归遍历所有 Spec。
'''
因为在 Env 里声明了 self.observation_spec["stats"] = stats_spec,所以这里能自动发现所有键:
("stats", "return")
("stats", "reach_goal")
("stats", "debug_heading_error")
'''
episode_stats_keys = [
    k for k in transformed_env.observation_spec.keys(True, True)
    if isinstance(k, tuple) and k[0]=="stats"
]
episode_stats = EpisodeStats(episode_stats_keys)
python 复制代码
#训练中收集要记录的信息-info字典,上边的信息通过tran.py里边的这个循环来收集
#info 字典最终包含两类关键信息:
#PPO 算法本身的训练损失(来自 policy.train(data))
#环境自定义的统计量(来自 self.stats,通过 episode_stats 聚合)

for i, data in enumerate(collector):
    info = {"env_frames": collector._frames, "rollout_fps": collector._fps} #collector._frames:到目前为止,总共采集了多少帧(环境步数)。collector._fps:数据采集的实时速度(frames per second),反映仿真+策略推理的整体速度

    # PPO 训练损失
    train_loss_stats = policy.train(data)
    info.update(train_loss_stats)          # 加入 actor_loss, critic_loss, entropy 等

    # 环境统计量,来自 env 的 self.stats
    episode_stats.add(data)  #把当前环境的data["stats"](来自 NavigationEnv 的 self.stats)加入收集器
    if len(episode_stats) >= transformed_env.num_envs:  #确保只有当所有并行环境都完成至少一个 episode 时才记录一次平均值
        stats = {   #字典推到式,episode_stats.pop():取出积累的所有已结束 episode 的统计数据,计算平均值(torch.mean),转成 Python 标量(.item()),加前缀 "train/" 并拼接键名(如 "train/stats.return")
            "train/" + (".".join(k) if isinstance(k, tuple) else k): torch.mean(v.float()).item()
            for k, v in episode_stats.pop().items(True, True)
        }
        info.update(stats)                 # 加入 return, reach_goal, collision, debug_xxx 等信息
#之后,最终上传,wandb 收到后画出所有曲线        
         run.log(info)
python 复制代码
总的来说,换汤不换药也就是形式上复杂一点,股价还是一样的,用grok总结一下整体思路,感觉写的可以:
NavigationEnv._set_specs()
    ↓ 定义 stats_spec(声明我要记录哪些指标)
    ↓ self.stats = stats_spec.zero()(创建容器)

NavigationEnv._compute_state_and_obs()(每步执行)
    ↓ self.stats[...] += / = ...(填充数值)
    ↓ return TensorDict(..., "stats": self.stats.clone())(返回给 collector)

train.py 数据采集(collector)
    ↓ data 包含 "stats" 子字典

train.py EpisodeStats 初始化
    ↓ 自动从 observation_spec 发现所有 ("stats", xxx) 键
    ↓ episode_stats = EpisodeStats(episode_stats_keys)

train.py 训练循环
    ↓ episode_stats.add(data)(收集已结束 episode 的 stats)
    ↓ 当所有环境完成一轮 episode
        ↓ 计算平均值 + 加 "train/" 前缀
        ↓ info.update(stats)
    ↓ run.log(info)(上传 wandb)

wandb 网页
    ↓ 显示曲线:train/stats.return、train/stats.reach_goal、train/stats.collision 等
相关推荐
木头左18 小时前
贝叶斯深度学习在指数期权风险价值VaR估计中的实现与应用
人工智能·深度学习
编程大师哥18 小时前
Java 常见异常(按「运行时 / 编译时」分类)
java·开发语言
rgeshfgreh18 小时前
解决Windows系统Python命令无效问题
python
MF_AI18 小时前
苹果病害检测识别数据集:1w+图像,5类,yolo标注
图像处理·人工智能·深度学习·yolo·计算机视觉
jinglong.zha18 小时前
AScript游戏进阶课程 - 实战课表(0基础小白从入门到精通系列课程)
python·自动化·懒人精灵·ascript·游戏脚本
xiao5kou4chang6kai418 小时前
面向自然科学领域机器学习与深度学习(高维数据预处理—可解释ML/DL—时空建模—不确定性量化-全程AI+Python)
人工智能·深度学习·机器学习·不确定性量化·时空建模·高维数据预处理·可解释ml/dl
bybitq18 小时前
Leetcode131题解 -Python-回溯+cache缓存
开发语言·python
SunnyDays101118 小时前
如何使用 Python 合并多个 Excel 文件
python·合并excel文件·合并excel表格
lixzest18 小时前
PyTorch张量(Tensor)简介
python