使用TensorFlow实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库

首先,需要引入必要的库。TensorFlow用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

python 复制代码
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

python 复制代码
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用TensorFlow的Keras接口来构建逻辑回归模型。

python 复制代码
# 构建逻辑回归模型
model = Sequential([
    Dense(1, activation='sigmoid', input_shape=(X.shape[1],))
])

# 编译模型
model.compile(optimizer=SGD(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

使用自定义数据集训练模型。

python 复制代码
# 训练模型
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)

5. 保存模型

训练完成后,可以使用TensorFlow的save方法保存模型。

python 复制代码
# 保存模型
model.save('logistic_regression_model.h5')

6. 加载模型并进行预测

在需要时,可以使用TensorFlow的load_model方法加载模型,并进行预测。

python 复制代码
# 加载模型
from tensorflow.keras.models import load_model

loaded_model = load_model('logistic_regression_model.h5')

# 进行预测
predictions = loaded_model.predict(X[:5])
predicted_labels = (predictions > 0.5).astype(int)

print("Predicted Labels:", predicted_labels.flatten())

7. 结果可视化

可以绘制训练过程中的损失和准确率变化曲线,以帮助理解模型的性能。

python 复制代码
# 绘制训练和验证的损失曲线
plt.plot(history.history['loss'], label='Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和验证的准确率曲线
plt.plot(history.history['accuracy'], label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
相关推荐
小鸡吃米…2 分钟前
机器学习 - 贝叶斯定理
人工智能·python·机器学习
esmap4 分钟前
技术解构:ESMAP AI数字孪生赋能传统行业转型的全链路技术方案
人工智能·低代码·ai·架构·编辑器·智慧城市
不懒不懒7 分钟前
【逻辑回归从原理到实战:正则化、参数调优与过拟合处理】
人工智能·算法·机器学习
喜欢吃豆10 分钟前
对象存储架构演进与AI大模型时代的深度融合:从S3基础到万亿参数训练的技术全景
人工智能·架构
ba_pi14 分钟前
每天写点什么2026-02-2(1.5)数字化转型和元宇宙
大数据·人工智能
vlln18 分钟前
【论文速读】MUSE: 层次记忆和自我反思提升的 Agent
人工智能·语言模型·自然语言处理·ai agent
Funny_AI_LAB22 分钟前
RAD基准重新定义多视角异常检测,传统2D方法为何战胜前沿3D与VLM?
人工智能·目标检测·3d·ai
星河队长23 分钟前
人工智能的自我认知
人工智能
无人装备硬件开发爱好者27 分钟前
AI 赋能航天造物:LEAP71 式火箭发动机计算工程软件开发全解析 1
人工智能·商业火箭发动机·增材加工·leap71
数智联AI团队30 分钟前
AI搜索引领行业变革:2023年GEO优化服务市场深度洞察与专业机构选择指南
人工智能