七。自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

import tensorflow as tf

import numpy as np

自定义数据集类

class CustomDataset(tf.data.Dataset):

def init(self, x_data, y_data):

self.x_data = tf.convert_to_tensor(x_data, dtype=tf.float32)

self.y_data = tf.convert_to_tensor(y_data, dtype=tf.float32)

def iter(self):

for i in range(len(self.x_data)):

yield (self.x_datai, self.y_datai)

逻辑回归模型

class LogisticRegressionModel(tf.keras.Model):

def init(self, input_dim):

super(LogisticRegressionModel, self).init()

self.linear = tf.keras.layers.Dense(1, input_shape=(input_dim,), activation='sigmoid')

def call(self, x):

return self.linear(x)

创建数据集

x_data = np.array(\[1, 2, 3, 4, 5], dtype=np.float32)

y_data = np.array(\[0, 0, 1, 1, 1], dtype=np.float32)

dataset = CustomDataset(x_data, y_data)

创建数据加载器

dataloader = dataset.batch(2).shuffle(100).repeat()

创建模型、损失函数和优化器

model = LogisticRegressionModel(input_dim=1)

loss_object = tf.keras.losses.BinaryCrossentropy()

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

训练模型

epochs = 100

for epoch in range(epochs):

for x_batch, y_batch in dataloader:

with tf.GradientTape() as tape:

predictions = model(x_batch)

loss = loss_object(y_batch, predictions)

gradients = tape.gradient(loss, model.trainable_variables)

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

if (epoch+1) % 10 == 0:

print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.numpy():.4f}')

保存模型

model.save('logistic_regression_model.h5')

加载模型

model = tf.keras.models.load_model('logistic_regression_model.h5')

进行预测

x_test = np.array(\[6, 7, 8], dtype=np.float32)

y_pred = model.predict(x_test)

print('预测值:', y_pred)

相关推荐
珺毅同学8 小时前
YOLO生成预测json标签迁移问题
python·yolo·json
骑士雄师8 小时前
18.4 长期记忆可修改版
python
~小先生~8 小时前
Python从入门到放弃(一)
开发语言·python
天佑木枫8 小时前
第2天:变量与数据类型 —— 让程序记住信息
python
Dust-Chasing9 小时前
Claude Code源码剖析 - Claude Code 上下文压缩机制
人工智能·python·ai
Cloud_Shy61810 小时前
解读《Effective Python 3rd Edition》:从练气到老魔(第五章 Item 33 - 35)
开发语言·人工智能·笔记·python·学习方法
abcy07121311 小时前
python pandas csv异步后台清洗前端优先返回成功信息
前端·python·pandas
颜酱11 小时前
LangChain使用RAG 入门:让大模型读懂你的私有文档
python·langchain
天天进步201512 小时前
Python全栈项目--校园智能宿舍管理系统
开发语言·python
测试员周周12 小时前
【AI测试智能体-面试】AI测试面试60题(附回答思路)
人工智能·python·功能测试·测试工具·单元测试·自动化·测试用例