图像分类模型 传统训练VS迁移学习训练

相同的数据集,训练结果对比。传统的只有70%左右,迁移学习方式有95%的准确率。

传统训练方式:

训练模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifierTrainer:
    """花卉分类模型训练器"""
    
    def __init__(self, img_height=180, img_width=180, batch_size=32):
        self.img_height = img_height
        self.img_width = img_width
        self.batch_size = batch_size
        self.model = None
        self.history = None
        self.class_names = None
        self.train_ds = None
        self.val_ds = None
        
    def setup_gpu(self):
        """配置GPU设置"""
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # 设置内存增长
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                # 设置逻辑GPU
                logical_gpus = tf.config.experimental.list_logical_devices('GPU')
                print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
                return True
            except RuntimeError as e:
                print(e)
        return False
    
    def load_data(self, data_dir=None):
        """加载数据集"""
        if data_dir is None:
            # 下载默认数据集
            dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
            data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
            data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
        
        # 创建训练和验证数据集
        self.train_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.val_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.class_names = self.train_ds.class_names
        print(f"类别名称: {self.class_names}")
        print(f"类别数量: {len(self.class_names)}")
        
        return self.train_ds, self.val_ds
    
    def build_model(self):
        """构建模型"""
        data_augmentation = keras.Sequential([
            layers.RandomFlip("horizontal", input_shape=(self.img_height, self.img_width, 3)),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.1),
        ])
        
        num_classes = len(self.class_names)
        
        self.model = Sequential([
            data_augmentation,
            layers.Rescaling(1./255),
            layers.Conv2D(16, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Conv2D(32, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Conv2D(64, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Dropout(0.2),
            layers.Flatten(),
            layers.Dense(128, activation='relu'),
            layers.Dense(num_classes, name="outputs")
        ])
        
        self.model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
        
        return self.model
    
    def train(self, epochs=15, verbose=1):
        """训练模型"""
        if self.model is None:
            self.build_model()
        
        print("开始训练模型...")
        self.history = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs,
            verbose=verbose
        )
        
        return self.history
    
    def evaluate(self):
        """评估模型"""
        if self.model is None:
            raise ValueError("模型尚未训练,请先调用 train() 方法")
        
        print("评估模型在验证集上的表现:")
        val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
        print(f"验证损失: {val_loss:.4f}")
        print(f"验证准确率: {val_accuracy:.4f}")
        
        return val_loss, val_accuracy
    
    def plot_training_history(self):
        """绘制训练历史"""
        if self.history is None:
            raise ValueError("没有训练历史数据,请先训练模型")
        
        acc = self.history.history['accuracy']
        val_acc = self.history.history['val_accuracy']
        loss = self.history.history['loss']
        val_loss = self.history.history['val_loss']
        
        epochs_range = range(len(acc))
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        
        plt.subplot(1, 2, 2)
        plt.plot(epochs_range, loss, label='Training Loss')
        plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        
        plt.tight_layout()
        plt.show()
    
    def save_model_simple(self, model_path='flower_classification_model.keras'):
        """简化版保存方法 - 推荐使用"""
        if self.model is None:
            raise ValueError("没有可保存的模型")
        
        # 确保文件扩展名正确
        if not model_path.endswith(('.keras', '.h5')):
            model_path += '.keras'  # 默认使用新的Keras格式
        
        self.model.save(model_path)
        print(f"模型已保存到: {model_path}")
        
        # 保存类别名称
        class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
        with open(class_names_path, 'w', encoding='utf-8') as f:
            json.dump(self.class_names, f, ensure_ascii=False, indent=2)
        print(f"类别名称已保存到: {class_names_path}")

使用模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifier:
    """花卉分类预测器"""
    
    def __init__(self, model_path='flower_classification_model.keras'):
        self.model = None
        self.class_names = None
        self.img_height = 180
        self.img_width = 180
        self.load_model(model_path)
    
    def load_model(self, model_path):
        """加载已训练的模型"""
        if not os.path.exists(model_path):
            # 如果指定的路径不存在,尝试其他可能的格式
            possible_paths = [
                model_path,
                model_path.replace('.keras', '.h5'),
                model_path.replace('.h5', '.keras'),
                model_path + '.keras',
                model_path + '.h5'
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            else:
                raise FileNotFoundError(f"模型文件不存在: {model_path}")
        
        # 加载模型
        self.model = tf.keras.models.load_model(model_path)
        print(f"模型已从 {model_path} 加载")
        
        # 加载类别名称
        base_path = model_path.rsplit('.', 1)[0]
        class_names_path = base_path + '_class_names.json'
        
        if os.path.exists(class_names_path):
            with open(class_names_path, 'r', encoding='utf-8') as f:
                self.class_names = json.load(f)
            print(f"类别名称已加载: {self.class_names}")
        else:
            print("警告: 未找到类别名称文件,将使用默认类别索引")
            self.class_names = None
        
        # 尝试加载训练配置
        config_path = base_path + '_training_config.json'
        if os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
                self.img_height = config.get('img_height', 180)
                self.img_width = config.get('img_width', 180)
            print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
    
    def predict_image(self, image_path):
        """预测单张图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        # 加载和预处理图片
        img = tf.keras.utils.load_img(
            image_path, target_size=(self.img_height, self.img_width)
        )
        img_array = tf.keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)  # 创建批次
        
        # 进行预测
        predictions = self.model.predict(img_array, verbose=0)
        scores = tf.nn.softmax(predictions[0])
        
        # 获取预测结果
        predicted_class_idx = np.argmax(scores)
        confidence = 100 * np.max(scores)
        
        if self.class_names is not None:
            predicted_class = self.class_names[predicted_class_idx]
        else:
            predicted_class = f"Class_{predicted_class_idx}"
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'all_scores': scores.numpy(),
            'class_index': predicted_class_idx
        }
    
    def predict_image_from_url(self, image_url, image_name='temp_image'):
        """从URL下载图片并进行预测"""
        try:
            image_path = tf.keras.utils.get_file(image_name, origin=image_url)
            return self.predict_image(image_path)
        except Exception as e:
            print(f"从URL加载图片时出错: {e}")
            return None
    
    def predict_batch(self, image_dir):
        """批量预测目录中的图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        results = []
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        
        for file_path in pathlib.Path(image_dir).glob('*'):
            if file_path.suffix.lower() in image_extensions:
                try:
                    result = self.predict_image(str(file_path))
                    result['file_path'] = str(file_path)
                    results.append(result)
                except Exception as e:
                    print(f"处理图片 {file_path} 时出错: {e}")
        
        return results
    
    def display_prediction(self, prediction_result, show_image=True):
        """显示预测结果"""
        if prediction_result is None:
            print("没有预测结果")
            return
        
        print(
            "预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
            .format(prediction_result['predicted_class'], prediction_result['confidence'])
        )
        
        # 显示所有类别的置信度
        if self.class_names is not None:
            print("\n所有类别置信度:")
            for i, class_name in enumerate(self.class_names):
                score = prediction_result['all_scores'][i] * 100
                print(f"  {class_name}: {score:.2f}%")

main:

python 复制代码
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer



def trainerModel():
     # 示例1: 训练模型
    print("=== 训练模型 ===")
    trainer = FlowerClassifierTrainer()
    
    # 设置GPU
    gpu_available = trainer.setup_gpu()
    if gpu_available:
        print("🎉 使用GPU进行训练!")
    else:
        print("⚠️ 没有可用的GPU,使用CPU")
    
    # 加载数据
    trainer.load_data()
    
    # 构建和训练模型
    trainer.build_model()
    trainer.model.summary()
    
    # 训练模型
    history = trainer.train(epochs=15)
    
    # 评估模型
    trainer.evaluate()
    
    # 绘制训练历史
    trainer.plot_training_history()
    
    # 保存模型
    trainer.save_model_simple('flower_classification_model.keras')



# 使用示例
if __name__ == "__main__":
   
    
    print("\n" + "="*50 + "\n")
    
    # 示例2: 使用训练好的模型进行预测
    print("=== 使用模型进行预测 ===")
    classifier = FlowerClassifier()
    
    # 示例图片URL预测
    sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
    
    print("从URL预测图片:")
    result = classifier.predict_image_from_url(sunflower_url, 'Red_sunflower')
    classifier.display_prediction(result)
    
    print("\n" + "="*50 + "\n")

结果:75%左右

迁移学习方式:

训练模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifierTrainer:
    """花卉分类模型训练器 - 迁移学习版本"""
    
    def __init__(self, img_height=224, img_width=224, batch_size=32):
        self.img_height = img_height
        self.img_width = img_width
        self.batch_size = batch_size
        self.model = None
        self.history = None
        self.class_names = None
        self.train_ds = None
        self.val_ds = None
        self.base_model = None
        
    def setup_gpu(self):
        """配置GPU设置"""
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # 设置内存增长
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                # 设置逻辑GPU
                logical_gpus = tf.config.experimental.list_logical_devices('GPU')
                print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
                return True
            except RuntimeError as e:
                print(e)
        return False
    
    def load_data(self, data_dir=None):
        """加载数据集"""
        if data_dir is None:
            # 下载默认数据集
            dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
            data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
            data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
        
        # 创建训练和验证数据集
        self.train_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.val_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.class_names = self.train_ds.class_names
        print(f"类别名称: {self.class_names}")
        print(f"类别数量: {len(self.class_names)}")
        
        # 优化数据管道
        AUTOTUNE = tf.data.AUTOTUNE
        self.train_ds = self.train_ds.prefetch(buffer_size=AUTOTUNE)
        self.val_ds = self.val_ds.prefetch(buffer_size=AUTOTUNE)
        
        return self.train_ds, self.val_ds
    
    def build_model_with_transfer_learning(self, base_model_name='MobileNetV2', fine_tune=False):
        """使用迁移学习构建模型
        
        Args:
            base_model_name: 预训练模型名称 ('MobileNetV2', 'EfficientNetB0', 'ResNet50')
            fine_tune: 是否进行微调(解冻部分层)
        """
        # 选择预训练模型
        if base_model_name == 'MobileNetV2':
            self.base_model = tf.keras.applications.MobileNetV2(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        elif base_model_name == 'EfficientNetB0':
            self.base_model = tf.keras.applications.EfficientNetB0(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        elif base_model_name == 'ResNet50':
            self.base_model = tf.keras.applications.ResNet50(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        else:
            raise ValueError(f"不支持的模型: {base_model_name}")
        
        print(f"使用预训练模型: {base_model_name}")
        
        # 第一阶段:冻结基础模型,只训练顶层
        self.base_model.trainable = False
        
        num_classes = len(self.class_names)
        
        # 数据增强
        data_augmentation = keras.Sequential([
            layers.RandomFlip("horizontal_and_vertical"),
            layers.RandomRotation(0.2),
            layers.RandomZoom(0.2),
            layers.RandomContrast(0.2),
        ])
        
        # 构建完整模型
        self.model = Sequential([
            data_augmentation,
            layers.Rescaling(1./127.5, offset=-1),  # 归一化到[-1, 1]
            self.base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dropout(0.3),
            layers.Dense(128, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(num_classes, activation='softmax')
        ])
        
        # 编译模型
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        print("迁移学习模型构建完成(基础模型冻结)")
        return self.model
    
    def fine_tune_model(self, fine_tune_layers=100):
        """微调模型 - 解冻部分基础模型层
        
        Args:
            fine_tune_layers: 要解冻的层数
        """
        if self.base_model is None:
            raise ValueError("请先调用 build_model_with_transfer_learning()")
        
        # 解冻基础模型的部分层
        self.base_model.trainable = True
        
        # 冻结前面的层,只训练后面的层
        for layer in self.base_model.layers[:-fine_tune_layers]:
            layer.trainable = False
        
        # 重新编译模型,使用更小的学习率
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),  # 更小的学习率
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        print(f"模型已设置为微调模式,解冻了最后 {fine_tune_layers} 层")
        print(f"使用学习率: 0.0001")
    
    def train(self, epochs_initial=10, epochs_fine_tune=10, verbose=1):
        """训练模型(两阶段训练)
        
        Args:
            epochs_initial: 初始训练轮数(基础模型冻结)
            epochs_fine_tune: 微调训练轮数(基础模型部分解冻)
        """
        if self.model is None:
            self.build_model_with_transfer_learning()
        
        # 定义回调函数
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=5,
                restore_best_weights=True,
                mode='max'
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=3,
                min_lr=1e-7
            )
        ]
        
        print("=" * 50)
        print("第一阶段:基础模型冻结训练")
        print("=" * 50)
        
        # 第一阶段训练:基础模型冻结
        history_initial = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs_initial,
            callbacks=callbacks,
            verbose=verbose
        )
        
        # 第二阶段:微调
        print("=" * 50)
        print("第二阶段:模型微调")
        print("=" * 50)
        
        self.fine_tune_model(fine_tune_layers=100)
        
        history_fine_tune = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs_fine_tune,
            callbacks=callbacks,
            verbose=verbose
        )
        
        # 合并训练历史
        self.history = {
            'accuracy': history_initial.history['accuracy'] + history_fine_tune.history['accuracy'],
            'val_accuracy': history_initial.history['val_accuracy'] + history_fine_tune.history['val_accuracy'],
            'loss': history_initial.history['loss'] + history_fine_tune.history['loss'],
            'val_loss': history_initial.history['val_loss'] + history_fine_tune.history['val_loss']
        }
        
        return self.history
    
    def evaluate(self):
        """评估模型"""
        if self.model is None:
            raise ValueError("模型尚未训练,请先调用 train() 方法")
        
        print("评估模型在验证集上的表现:")
        val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
        print(f"验证损失: {val_loss:.4f}")
        print(f"验证准确率: {val_accuracy:.4f}")
        
        return val_loss, val_accuracy
    
    def plot_training_history(self):
        """绘制训练历史"""
        if self.history is None:
            raise ValueError("没有训练历史数据,请先训练模型")
        
        acc = self.history['accuracy']
        val_acc = self.history['val_accuracy']
        loss = self.history['loss']
        val_loss = self.history['val_loss']
        
        epochs_range = range(len(acc))
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.axvline(x=len(self.history['accuracy']) - len([x for x in self.history['accuracy'] if x > 0])//2, 
                   color='r', linestyle='--', alpha=0.7, label='开始微调')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        
        plt.subplot(1, 2, 2)
        plt.plot(epochs_range, loss, label='Training Loss')
        plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.axvline(x=len(self.history['loss']) - len([x for x in self.history['loss'] if x > 0])//2, 
                   color='r', linestyle='--', alpha=0.7, label='开始微调')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        
        plt.tight_layout()
        plt.show()
    
    def save_model(self, model_path='flower_classification_model_transfer.keras'):
        """保存模型"""
        if self.model is None:
            raise ValueError("没有可保存的模型")
        
        # 确保文件扩展名正确
        if not model_path.endswith(('.keras', '.h5')):
            model_path += '.keras'
        
        self.model.save(model_path)
        print(f"模型已保存到: {model_path}")
        
        # 保存类别名称
        class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
        with open(class_names_path, 'w', encoding='utf-8') as f:
            json.dump(self.class_names, f, ensure_ascii=False, indent=2)
        print(f"类别名称已保存到: {class_names_path}")
        
        # 保存训练配置
        config_path = model_path.rsplit('.', 1)[0] + '_training_config.json'
        config = {
            'img_height': self.img_height,
            'img_width': self.img_width,
            'batch_size': self.batch_size,
            'model_type': 'transfer_learning'
        }
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(config, f, ensure_ascii=False, indent=2)
        print(f"训练配置已保存到: {config_path}")

使用模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os

class FlowerClassifier:
    """花卉分类预测器 - 兼容迁移学习模型"""
    
    def __init__(self, model_path='flower_classification_model_transfer.keras'):
        self.model = None
        self.class_names = None
        self.img_height = 224  # 默认使用迁移学习的标准尺寸
        self.img_width = 224
        self.load_model(model_path)
    
    def load_model(self, model_path):
        """加载已训练的模型"""
        if not os.path.exists(model_path):
            # 如果指定的路径不存在,尝试其他可能的格式
            possible_paths = [
                model_path,
                model_path.replace('.keras', '.h5'),
                model_path.replace('.h5', '.keras'),
                model_path + '.keras',
                model_path + '.h5'
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            else:
                raise FileNotFoundError(f"模型文件不存在: {model_path}")
        
        # 加载模型
        self.model = tf.keras.models.load_model(model_path)
        print(f"模型已从 {model_path} 加载")
        
        # 加载类别名称
        base_path = model_path.rsplit('.', 1)[0]
        class_names_path = base_path + '_class_names.json'
        
        if os.path.exists(class_names_path):
            with open(class_names_path, 'r', encoding='utf-8') as f:
                self.class_names = json.load(f)
            print(f"类别名称已加载: {self.class_names}")
        else:
            print("警告: 未找到类别名称文件,将使用默认类别索引")
            self.class_names = None
        
        # 尝试加载训练配置
        config_path = base_path + '_training_config.json'
        if os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
                self.img_height = config.get('img_height', 224)
                self.img_width = config.get('img_width', 224)
            print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
    
    def predict_image(self, image_path):
        """预测单张图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        # 加载和预处理图片
        img = tf.keras.utils.load_img(
            image_path, target_size=(self.img_height, self.img_width)
        )
        img_array = tf.keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)  # 创建批次
        
        # 进行预测
        predictions = self.model.predict(img_array, verbose=0)
        scores = tf.nn.softmax(predictions[0])
        
        # 获取预测结果
        predicted_class_idx = np.argmax(scores)
        confidence = 100 * np.max(scores)
        
        if self.class_names is not None:
            predicted_class = self.class_names[predicted_class_idx]
        else:
            predicted_class = f"Class_{predicted_class_idx}"
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'all_scores': scores.numpy(),
            'class_index': predicted_class_idx
        }
    
    def predict_image_from_url(self, image_url, image_name='temp_image'):
        """从URL下载图片并进行预测"""
        try:
            image_path = tf.keras.utils.get_file(image_name, origin=image_url)
            return self.predict_image(image_path)
        except Exception as e:
            print(f"从URL加载图片时出错: {e}")
            return None
    
    def predict_batch(self, image_dir):
        """批量预测目录中的图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        results = []
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        
        for file_path in pathlib.Path(image_dir).glob('*'):
            if file_path.suffix.lower() in image_extensions:
                try:
                    result = self.predict_image(str(file_path))
                    result['file_path'] = str(file_path)
                    results.append(result)
                except Exception as e:
                    print(f"处理图片 {file_path} 时出错: {e}")
        
        return results
    
    def display_prediction(self, prediction_result, show_image=True):
        """显示预测结果"""
        if prediction_result is None:
            print("没有预测结果")
            return
        
        print(
            "预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
            .format(prediction_result['predicted_class'], prediction_result['confidence'])
        )
        
        # 显示所有类别的置信度
        if self.class_names is not None:
            print("\n所有类别置信度:")
            for i, class_name in enumerate(self.class_names):
                score = prediction_result['all_scores'][i] * 100
                print(f"  {class_name}: {score:.2f}%")

main:

python 复制代码
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer


# 使用示例
def train_flower_classifier():
    """训练迁移学习花卉分类器"""
    trainer = FlowerClassifierTrainer(img_height=224, img_width=224, batch_size=32)
    trainer.setup_gpu()
    trainer.load_data()
    
    # 使用迁移学习构建模型
    trainer.build_model_with_transfer_learning('MobileNetV2')
    
    # 训练模型(两阶段)
    history = trainer.train(epochs_initial=10, epochs_fine_tune=10)
    
    # 评估模型
    trainer.evaluate()
    
    # 绘制训练历史
    trainer.plot_training_history()
    
    # 保存模型
    trainer.save_model('flower_model_transfer_learning.keras')
    
    return trainer


if __name__ == "__main__":
    # 训练模型
    trainer = train_flower_classifier()
    
    # 使用训练好的模型进行预测
    classifier = FlowerClassifier('flower_model_transfer_learning.keras')
    
    # 示例预测(需要实际图片路径)
    # result = classifier.predict_image('path_to_your_flower_image.jpg')
    # classifier.display_prediction(result)

结果:准确率接近95%

相关推荐
碧海银沙音频科技研究院1 小时前
CLIP(对比语言-图像预训练)在长尾图像分类应用
python·深度学习·分类
极客BIM工作室3 小时前
详解 KL 散度的反向传播计算:以三分类神经网络为例
神经网络·机器学习·分类
自然语3 小时前
数字生已经进化到一个分水岭面临选择?先实现“动态识别“还是先实现“特征信息归纳分类“,文中给出以给出答案,大家选哪个方向?
人工智能·分类·数据挖掘
RickyWasYoung4 小时前
【聚类算法】高维数据的聚类
算法·数据挖掘·聚类
R-G-B5 小时前
【P19 机器学习-分类算法及应用实践】手写数字识别(KNN)
python·机器学习·分类·手写数字识别·knn算法
我是哈哈hh16 小时前
【Python数据分析】Numpy总结
开发语言·python·数据挖掘·数据分析·numpy·python数据分析
小飞象—木兮21 小时前
【产品运营必备】数据分析实战宝典:从入门到精通,驱动业务增长(附相关材料下载)
大数据·数据挖掘·数据分析·产品运营
kong79069281 天前
大数据的特征和数据分析
大数据·数据挖掘·数据分析
weixin_457760001 天前
EIOU (Efficient IoU): 高效边界框回归损失的解析
人工智能·数据挖掘·回归
源于花海1 天前
迁移学习基础知识——总体思路和度量准则(距离和相似度)
人工智能·机器学习·迁移学习