FNN sin predict

import torch

import matplotlib.pyplot as plt

设置随机种子,保证结果可复现

torch.manual_seed(99)

==================== 1. 定义MLP模型 ====================

class MLP:

"""

多层感知机 (Multilayer Perceptron)

结构: 输入层 -> 隐藏层(tanh激活) -> 输出层(线性)

"""

def init(self, input_size, hidden_size, output_size):

"""

初始化网络参数

Args:

input_size: 输入特征维度

hidden_size: 隐藏层神经元数量

output_size: 输出维度

"""

初始化权重和偏置(使用标准正态分布)

self.W1 = torch.randn([hidden_size, input_size], requires_grad=True) # 输入层->隐藏层权重

self.b1 = torch.randn([hidden_size, 1], requires_grad=True) # 隐藏层偏置

self.W2 = torch.randn([output_size, hidden_size], requires_grad=True) # 隐藏层->输出层权重

self.b2 = torch.randn([output_size, 1], requires_grad=True) # 输出层偏置

存储参数列表,方便后续操作

self.parameters = [self.W1, self.b1, self.W2, self.b2]

def forward(self, x):

"""

前向传播

Args:

x: 输入数据,形状为 (input_size, batch_size)

Returns:

网络输出,形状为 (output_size, batch_size)

"""

隐藏层:线性变换 + tanh激活

hidden = torch.tanh(self.W1 @ x + self.b1)

输出层:线性变换

outp = self.W2 @ hidden + self.b2

return outp

def zero_grad(self):

"""清空所有参数的梯度"""

for param in self.parameters:

if param.grad is not None:

param.grad.zero_()

def update_parameters(self, lr):

"""

使用梯度下降更新参数

Args:

lr: 学习率

"""

with torch.no_grad(): # 禁用梯度追踪,仅更新数值

for param in self.parameters:

param.data -= param.grad * lr

==================== 2. 定义损失函数 ====================

def mse_loss(y_true, y_pred):

"""

均方误差损失函数 (Mean Squared Error)

"""

return ((y_true - y_pred) ** 2).mean()

==================== 3. 准备数据 ====================

def generate_data():

"""

生成训练数据:sin函数在[-5, 5]区间的采样

"""

生成20个训练样本

x_train = torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train) # 形状: (1, 20)

生成100个测试样本(用于绘制平滑曲线)

x_test = torch.linspace(-5, 5, 100).reshape(1, -1) # 形状: (1, 100)

return x_train, y_train, x_test

==================== 4. 训练流程 ====================

def train(model, x_train, y_train, epochs=5000, lr=0.01, print_interval=500):

"""

训练MLP模型

"""

print("开始训练...")

print("=" * 50)

for epoch in range(epochs):

前向传播

y_pred = model.forward(x_train)

计算损失

loss = mse_loss(y_train, y_pred)

反向传播

loss.backward()

更新参数

model.update_parameters(lr)

清空梯度

model.zero_grad()

打印训练进度

if epoch % print_interval == 0 or epoch == epochs - 1:

print(f"Epoch [{epoch:>4d}/{epochs}], Loss: {loss.item():.6f}")

print("=" * 50)

print("训练完成!")

==================== 5. 可视化结果 ====================

def visualize(model, x_train, y_train, x_test):

"""

可视化训练结果

"""

切换到评估模式(禁用梯度计算)

with torch.no_grad():

y_pred = model.forward(x_test).detach().numpy()

转换数据用于绘图

x_train_np = x_train.numpy().flatten()

y_train_np = y_train.numpy().flatten()

x_test_np = x_test.numpy().flatten()

绘制结果

plt.figure(figsize=(10, 6))

plt.scatter(x_train_np, y_train_np, color='red', label='训练数据', s=50, zorder=5)

plt.plot(x_test_np, y_pred.flatten(), color='blue', linewidth=2, label='MLP拟合曲线')

plt.plot(x_test_np, torch.sin(x_test).numpy().flatten(), '--', color='green',

linewidth=1.5, label='真实sin函数', alpha=0.7)

plt.xlabel('x', fontsize=12)

plt.ylabel('y', fontsize=12)

plt.title('MLP拟合sin函数', fontsize=14)

plt.legend(fontsize=10)

plt.grid(True, alpha=0.3)

plt.tight_layout()

plt.show()

==================== 6. 打印模型参数 ====================

def print_model_params(model):

"""打印模型训练后的参数"""

print("\n模型参数:")

print("-" * 50)

print(f"W1 (输入层->隐藏层权重):\n{model.W1.data}\n")

print(f"b1 (隐藏层偏置):\n{model.b1.data}\n")

print(f"W2 (隐藏层->输出层权重):\n{model.W2.data}\n")

print(f"b2 (输出层偏置):\n{model.b2.data}")

print("-" * 50)

==================== 主程序入口 ====================

if name == "main":

1. 生成数据

x_train, y_train, x_test = generate_data()

2. 创建模型 (输入1维, 隐藏层4个神经元, 输出1维)

model = MLP(input_size=1, hidden_size=100, output_size=1)

3. 训练模型

train(model, x_train, y_train, epochs=5000, lr=0.01)

4. 可视化结果

#visualize(model, x_train, y_train, x_test)

5. 打印最终参数

#print_model_params(model)

x_train = torch.tensor([[4.1]]) #torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train)

print(y_train.item())

outp=model.forward(x_train)

print(outp.item())

相关推荐
沐知全栈开发1 小时前
C++ 多态
开发语言
zihan03211 小时前
若依(RuoYi)框架核心升级:全面适配 SpringData JPA,替换 MyBatis 持久层方案
java·开发语言·前端框架·mybatis·若依升级springboot
先做个垃圾出来………2 小时前
Python字节串“b“前缀
开发语言·python
无限进步_2 小时前
21. 合并两个有序链表 - 题解与详细分析
c语言·开发语言·数据结构·git·链表·github·visual studio
dreams_dream2 小时前
什么是迭代器和生成器
python
神奇大叔2 小时前
Java 配置文件记录
java·开发语言
三水彡彡彡彡2 小时前
C++拷贝函数:const与引用的高效实践
开发语言·c++
悠闲蜗牛�2 小时前
深入浅出Spring Boot 3.x:新特性全解析与实战指南
开发语言·python