使用TensorFlow实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库

首先,需要引入必要的库。TensorFlow用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

python 复制代码
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

python 复制代码
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用TensorFlow的Keras接口来构建逻辑回归模型。

python 复制代码
# 构建逻辑回归模型
model = Sequential([
    Dense(1, activation='sigmoid', input_shape=(X.shape[1],))
])

# 编译模型
model.compile(optimizer=SGD(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

使用自定义数据集训练模型。

python 复制代码
# 训练模型
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)

5. 保存模型

训练完成后,可以使用TensorFlow的save方法保存模型。

python 复制代码
# 保存模型
model.save('logistic_regression_model.h5')

6. 加载模型并进行预测

在需要时,可以使用TensorFlow的load_model方法加载模型,并进行预测。

python 复制代码
# 加载模型
from tensorflow.keras.models import load_model

loaded_model = load_model('logistic_regression_model.h5')

# 进行预测
predictions = loaded_model.predict(X[:5])
predicted_labels = (predictions > 0.5).astype(int)

print("Predicted Labels:", predicted_labels.flatten())

7. 结果可视化

可以绘制训练过程中的损失和准确率变化曲线,以帮助理解模型的性能。

python 复制代码
# 绘制训练和验证的损失曲线
plt.plot(history.history['loss'], label='Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和验证的准确率曲线
plt.plot(history.history['accuracy'], label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
相关推荐
Dxy123931021614 小时前
如何给AI提问:让机器高效理解你的需求
人工智能
少林码僧14 小时前
2.31 机器学习神器项目实战:如何在真实项目中应用XGBoost等算法
人工智能·python·算法·机器学习·ai·数据挖掘
钱彬 (Qian Bin)14 小时前
项目实践15—全球证件智能识别系统(切换为Qwen3-VL-8B-Instruct图文多模态大模型)
人工智能·算法·机器学习·多模态·全球证件识别
没学上了15 小时前
CNNMNIST
人工智能·深度学习
宝贝儿好15 小时前
【强化学习】第六章:无模型控制:在轨MC控制、在轨时序差分学习(Sarsa)、离轨学习(Q-learning)
人工智能·python·深度学习·学习·机器学习·机器人
智驱力人工智能15 小时前
守护流动的规则 基于视觉分析的穿越导流线区检测技术工程实践 交通路口导流区穿越实时预警技术 智慧交通部署指南
人工智能·opencv·安全·目标检测·计算机视觉·cnn·边缘计算
AI产品备案15 小时前
生成式人工智能大模型备案制度与发展要求
人工智能·深度学习·大模型备案·算法备案·大模型登记
AC赳赳老秦16 小时前
DeepSeek 私有化部署避坑指南:敏感数据本地化处理与合规性检测详解
大数据·开发语言·数据库·人工智能·自动化·php·deepseek
wm104316 小时前
机器学习之线性回归
人工智能·机器学习·线性回归
通义灵码16 小时前
Qoder 支持通过 DeepLink 添加 MCP Server
人工智能·github·mcp