训练一个线性模型

复制代码
import tensorflow as tf
import pandas as pd

# 读取数据
data = pd.read_csv('../data/line_fit_data.csv').values
# 划分训练集和测试集
x = data[:-10, 0]   #第一列排除后10行
y = data[:-10, 1]   #第二列排除后10行
x_test = data[-10:, 0] #第一列后10行
y_test = data[-10:, 1] #第二列后10行

# 构建Sequential网络
model_net = tf.keras.models.Sequential()  # 实例化网络
model_net.add(tf.keras.layers.Dense(1, input_shape=(1, )))  # 添加全连接层
print(model_net.summary())

# 构建损失函数
model_net.compile(loss='mse', optimizer=tf.keras.optimizers.SGD(learning_rate=0.5))

# 模型训练
model_net.fit(x, y, verbose=1, epochs=20, validation_split=0.2)
pre = model_net.predict(x_test)

# 利用均方误差进行模型评价
y_test = pd.DataFrame(y_test)
pre = pd.DataFrame(pre)
mse = (sum(y_test - pre) ** 2) / 10
print('均方误差为:', mse)

总结

model_net.add() :向模型中添加层,第一层需指定 `input_shape` |

Dense(units=1) :定义全连接层 ,`units` 决定输出维度 |

`input_shape=(1,)` : 指定输入数据的形状 ,仅第一层需要,元组格式 |

model.summary(): 查看模型结构和参数数量

**`units=1`**:输出维度为1(即该层只有1个神经元)。

  • **`input_shape=(1,)`**:指定输入数据的形状为 `(1,)`(即每个样本是1个数值)。

**1. `model_net.compile()`:配置模型训练参数**

  • **作用**:定义模型的损失函数、优化器和评估指标。

  • **参数解析**:

  • **`loss='mse'`**:使用均方误差(Mean Squared Error)作为损失函数,适用于**回归任务**(如预测房价、温度等连续值)。

  • **`optimizer=tf.keras.optimizers.SGD(learning_rate=0.5)`**:

  • 优化器:随机梯度下降(Stochastic Gradient Descent, SGD)。

  • 学习率:`0.5`(较高的学习率,可能导致训练不稳定,需根据任务调整)。

  • **未显式指定 `metrics`**:如需要监控准确率等指标,可添加 `metrics=['mae']`(平均绝对误差)。


**2. `model_net.fit()`:模型训练**

  • **作用**:用训练数据拟合模型,更新权重参数。

  • **参数解析**:

  • **`x, y`**:输入数据和标签(假设 `x` 是特征,`y` 是目标值)。

  • **`verbose=1`**:显示训练进度条(`0`=不显示,`1`=显示进度条,`2`=仅显示轮次结果)。

  • **`epochs=20`**:训练20轮(所有数据完整遍历一次为一轮)。

  • **`validation_split=0.2`**:从训练数据中自动划分20%作为验证集(例如,若 `x` 有100个样本,则80个用于训练,20个用于验证)。

**`pd.DataFrame()`** 是 Pandas 库中用于创建或转换数据为 **二维表格结构**(DataFrame)的函数。

  • 这行代码的目的是将 `y_test`(可能是列表、NumPy 数组或其他格式)转换为 DataFrame,以便后续使用 Pandas 的功能(如数据操作、保存到文件、与其他 DataFrame 合并等)。
相关推荐
缘友一世33 分钟前
PyTorch可视化工具——使用Visdom进行深度学习可视化
人工智能·pytorch·深度学习
熊猫在哪3 小时前
野火鲁班猫(arrch64架构debian)从零实现用MobileFaceNet算法进行实时人脸识别(四)安装RKNN Toolkit2
人工智能·python·嵌入式硬件·深度学习·神经网络·目标检测·机器学习
老唐7773 小时前
PyTorch的基本操作
人工智能·pytorch·python·深度学习·神经网络·机器学习·计算机视觉
Panesle4 小时前
谷歌medgemma-27b-text-it医疗大模型论文速读:多语言大型语言模型医学问答基准测试MedExpQA
人工智能·深度学习·语言模型·自然语言处理·开源·大模型
攻城狮7号4 小时前
Strands Agents:AWS开源Agent框架的技术与应用全景
人工智能·深度学习·云计算·aws·智能体·strands agents
WSSWWWSSW4 小时前
解释:神经网络
人工智能·深度学习·神经网络
九章云极AladdinEdu5 小时前
光子神经网络加速器编程范式研究:光子矩阵乘法的误差传播模型构建
开发语言·人工智能·深度学习·神经网络·矩阵·负载均衡·transformer
我不是小upper5 小时前
深度学习之-目标检测算法汇总(超全面)
深度学习·算法·目标检测
Blossom.1186 小时前
Web3.0:下一代互联网的变革与机遇
人工智能·深度学习·物联网·机器学习·web3·区块链·边缘计算
Studying 开龙wu7 小时前
深度学习模型部署:使用Flask将图像分类(5类)模型部署在服务器上,然后在本地GUI调用。(全网模型部署项目步骤详解:从模型训练到部署再到调用)
深度学习·分类·flask