七。自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

import tensorflow as tf

import numpy as np

自定义数据集类

class CustomDataset(tf.data.Dataset):

def init(self, x_data, y_data):

self.x_data = tf.convert_to_tensor(x_data, dtype=tf.float32)

self.y_data = tf.convert_to_tensor(y_data, dtype=tf.float32)

def iter(self):

for i in range(len(self.x_data)):

yield (self.x_data[i], self.y_data[i])

逻辑回归模型

class LogisticRegressionModel(tf.keras.Model):

def init(self, input_dim):

super(LogisticRegressionModel, self).init()

self.linear = tf.keras.layers.Dense(1, input_shape=(input_dim,), activation='sigmoid')

def call(self, x):

return self.linear(x)

创建数据集

x_data = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)

y_data = np.array([[0], [0], [1], [1], [1]], dtype=np.float32)

dataset = CustomDataset(x_data, y_data)

创建数据加载器

dataloader = dataset.batch(2).shuffle(100).repeat()

创建模型、损失函数和优化器

model = LogisticRegressionModel(input_dim=1)

loss_object = tf.keras.losses.BinaryCrossentropy()

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

训练模型

epochs = 100

for epoch in range(epochs):

for x_batch, y_batch in dataloader:

with tf.GradientTape() as tape:

predictions = model(x_batch)

loss = loss_object(y_batch, predictions)

gradients = tape.gradient(loss, model.trainable_variables)

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

if (epoch+1) % 10 == 0:

print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy():.4f}')

保存模型

model.save('logistic_regression_model.h5')

加载模型

model = tf.keras.models.load_model('logistic_regression_model.h5')

进行预测

x_test = np.array([[6], [7], [8]], dtype=np.float32)

y_pred = model.predict(x_test)

print('预测值:', y_pred)

相关推荐
zzzzls~7 小时前
Python 工程化: 用 Copier 打造“自我进化“的项目脚手架
开发语言·python·copier
韶博雅7 小时前
emcc24ai
开发语言·数据库·python
He少年7 小时前
【基础知识、Skill、Rules和MCP案例介绍】
java·前端·python
AI_Claude_code7 小时前
ZLibrary访问困境方案四:利用Cloudflare Workers等边缘计算实现访问
javascript·人工智能·爬虫·python·网络爬虫·边缘计算·爬山算法
jedi-knight8 小时前
AGI时代下的青年教师与学术民主化
人工智能·python·agi
迷藏4948 小时前
**eBPF实战进阶:从零构建网络流量监控与过滤系统**在现代云原生架构中,**网络可观测性**和**安全隔离**已成为
java·网络·python·云原生·架构
迷藏4948 小时前
**发散创新:基于Solid协议的Web3.0去中心化身份认证系统实战解析**在Web3.
java·python·web3·去中心化·区块链
weixin_156241575768 小时前
基于YOLOv8深度学习花卉识别系统摄像头实时图片文件夹多图片等另有其他的识别系统可二开
大数据·人工智能·python·深度学习·yolo
AI_Claude_code8 小时前
ZLibrary访问困境方案三:Web代理与轻量级转发服务的搭建与优化
爬虫·python·web安全·搜索引擎·网络安全·web3·httpx
小陈工8 小时前
2026年4月7日技术资讯洞察:下一代数据库融合、AI基础设施竞赛与异步编程实战
开发语言·前端·数据库·人工智能·python