FNN sin predict

import torch

import matplotlib.pyplot as plt

设置随机种子,保证结果可复现

torch.manual_seed(99)

==================== 1. 定义MLP模型 ====================

class MLP:

"""

多层感知机 (Multilayer Perceptron)

结构: 输入层 -> 隐藏层(tanh激活) -> 输出层(线性)

"""

def init(self, input_size, hidden_size, output_size):

"""

初始化网络参数

Args:

input_size: 输入特征维度

hidden_size: 隐藏层神经元数量

output_size: 输出维度

"""

初始化权重和偏置(使用标准正态分布)

self.W1 = torch.randn(hidden_size, input_size, requires_grad=True) # 输入层->隐藏层权重

self.b1 = torch.randn(hidden_size, 1, requires_grad=True) # 隐藏层偏置

self.W2 = torch.randn(output_size, hidden_size, requires_grad=True) # 隐藏层->输出层权重

self.b2 = torch.randn(output_size, 1, requires_grad=True) # 输出层偏置

存储参数列表,方便后续操作

self.parameters = self.W1, self.b1, self.W2, self.b2

def forward(self, x):

"""

前向传播

Args:

x: 输入数据,形状为 (input_size, batch_size)

Returns:

网络输出,形状为 (output_size, batch_size)

"""

隐藏层:线性变换 + tanh激活

hidden = torch.tanh(self.W1 @ x + self.b1)

输出层:线性变换

outp = self.W2 @ hidden + self.b2

return outp

def zero_grad(self):

"""清空所有参数的梯度"""

for param in self.parameters:

if param.grad is not None:

param.grad.zero_()

def update_parameters(self, lr):

"""

使用梯度下降更新参数

Args:

lr: 学习率

"""

with torch.no_grad(): # 禁用梯度追踪,仅更新数值

for param in self.parameters:

param.data -= param.grad * lr

==================== 2. 定义损失函数 ====================

def mse_loss(y_true, y_pred):

"""

均方误差损失函数 (Mean Squared Error)

"""

return ((y_true - y_pred) ** 2).mean()

==================== 3. 准备数据 ====================

def generate_data():

"""

生成训练数据:sin函数在-5, 5区间的采样

"""

生成20个训练样本

x_train = torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train) # 形状: (1, 20)

生成100个测试样本(用于绘制平滑曲线)

x_test = torch.linspace(-5, 5, 100).reshape(1, -1) # 形状: (1, 100)

return x_train, y_train, x_test

==================== 4. 训练流程 ====================

def train(model, x_train, y_train, epochs=5000, lr=0.01, print_interval=500):

"""

训练MLP模型

"""

print("开始训练...")

print("=" * 50)

for epoch in range(epochs):

前向传播

y_pred = model.forward(x_train)

计算损失

loss = mse_loss(y_train, y_pred)

反向传播

loss.backward()

更新参数

model.update_parameters(lr)

清空梯度

model.zero_grad()

打印训练进度

if epoch % print_interval == 0 or epoch == epochs - 1:

print(f"Epoch {epoch:\>4d}/{epochs}, Loss: {loss.item():.6f}")

print("=" * 50)

print("训练完成!")

==================== 5. 可视化结果 ====================

def visualize(model, x_train, y_train, x_test):

"""

可视化训练结果

"""

切换到评估模式(禁用梯度计算)

with torch.no_grad():

y_pred = model.forward(x_test).detach().numpy()

转换数据用于绘图

x_train_np = x_train.numpy().flatten()

y_train_np = y_train.numpy().flatten()

x_test_np = x_test.numpy().flatten()

绘制结果

plt.figure(figsize=(10, 6))

plt.scatter(x_train_np, y_train_np, color='red', label='训练数据', s=50, zorder=5)

plt.plot(x_test_np, y_pred.flatten(), color='blue', linewidth=2, label='MLP拟合曲线')

plt.plot(x_test_np, torch.sin(x_test).numpy().flatten(), '--', color='green',

linewidth=1.5, label='真实sin函数', alpha=0.7)

plt.xlabel('x', fontsize=12)

plt.ylabel('y', fontsize=12)

plt.title('MLP拟合sin函数', fontsize=14)

plt.legend(fontsize=10)

plt.grid(True, alpha=0.3)

plt.tight_layout()

plt.show()

==================== 6. 打印模型参数 ====================

def print_model_params(model):

"""打印模型训练后的参数"""

print("\n模型参数:")

print("-" * 50)

print(f"W1 (输入层->隐藏层权重):\n{model.W1.data}\n")

print(f"b1 (隐藏层偏置):\n{model.b1.data}\n")

print(f"W2 (隐藏层->输出层权重):\n{model.W2.data}\n")

print(f"b2 (输出层偏置):\n{model.b2.data}")

print("-" * 50)

==================== 主程序入口 ====================

if name == "main":

1. 生成数据

x_train, y_train, x_test = generate_data()

2. 创建模型 (输入1维, 隐藏层4个神经元, 输出1维)

model = MLP(input_size=1, hidden_size=100, output_size=1)

3. 训练模型

train(model, x_train, y_train, epochs=5000, lr=0.01)

4. 可视化结果

#visualize(model, x_train, y_train, x_test)

5. 打印最终参数

#print_model_params(model)

x_train = torch.tensor(\[4.1]) #torch.linspace(-5, 5, 20).reshape(1, -1) # 形状: (1, 20)

y_train = torch.sin(x_train)

print(y_train.item())

outp=model.forward(x_train)

print(outp.item())

相关推荐
早起CaiCai16 分钟前
【Pytorch 实践1】手写数字
人工智能·pytorch·python
周杰伦的稻香23 分钟前
Go + Redis:本地部署高性能图片主色调提取服务
开发语言·redis·golang
吴梓穆28 分钟前
Python 语法基础 函数
开发语言·python
不负岁月无痕31 分钟前
C++ 模板核心内容与高频面试题汇总
java·开发语言·c++
Kobebryant-Manba35 分钟前
学习文本处理
开发语言·python
m0_6174939440 分钟前
PaddleOCR报错:OneDnnContext does not have the input Filter 解决方案汇总
python
李可以量化42 分钟前
量化迅投 QMT vs 聚宽 (JoinQuant)全面分析
python·量化·qmt·ptrade·聚宽
福大大架构师每日一题1 小时前
2026年6月TIOBE编程语言排行榜,Go语言排名第13,Rust语言排名12。关于Rust已进入平台期的报道似乎为时过早。
开发语言·golang·rust
无限进步_1 小时前
从零实现一个迷你Shell——深入理解Linux命令行解释器
linux·运维·服务器·开发语言·c++·chrome
旅僧1 小时前
运行UMI镜像
python