猫狗识别大模型——基于python语言

目录

1.猫狗识别

2.数据集介绍

3.猫狗识别核心原理

4.程序思路

4.1数据文件框架

[4.2 训练模型](#4.2 训练模型)

[4.3 模型使用](#4.3 模型使用)

[4.4 识别结果](#4.4 识别结果)

5.总结


1.猫狗识别

人可以直接分辨出图片里的动物是猫还是狗,但是电脑不可以,要想让电脑也分辨出图片里的动物是猫还是小狗,就要使用到深度学习,电脑学习提取图片特征,进而学习区分图片里的是猫还是狗。

2.数据集介绍

程序用到的训练数据集是猫狗图像数据集,数据格式jpg格式,猫狗数据集:

复制代码
https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

3.猫狗识别核心原理

猫狗识别大模型是一种深度学习架构,主要用于图像分类任务,用来区分猫和狗这两种常见的宠物动物。

该模型基于卷积神经网络(CNN),它们通过学习大量的猫和狗图像数据集中的特征来进行训练,使其能够识别出输入图片中动物的种类。

训练过程中,模型会对猫的特有纹理、颜色模式、耳朵形状等特征进行学习,并形成区分猫狗的关键特征模板。一旦模型经过充分训练并优化,它可以准确地判断新的未知图片是属于猫还是狗。

应用此类模型的方式通常是将其部署到移动设备或者云端服务器上,用户上传一张照片后,模型会返回一个预测结果,指示图像中动物的类别。

4.程序思路

基于tensorflow模型框架以及卷积神经网络还有其他各种模块,划分训练集,微调集和测试机,对猫狗图片文件进行训练。

4.1数据文件框架

4.2 训练模型

复制代码
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import os

# 获取所有的GPU设备
gpus = tf.config.list_physical_devices('GPU')

# 检查是否有两个以上的GPU
if gpus and len(gpus) > 1:
    try:
        # 假设GPU1是独立GPU,设置可见设备为GPU1
        tf.config.set_visible_devices(gpus[1], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[1], True)
    except RuntimeError as e:
        print(e)
else:
    print("没有检测到多个GPU,或者系统只存在一个GPU。")

# 定义数据目录
data_dir = './pythonProject/ai_modle_win/cats vs dogs/dataset'  # 请替换为你的数据集路径
train_dir = os.path.join(data_dir, 'train')
validation_dir = os.path.join(data_dir, 'validation')
test_dir = os.path.join(data_dir, 'test')

# 图像数据生成器
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

validation_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# 计算样本数量
def count_files(directory):
    total_files = 0
    for root, dirs, files in os.walk(directory):
        total_files += len(files)
    return total_files

train_samples = count_files(train_dir)
validation_samples = count_files(validation_dir)
test_samples = count_files(test_dir)

# 数据生成器
def create_generator(datagen, directory, target_size, batch_size, class_mode):
    generator = datagen.flow_from_directory(
        directory,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=class_mode
    )
    # 包装生成器以处理损坏的图像文件
    while True:
        try:
            yield next(generator)
        except (OSError, StopIteration) as e:
            print(f"跳过无法读取的图像文件:{e}")
            continue

train_generator = create_generator(train_datagen, train_dir, (150, 150), 32, 'binary')
validation_generator = create_generator(validation_datagen, validation_dir, (150, 150), 32, 'binary')
test_generator = create_generator(test_datagen, test_dir, (150, 150), 32, 'binary')

# 定义模型
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dropout(0.5),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=0.001),
              metrics=['accuracy'])

# 训练模型
history = model.fit(
    train_generator,
    steps_per_epoch=train_samples // 32,  # 将结果转换为整数
    validation_data=validation_generator,
    validation_steps=validation_samples // 32,  # 将结果转换为整数
    epochs=5
)

# 保存模型
model.save('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5')

# 评估模型
test_loss, test_acc = model.evaluate(test_generator, steps=test_samples // 32)
print(f'Test accuracy: {test_acc:.2f}')

# 可视化训练结果
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.figure(figsize=(12, 9))

plt.subplot(1, 2, 1)
plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

注意更改文件路径!!!

4.3 模型使用

复制代码
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import os

# 加载已保存的模型
model = load_model('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5')

# 预测函数
def predict_image(img_path):
    img = image.load_img(img_path, target_size=(150, 150))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0

    prediction = model.predict(img_array)
    if prediction[0] > 0.5:
        print(f"The image at {img_path} is a Dog")
    else:
        print(f"The image at {img_path} is a Cat")

# 示例用法
test_image_path = './pythonProject/ai_modle_win/cats vs dogs/30.jpg'  # 替换为你的测试图片路径
predict_image(test_image_path)

使用上述训练的模型进行图片识别,注意文件路径。

4.4 识别结果

5.总结

通过构造猫狗图片数据集,然后使用深度学习训练一个猫狗识别大模型,你也快来试一试吧。

相关推荐
WeiXiao_Hyy几秒前
成为 Top 1% 的工程师
java·开发语言·javascript·经验分享·后端
ZH1545589131几秒前
Flutter for OpenHarmony Python学习助手实战:面向对象编程实战的实现
python·学习·flutter
玄同7651 分钟前
SQLite + LLM:大模型应用落地的轻量级数据存储方案
jvm·数据库·人工智能·python·语言模型·sqlite·知识图谱
User_芊芊君子6 分钟前
CANN010:PyASC Python编程接口—简化AI算子开发的Python框架
开发语言·人工智能·python
Max_uuc16 分钟前
【C++ 硬核】打破嵌入式 STL 禁忌:利用 std::pmr 在“栈”上运行 std::vector
开发语言·jvm·c++
白日做梦Q17 分钟前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
故事不长丨17 分钟前
C#线程同步:lock、Monitor、Mutex原理+用法+实战全解析
开发语言·算法·c#
牵牛老人20 分钟前
【Qt 开发后台服务避坑指南:从库存管理系统开发出现的问题来看后台开发常见问题与解决方案】
开发语言·qt·系统架构
froginwe1128 分钟前
Python3与MySQL的连接:使用mysql-connector
开发语言
喵手31 分钟前
Python爬虫实战:公共自行车站点智能采集系统 - 从零构建生产级爬虫的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集公共自行车站点·公共自行车站点智能采集系统·采集公共自行车站点导出csv