[paddle] 非线性拟合问题的训练

利用paddlepaddle建立神经网络,模拟有限个数据的非线性拟合

本文仍然考虑 f ( x ) = sin ⁡ ( x ) x f(x)=\frac{\sin(x)}{x} f(x)=xsin(x) 函数在区间 -10,10 上固定数据的拟合。

python 复制代码
import paddle
import paddle.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子以确保结果的可重复性
paddle.seed(1)

# 生成数据集
x_data = (np.random.rand(500) * 20 - 10).astype('float32')  # 生成500个随机x值,范围在-10到10之间
y_data = np.sin(x_data) / x_data  # 生成y值
y_data = y_data.reshape(-1, 1)  # 将y_data转换为二维数组

# 定义模型,一个具有2个隐藏层的多层感知器
class MyModel(nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.hidden1 = nn.Linear(in_features=1, out_features=50)
        self.bn = nn.BatchNorm1D(num_features=50)
        self.hidden2 = nn.Linear(in_features=50, out_features=1)

    def forward(self, x):
        x = paddle.tanh(self.hidden1(x))
        x = self.bn(x)
        x = self.hidden2(x)
        return x

model = MyModel()

# 定义损失函数
loss_fn = nn.MSELoss()

# 设置优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())

# 训练数据
train_data = paddle.to_tensor(x_data).unsqueeze(-1), paddle.to_tensor(y_data)

# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):
    loss = loss_fn(model(train_data[0]), train_data[1])
    loss.backward()
    optimizer.step()
    optimizer.clear_grad()
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: Loss = {loss.numpy()}')

# 使用训练好的模型进行预测
y_pred = model(train_data[0]).numpy()

# 可视化结果
plt.scatter(x_data, y_data, label='True')
plt.scatter(x_data, y_pred, label='Predicted')
plt.legend()
plt.show()
相关推荐
如烟花的信页2 分钟前
易盾滑块逆向分析
javascript·爬虫·python·js逆向
Hali_Botebie6 分钟前
为什么静态3DGS+轨迹回放,可以通过强化学习训练端到端自动驾驶?
人工智能·机器学习·自动驾驶
常常有13 分钟前
Redis:哨兵模式 (Sentinel)
redis·python·sentinel
计算机安禾13 分钟前
【算法分析与设计】第44篇:随机化复杂度类:RP、BPP与去随机化猜想
java·数据结构·数据库·算法·机器学习
丨白色风车丨18 分钟前
机器学习数据预处理:6 种缺失值填充方法完整实现(CCA / 均值 / 中位数 / 众数 / 线性回归 / 随机森林)
机器学习·均值算法·线性回归
程序员三藏20 分钟前
接口测试用例设计
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·接口测试
再玩一会儿看代码25 分钟前
Java抽象类和接口区别_场景理解
java·开发语言·经验分享·笔记·python
大蚂蚁2号26 分钟前
Python迭代器与生成器深度剖析:从底层协议到工程实战
python
专注搞钱28 分钟前
AI编程实战:我用Python+LangChain搭建了一个半导体FAB智能运维Agent
python·langchain·ai编程
财经资讯数据_灵砚智能34 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年6月3日
大数据·人工智能·python·信息可视化·自然语言处理·灵砚智能