FNN sin predict

import torch

import matplotlib.pyplot as plt

设置随机种子,保证结果可复现

torch.manual_seed(99)

==================== 1. 定义MLP模型 ====================

class MLP:

"""

多层感知机 (Multilayer Perceptron)

结构: 输入层 -> 隐藏层(tanh激活) -> 输出层(线性)

"""

def init(self, input_size, hidden_size, output_size):

"""

初始化网络参数

Args:

input_size: 输入特征维度

hidden_size: 隐藏层神经元数量

output_size: 输出维度

"""

初始化权重和偏置(使用标准正态分布)

self.W1 = torch.randn([hidden_size, input_size], requires_grad=True) # 输入层->隐藏层权重

self.b1 = torch.randn([hidden_size, 1], requires_grad=True) # 隐藏层偏置

self.W2 = torch.randn([output_size, hidden_size], requires_grad=True) # 隐藏层->输出层权重

self.b2 = torch.randn([output_size, 1], requires_grad=True) # 输出层偏置

存储参数列表,方便后续操作

self.parameters = [self.W1, self.b1, self.W2, self.b2]

def forward(self, x):

"""

前向传播

Args:

x: 输入数据,形状为 (input_size, batch_size)

Returns:

网络输出,形状为 (output_size, batch_size)

"""

隐藏层:线性变换 + tanh激活

hidden = torch.tanh(self.W1 @ x + self.b1)

输出层:线性变换

outp = self.W2 @ hidden + self.b2

return outp

def zero_grad(self):

"""清空所有参数的梯度"""

for param in self.parameters:

if param.grad is not None:

param.grad.zero_()

def update_parameters(self, lr):

"""

使用梯度下降更新参数

Args:

lr: 学习率

"""

with torch.no_grad(): # 禁用梯度追踪,仅更新数值

for param in self.parameters:

param.data -= param.grad * lr

==================== 2. 定义损失函数 ====================

def mse_loss(y_true, y_pred):

"""

均方误差损失函数 (Mean Squared Error)

"""

return ((y_true - y_pred) ** 2).mean()

==================== 3. 准备数据 ====================

def generate_data():

"""

生成训练数据:sin函数在[-5, 5]区间的采样

"""

生成20个训练样本

x_train = torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train) # 形状: (1, 20)

生成100个测试样本(用于绘制平滑曲线)

x_test = torch.linspace(-5, 5, 100).reshape(1, -1) # 形状: (1, 100)

return x_train, y_train, x_test

==================== 4. 训练流程 ====================

def train(model, x_train, y_train, epochs=5000, lr=0.01, print_interval=500):

"""

训练MLP模型

"""

print("开始训练...")

print("=" * 50)

for epoch in range(epochs):

前向传播

y_pred = model.forward(x_train)

计算损失

loss = mse_loss(y_train, y_pred)

反向传播

loss.backward()

更新参数

model.update_parameters(lr)

清空梯度

model.zero_grad()

打印训练进度

if epoch % print_interval == 0 or epoch == epochs - 1:

print(f"Epoch [{epoch:>4d}/{epochs}], Loss: {loss.item():.6f}")

print("=" * 50)

print("训练完成!")

==================== 5. 可视化结果 ====================

def visualize(model, x_train, y_train, x_test):

"""

可视化训练结果

"""

切换到评估模式(禁用梯度计算)

with torch.no_grad():

y_pred = model.forward(x_test).detach().numpy()

转换数据用于绘图

x_train_np = x_train.numpy().flatten()

y_train_np = y_train.numpy().flatten()

x_test_np = x_test.numpy().flatten()

绘制结果

plt.figure(figsize=(10, 6))

plt.scatter(x_train_np, y_train_np, color='red', label='训练数据', s=50, zorder=5)

plt.plot(x_test_np, y_pred.flatten(), color='blue', linewidth=2, label='MLP拟合曲线')

plt.plot(x_test_np, torch.sin(x_test).numpy().flatten(), '--', color='green',

linewidth=1.5, label='真实sin函数', alpha=0.7)

plt.xlabel('x', fontsize=12)

plt.ylabel('y', fontsize=12)

plt.title('MLP拟合sin函数', fontsize=14)

plt.legend(fontsize=10)

plt.grid(True, alpha=0.3)

plt.tight_layout()

plt.show()

==================== 6. 打印模型参数 ====================

def print_model_params(model):

"""打印模型训练后的参数"""

print("\n模型参数:")

print("-" * 50)

print(f"W1 (输入层->隐藏层权重):\n{model.W1.data}\n")

print(f"b1 (隐藏层偏置):\n{model.b1.data}\n")

print(f"W2 (隐藏层->输出层权重):\n{model.W2.data}\n")

print(f"b2 (输出层偏置):\n{model.b2.data}")

print("-" * 50)

==================== 主程序入口 ====================

if name == "main":

1. 生成数据

x_train, y_train, x_test = generate_data()

2. 创建模型 (输入1维, 隐藏层4个神经元, 输出1维)

model = MLP(input_size=1, hidden_size=100, output_size=1)

3. 训练模型

train(model, x_train, y_train, epochs=5000, lr=0.01)

4. 可视化结果

#visualize(model, x_train, y_train, x_test)

5. 打印最终参数

#print_model_params(model)

x_train = torch.tensor([[4.1]]) #torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train)

print(y_train.item())

outp=model.forward(x_train)

print(outp.item())

相关推荐
暮冬-  Gentle°2 小时前
C++中的命令模式实战
开发语言·c++·算法
勾股导航2 小时前
大模型Skill
人工智能·python·机器学习
2501_945423543 小时前
Django全栈开发入门:构建一个博客系统
jvm·数据库·python
FreakStudio5 小时前
保姆级 uPyPi 教程|从 0 到 1:MicroPython 驱动包一键安装 + 分享全攻略
python·嵌入式·电子diy
Volunteer Technology5 小时前
架构面试题(一)
开发语言·架构·php
清水白石0085 小时前
Python 对象序列化深度解析:pickle、JSON 与自定义协议的取舍之道
开发语言·python·json
2401_876907525 小时前
Python机器学习实践指南
开发语言·python·机器学习
努力中的编程者5 小时前
栈和队列(C语言底层实现环形队列)
c语言·开发语言
张张123y5 小时前
RAG从0到1学习:技术架构、项目实践与面试指南
人工智能·python·学习·面试·架构·langchain·transformer