TensorFlow 2 来训练一个线性回归模型

本节将通过一个简单的示例,带领大家了解如何使用 TensorFlow 2 来训练一个线性回归模型。这个例子将帮助大家掌握如何从数据处理、模型构建、训练到评估等步骤,逐步实现一个基础的机器学习任务。

下面是代码的详细讲解。

python 复制代码
import tensorflow as tf
import pandas as pd

首先,我们导入了 TensorFlow 和 Pandas 库。TensorFlow 用于构建和训练我们的机器学习模型,Pandas 用于处理和加载数据。Pandas 是非常强大的数据处理工具,它能够帮助我们方便地加载 CSV 文件并进行数据预处理。

python 复制代码
# 读取数据
data = pd.read_csv('./data/line_fit_data.csv').values

这里我们使用 Pandas 的 read_csv 函数来读取存储在 line_fit_data.csv 文件中的数据。.values 方法将数据转换成一个 NumPy 数组,方便后续处理。

python 复制代码
# 划分训练集和测试集
x = data[:-10, 0]
y = data[:-10, 1]
x_test = data[-10:, 0]
y_test = data[-10:, 1]

接着,我们将数据分成了训练集和测试集。x 是输入数据(特征),y 是目标数据(标签)。我们使用数据的前面部分(除了最后 10 个数据点)作为训练集,而剩下的 10 个数据点用作测试集。

python 复制代码
# 构建Sequential网络
model_net = tf.keras.models.Sequential()  # 实例化网络
model_net.add(tf.keras.layers.Dense(1, input_shape=(1, )))  # 添加全连接层
print(model_net.summary())

我们开始构建模型,首先实例化了一个 Sequential 模型。Sequential 是 TensorFlow 中最常用的模型类型,表示模型由一层一层的神经网络层按顺序堆叠而成。

这里我们添加了一个全连接层(Dense),该层只有一个神经元,输入数据的形状为 (1,),意味着每个输入数据只有一个特征。全连接层将把输入数据映射到一个输出值,适用于线性回归问题。

model_net.summary() 会打印出模型的简要信息,包括每一层的类型、输出形状、参数数量等。

python 复制代码
# 构建损失函数
model_net.compile(loss='mse', optimizer=tf.keras.optimizers.SGD(learning_rate=0.5))

接下来,我们编译模型,指定损失函数和优化器。在这个例子中,我们选择了 均方误差(MSE, Mean Squared Error) 作为损失函数,这是回归问题中常用的损失函数,它计算预测值和真实值之间的误差平方和的平均值。优化器使用的是 随机梯度下降(SGD, Stochastic Gradient Descent),并设定学习率为 0.5。

python 复制代码
# 模型训练
model_net.fit(x, y, verbose=1, epochs=20, validation_split=0.2)

fit 函数用于训练模型。我们传入了训练数据 x 和标签 y,并设定训练时要进行 20 个周期(epochs)。validation_split=0.2 表示在训练过程中,自动将 20% 的数据用于验证,以监控模型的训练效果。

通过 verbose=1,我们可以在训练过程中看到每个 epoch 的训练进度。

python 复制代码
pre = model_net.predict(x_test)

训练完成后,我们使用 predict 函数对测试集 x_test 进行预测,得到预测值 pre

python 复制代码
# 利用均方误差进行模型评价
y_test = pd.DataFrame(y_test)
pre = pd.DataFrame(pre)
mse = (sum(y_test - pre) ** 2) / 10
print('均方误差为:', mse)

最后,我们通过计算均方误差来评估模型的表现。首先,将真实的测试标签 y_test 和预测结果 pre 转换为 Pandas 的 DataFrame 格式,便于后续的计算。然后,我们计算每个测试点的误差(真实值与预测值的差),并求得它们的平方和的平均值,这就是均方误差(MSE)。

均方误差越小,说明模型的预测结果越接近真实值,模型的表现越好。

关键点总结:

  • 数据读取与预处理:通过 Pandas 加载数据并划分为训练集和测试集。
  • 模型构建 :使用 TensorFlow 的 Sequential 模型和全连接层来构建一个简单的线性回归模型。
  • 编译与训练:选择合适的损失函数(MSE)和优化器(SGD),并通过训练数据对模型进行训练。
  • 评估与预测:通过计算均方误差来评估模型在测试集上的表现。

总结

这个简单的例子展示了如何使用 TensorFlow 2 来训练一个线性回归模型。虽然这个例子非常基础,但它涵盖了机器学习模型的主要步骤:数据加载、模型构建、训练和评估。学生通过这个例子可以初步理解如何使用 TensorFlow 进行机器学习任务,并为更复杂的模型训练奠定基础。

相关推荐
赋创小助手1 分钟前
NVIDIA B200 GPU 技术解读:Blackwell 架构带来了哪些真实变化?
运维·服务器·人工智能·深度学习·计算机视觉·自然语言处理·架构
newbiai9 分钟前
小白用的AI视频创作软件哪个功能全?
人工智能·python
love530love11 分钟前
【实战经验】解决ComfyUI加载报错:PytorchStreamReader failed reading zip archive: failed finding central directory
人工智能·windows·python·ai作画·aigc·comfyui·攻关
玄同76511 分钟前
LangChain 1.0 框架全面解析:从架构到实践
人工智能·深度学习·自然语言处理·中间件·架构·langchain·rag
资深数据库专家12 分钟前
人大金仓(Kingbase)在 MPP(大规模并行处理)线程设置建议
人工智能·微信公众平台·新浪微博·微信开放平台·人大金仓
AAD5558889912 分钟前
工业纸板加工过程中的检测与识别_Cornernet_Hourglass104模型应用
人工智能·计算机视觉·目标跟踪
觅特科技-互站12 分钟前
陌讯AI视觉赋能政企园区:国家级高新区实现人流超限自动广播与工单闭环
人工智能·排序算法·线性回归
Kevin-anycode19 分钟前
国内安装Claude Code即配置国内大模型(win11环境)
人工智能·语言模型
央链知播20 分钟前
中国移联AI元宇宙产业委与黎阳之光共识“链上诗路-时光门”战略合作
人工智能·业界资讯·数字孪生
格林威20 分钟前
Baumer相机电机转子偏心检测:实现动平衡预判的 5 个核心方法,附 OpenCV+Halcon 实战代码!
人工智能·深度学习·opencv·机器学习·计算机视觉·视觉检测·工业相机