使用用tensorflow实现鸢尾花的分类

python 复制代码
import tensorflow as tf
from tensorflow import keras
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Create the model
model = keras.Sequential([
    keras.layers.Dense(10, activation='relu', input_shape=(X.shape[1],)),
    keras.layers.Dense(3, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(X_train, y_train, epochs=100)

# Evaluate the model
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test accuracy:', test_acc)
复制代码
运行结果
epoch 100/100
120/120 [==============================] - 0s 44us/sample - loss: 0.3806 - accuracy: 0.9667
30/30 [==============================] - 0s 1ms/sample - loss: 0.3763 - accuracy: 0.9667
Test accuracy: 0.96666664

此代码将首先加载鸢尾花数据集。然后,它将数据拆分为训练集和测试集。接下来,它将定义一个简单的模型,该模型包含一个具有10个隐藏单元的隐藏层和一个具有3个输出单元的输出层(对应于三个鸢尾花物种)。然后,它将编译模型并使用训练数据训练它。最后,它将在测试集上评估模型并打印准确性分数。

以下是一些有关该代码的更多详细信息:

  • datasets.load_iris() 函数用于加载鸢尾花数据集。它返回一个包含数据和目标变量的字典。
  • train_test_split() 函数用于将数据拆分为训练集和测试集。它随机选择测试集中的数据,其余数据用于训练。
  • keras.Sequential() 类用于创建顺序模型。顺序模型是一系列层,其中每个层的输出馈送到下一个层。
  • keras.layers.Dense() 类用于创建密集层。密集层是完全连接的,这意味着每个输入单元都连接到每个输出单元。
  • activation='relu' 告诉ReLU激活函数应用于该层的输出。ReLU函数将负值输出替换为零。
  • activation='softmax' 告诉softmax激活函数应用于该层的输出。softmax函数将每个输出转换为概率分布,其中概率之和为1。
  • optimizer='adam' 告诉模型使用Adam优化器进行训练。Adam是一种流行的优化器,通常在实践中表现良好。
  • loss='sparse_categorical_crossentropy' 告诉模型使用稀疏类别交叉熵损失函数。此损失函数适用于分类问题,其中目标变量是类别。
  • metrics=['accuracy'] 告诉模型跟踪准确性度量。准确性是正确分类的示例的百分比。
  • model.fit() 方法用于训练模型。它接受训练数据和目标变量作为参数。
  • epochs=10 告诉模型训练10个纪元。纪元是模型看到所有训练数据的一次完整传递。
  • model.evaluate() 方法用于评估模型。它接受测试数据和目标变量作为参数,并返回损失和度量值(例如准确性)的列表。

您可以通过更改模型架构、训练参数和超参数来实验此代码。例如,您可以尝试添加更多隐藏层或单元,使用不同的激活函数,或使用不同的优化器或损失函数。

以下是一些有关如何改进此代码的提示:

  • 尝试使用不同的模型架构,例如卷积神经网络或循环神经网络。
  • 使用正则化技术(例如L1或L2正则化)来防止过拟合。
  • 使用数据增强技术来人工增加训练数据量。
  • 在更大的数据集上训练模型。
相关推荐
kovlistudio23 分钟前
机器学习第三讲:监督学习 → 带答案的学习册,如预测房价时需要历史价格数据
人工智能·机器学习
嵌入式仿真实验教学平台27 分钟前
「国产嵌入式仿真平台:高精度虚实融合如何终结Proteus时代?」——从教学实验到低空经济,揭秘新一代AI赋能的产业级教学工具
人工智能·学习·proteus·无人机·低空经济·嵌入式仿真·实验教学
正在走向自律1 小时前
Python 数据分析与可视化:开启数据洞察之旅(5/10)
开发语言·人工智能·python·数据挖掘·数据分析
LuvMyLife1 小时前
基于Win在VSCode部署运行OpenVINO模型
人工智能·深度学习·计算机视觉·openvino
fancy1661661 小时前
力扣top100 矩阵置零
人工智能·算法·矩阵
gaosushexiangji1 小时前
基于千眼狼高速摄像机与三色掩模的体三维粒子图像测速PIV技术
人工智能·数码相机·计算机视觉
中电金信2 小时前
重构金融数智化产业版图:中电金信“链主”之道
大数据·人工智能
奋斗者1号2 小时前
Docker 部署 - Crawl4AI 文档 (v0.5.x)
人工智能·爬虫·机器学习
陈奕昆2 小时前
五、【LLaMA-Factory实战】模型部署与监控:从实验室到生产的全链路实践
开发语言·人工智能·python·llama·大模型微调
多巴胺与内啡肽.2 小时前
OpenCV进阶操作:光流估计
人工智能·opencv·计算机视觉