基础的贝叶斯神经网络(BNN)回归

下面是一个最基础的贝叶斯神经网络(BNN)回归 示例,采用PyTorch实现,适合入门理解。

这个例子用BNN拟合 y = x + 噪声 的一维回归问题,输出均值和不确定性(方差)。

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 1. 生成数据
np.random.seed(0)
x = np.linspace(-3, 3, 100)
y = x + np.random.normal(0, 0.5, size=x.shape)

# 转为torch tensor
x_train = torch.tensor(x, dtype=torch.float32).unsqueeze(1)
y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

# 2. 定义贝叶斯回归网络(输出均值和log方差)
class BayesianRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 32), nn.ReLU(),
            nn.Linear(32, 32), nn.ReLU(),
            nn.Linear(32, 2) # 输出均值和log方差
        )
    def forward(self, x):
        out = self.net(x)
        mean = out[:, 0:1]
        logvar = out[:, 1:2]
        return mean, logvar

# 3. 贝叶斯损失函数(负对数似然)
def bayesian_loss(mean, logvar, target):
    # 对应N(y|mean, exp(logvar))
    return (0.5 * torch.exp(-logvar) * (target - mean) ** 2 + 0.5 * logvar).mean()

# 4. 训练网络
model = BayesianRegressor()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(2000):
    mean, logvar = model(x_train)
    loss = bayesian_loss(mean, logvar, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 200 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# 5. 预测与可视化
x_test = torch.linspace(-3, 3, 100).unsqueeze(1)
mean_pred, logvar_pred = model(x_test)
mean_pred = mean_pred.detach().numpy().flatten()
std_pred = torch.exp(0.5 * logvar_pred).detach().numpy().flatten()

plt.figure(figsize=(8, 5))
plt.scatter(x, y, label='Data', color='gray', s=10)
plt.plot(x, x, 'g--', label='True function')
plt.plot(x_test, mean_pred, 'b-', label='BNN mean')
plt.fill_between(x_test.flatten(), mean_pred-2*std_pred, mean_pred+2*std_pred, color='orange', alpha=0.3, label='BNN ±2std')
plt.legend()
plt.title("Simple Bayesian Neural Network Regression")
plt.show()
相关推荐
看到我,请让我去学习2 分钟前
OpenCV编程- (图像基础处理:噪声、滤波、直方图与边缘检测)
c语言·c++·人工智能·opencv·计算机视觉
码字的字节4 分钟前
深度解析Computer-Using Agent:AI如何像人类一样操作计算机
人工智能·computer-using·ai操作计算机·cua
说私域1 小时前
互联网生态下赢家群体的崛起与“开源AI智能名片链动2+1模式S2B2C商城小程序“的赋能效应
人工智能·小程序·开源
董厂长5 小时前
langchain :记忆组件混淆概念澄清 & 创建Conversational ReAct后显示指定 记忆组件
人工智能·深度学习·langchain·llm
G皮T8 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
九年义务漏网鲨鱼8 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间8 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享8 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频
产品经理独孤虾9 小时前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性
老任与码9 小时前
Spring AI Alibaba(1)——基本使用
java·人工智能·后端·springaialibaba