使用TensorFlow实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库

首先,需要引入必要的库。TensorFlow用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

python 复制代码
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

python 复制代码
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用TensorFlow的Keras接口来构建逻辑回归模型。

python 复制代码
# 构建逻辑回归模型
model = Sequential([
    Dense(1, activation='sigmoid', input_shape=(X.shape[1],))
])

# 编译模型
model.compile(optimizer=SGD(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

使用自定义数据集训练模型。

python 复制代码
# 训练模型
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)

5. 保存模型

训练完成后,可以使用TensorFlow的save方法保存模型。

python 复制代码
# 保存模型
model.save('logistic_regression_model.h5')

6. 加载模型并进行预测

在需要时,可以使用TensorFlow的load_model方法加载模型,并进行预测。

python 复制代码
# 加载模型
from tensorflow.keras.models import load_model

loaded_model = load_model('logistic_regression_model.h5')

# 进行预测
predictions = loaded_model.predict(X[:5])
predicted_labels = (predictions > 0.5).astype(int)

print("Predicted Labels:", predicted_labels.flatten())

7. 结果可视化

可以绘制训练过程中的损失和准确率变化曲线,以帮助理解模型的性能。

python 复制代码
# 绘制训练和验证的损失曲线
plt.plot(history.history['loss'], label='Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和验证的准确率曲线
plt.plot(history.history['accuracy'], label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
相关推荐
yueyuebaobaoxinx21 小时前
2025 AI 落地元年:从技术突破到行业重构的实践图景
人工智能·重构
说私域21 小时前
私域整体结构的顶层设计:基于“开源AI智能名片链动2+1模式S2B2C商城小程序”的体系重构
人工智能·小程序·开源
yunyun18863581 天前
AI - 自然语言处理(NLP) - part 1
人工智能·自然语言处理
星期天要睡觉1 天前
计算机视觉(opencv)——疲劳检测
人工智能·opencv·计算机视觉
zxsz_com_cn1 天前
基于AI的设备健康诊断:工业设备智能运维的破局之钥
运维·人工智能
MoRanzhi12031 天前
12. Pandas 数据合并与拼接(concat 与 merge)
数据库·人工智能·python·数学建模·矩阵·数据分析·pandas
杜子不疼.1 天前
【Linux】进程的初步探险:基本概念与基本操作
linux·人工智能·ai
可触的未来,发芽的智生1 天前
触摸未来2025.10.04:当神经网络拥有了内在记忆……
人工智能·python·神经网络·算法·架构
PKNLP1 天前
深度学习之神经网络2(Neural Network)
人工智能·深度学习·神经网络
格林威1 天前
常规的变焦镜头有哪些类型?能做什么?
人工智能·数码相机·opencv·计算机视觉·视觉检测·机器视觉·工业镜头