图像分类模型 传统训练VS迁移学习训练

相同的数据集,训练结果对比。传统的只有70%左右,迁移学习方式有95%的准确率。

传统训练方式:

训练模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifierTrainer:
    """花卉分类模型训练器"""
    
    def __init__(self, img_height=180, img_width=180, batch_size=32):
        self.img_height = img_height
        self.img_width = img_width
        self.batch_size = batch_size
        self.model = None
        self.history = None
        self.class_names = None
        self.train_ds = None
        self.val_ds = None
        
    def setup_gpu(self):
        """配置GPU设置"""
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # 设置内存增长
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                # 设置逻辑GPU
                logical_gpus = tf.config.experimental.list_logical_devices('GPU')
                print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
                return True
            except RuntimeError as e:
                print(e)
        return False
    
    def load_data(self, data_dir=None):
        """加载数据集"""
        if data_dir is None:
            # 下载默认数据集
            dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
            data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
            data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
        
        # 创建训练和验证数据集
        self.train_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.val_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.class_names = self.train_ds.class_names
        print(f"类别名称: {self.class_names}")
        print(f"类别数量: {len(self.class_names)}")
        
        return self.train_ds, self.val_ds
    
    def build_model(self):
        """构建模型"""
        data_augmentation = keras.Sequential([
            layers.RandomFlip("horizontal", input_shape=(self.img_height, self.img_width, 3)),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.1),
        ])
        
        num_classes = len(self.class_names)
        
        self.model = Sequential([
            data_augmentation,
            layers.Rescaling(1./255),
            layers.Conv2D(16, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Conv2D(32, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Conv2D(64, 3, padding='same', activation='relu'),
            layers.MaxPooling2D(),
            layers.Dropout(0.2),
            layers.Flatten(),
            layers.Dense(128, activation='relu'),
            layers.Dense(num_classes, name="outputs")
        ])
        
        self.model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
        
        return self.model
    
    def train(self, epochs=15, verbose=1):
        """训练模型"""
        if self.model is None:
            self.build_model()
        
        print("开始训练模型...")
        self.history = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs,
            verbose=verbose
        )
        
        return self.history
    
    def evaluate(self):
        """评估模型"""
        if self.model is None:
            raise ValueError("模型尚未训练,请先调用 train() 方法")
        
        print("评估模型在验证集上的表现:")
        val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
        print(f"验证损失: {val_loss:.4f}")
        print(f"验证准确率: {val_accuracy:.4f}")
        
        return val_loss, val_accuracy
    
    def plot_training_history(self):
        """绘制训练历史"""
        if self.history is None:
            raise ValueError("没有训练历史数据,请先训练模型")
        
        acc = self.history.history['accuracy']
        val_acc = self.history.history['val_accuracy']
        loss = self.history.history['loss']
        val_loss = self.history.history['val_loss']
        
        epochs_range = range(len(acc))
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        
        plt.subplot(1, 2, 2)
        plt.plot(epochs_range, loss, label='Training Loss')
        plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        
        plt.tight_layout()
        plt.show()
    
    def save_model_simple(self, model_path='flower_classification_model.keras'):
        """简化版保存方法 - 推荐使用"""
        if self.model is None:
            raise ValueError("没有可保存的模型")
        
        # 确保文件扩展名正确
        if not model_path.endswith(('.keras', '.h5')):
            model_path += '.keras'  # 默认使用新的Keras格式
        
        self.model.save(model_path)
        print(f"模型已保存到: {model_path}")
        
        # 保存类别名称
        class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
        with open(class_names_path, 'w', encoding='utf-8') as f:
            json.dump(self.class_names, f, ensure_ascii=False, indent=2)
        print(f"类别名称已保存到: {class_names_path}")

使用模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifier:
    """花卉分类预测器"""
    
    def __init__(self, model_path='flower_classification_model.keras'):
        self.model = None
        self.class_names = None
        self.img_height = 180
        self.img_width = 180
        self.load_model(model_path)
    
    def load_model(self, model_path):
        """加载已训练的模型"""
        if not os.path.exists(model_path):
            # 如果指定的路径不存在,尝试其他可能的格式
            possible_paths = [
                model_path,
                model_path.replace('.keras', '.h5'),
                model_path.replace('.h5', '.keras'),
                model_path + '.keras',
                model_path + '.h5'
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            else:
                raise FileNotFoundError(f"模型文件不存在: {model_path}")
        
        # 加载模型
        self.model = tf.keras.models.load_model(model_path)
        print(f"模型已从 {model_path} 加载")
        
        # 加载类别名称
        base_path = model_path.rsplit('.', 1)[0]
        class_names_path = base_path + '_class_names.json'
        
        if os.path.exists(class_names_path):
            with open(class_names_path, 'r', encoding='utf-8') as f:
                self.class_names = json.load(f)
            print(f"类别名称已加载: {self.class_names}")
        else:
            print("警告: 未找到类别名称文件,将使用默认类别索引")
            self.class_names = None
        
        # 尝试加载训练配置
        config_path = base_path + '_training_config.json'
        if os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
                self.img_height = config.get('img_height', 180)
                self.img_width = config.get('img_width', 180)
            print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
    
    def predict_image(self, image_path):
        """预测单张图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        # 加载和预处理图片
        img = tf.keras.utils.load_img(
            image_path, target_size=(self.img_height, self.img_width)
        )
        img_array = tf.keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)  # 创建批次
        
        # 进行预测
        predictions = self.model.predict(img_array, verbose=0)
        scores = tf.nn.softmax(predictions[0])
        
        # 获取预测结果
        predicted_class_idx = np.argmax(scores)
        confidence = 100 * np.max(scores)
        
        if self.class_names is not None:
            predicted_class = self.class_names[predicted_class_idx]
        else:
            predicted_class = f"Class_{predicted_class_idx}"
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'all_scores': scores.numpy(),
            'class_index': predicted_class_idx
        }
    
    def predict_image_from_url(self, image_url, image_name='temp_image'):
        """从URL下载图片并进行预测"""
        try:
            image_path = tf.keras.utils.get_file(image_name, origin=image_url)
            return self.predict_image(image_path)
        except Exception as e:
            print(f"从URL加载图片时出错: {e}")
            return None
    
    def predict_batch(self, image_dir):
        """批量预测目录中的图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        results = []
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        
        for file_path in pathlib.Path(image_dir).glob('*'):
            if file_path.suffix.lower() in image_extensions:
                try:
                    result = self.predict_image(str(file_path))
                    result['file_path'] = str(file_path)
                    results.append(result)
                except Exception as e:
                    print(f"处理图片 {file_path} 时出错: {e}")
        
        return results
    
    def display_prediction(self, prediction_result, show_image=True):
        """显示预测结果"""
        if prediction_result is None:
            print("没有预测结果")
            return
        
        print(
            "预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
            .format(prediction_result['predicted_class'], prediction_result['confidence'])
        )
        
        # 显示所有类别的置信度
        if self.class_names is not None:
            print("\n所有类别置信度:")
            for i, class_name in enumerate(self.class_names):
                score = prediction_result['all_scores'][i] * 100
                print(f"  {class_name}: {score:.2f}%")

main:

python 复制代码
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer



def trainerModel():
     # 示例1: 训练模型
    print("=== 训练模型 ===")
    trainer = FlowerClassifierTrainer()
    
    # 设置GPU
    gpu_available = trainer.setup_gpu()
    if gpu_available:
        print("🎉 使用GPU进行训练!")
    else:
        print("⚠️ 没有可用的GPU,使用CPU")
    
    # 加载数据
    trainer.load_data()
    
    # 构建和训练模型
    trainer.build_model()
    trainer.model.summary()
    
    # 训练模型
    history = trainer.train(epochs=15)
    
    # 评估模型
    trainer.evaluate()
    
    # 绘制训练历史
    trainer.plot_training_history()
    
    # 保存模型
    trainer.save_model_simple('flower_classification_model.keras')



# 使用示例
if __name__ == "__main__":
   
    
    print("\n" + "="*50 + "\n")
    
    # 示例2: 使用训练好的模型进行预测
    print("=== 使用模型进行预测 ===")
    classifier = FlowerClassifier()
    
    # 示例图片URL预测
    sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
    
    print("从URL预测图片:")
    result = classifier.predict_image_from_url(sunflower_url, 'Red_sunflower')
    classifier.display_prediction(result)
    
    print("\n" + "="*50 + "\n")

结果:75%左右

迁移学习方式:

训练模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os


class FlowerClassifierTrainer:
    """花卉分类模型训练器 - 迁移学习版本"""
    
    def __init__(self, img_height=224, img_width=224, batch_size=32):
        self.img_height = img_height
        self.img_width = img_width
        self.batch_size = batch_size
        self.model = None
        self.history = None
        self.class_names = None
        self.train_ds = None
        self.val_ds = None
        self.base_model = None
        
    def setup_gpu(self):
        """配置GPU设置"""
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # 设置内存增长
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                # 设置逻辑GPU
                logical_gpus = tf.config.experimental.list_logical_devices('GPU')
                print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
                return True
            except RuntimeError as e:
                print(e)
        return False
    
    def load_data(self, data_dir=None):
        """加载数据集"""
        if data_dir is None:
            # 下载默认数据集
            dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
            data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
            data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
        
        # 创建训练和验证数据集
        self.train_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.val_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(self.img_height, self.img_width),
            batch_size=self.batch_size
        )
        
        self.class_names = self.train_ds.class_names
        print(f"类别名称: {self.class_names}")
        print(f"类别数量: {len(self.class_names)}")
        
        # 优化数据管道
        AUTOTUNE = tf.data.AUTOTUNE
        self.train_ds = self.train_ds.prefetch(buffer_size=AUTOTUNE)
        self.val_ds = self.val_ds.prefetch(buffer_size=AUTOTUNE)
        
        return self.train_ds, self.val_ds
    
    def build_model_with_transfer_learning(self, base_model_name='MobileNetV2', fine_tune=False):
        """使用迁移学习构建模型
        
        Args:
            base_model_name: 预训练模型名称 ('MobileNetV2', 'EfficientNetB0', 'ResNet50')
            fine_tune: 是否进行微调(解冻部分层)
        """
        # 选择预训练模型
        if base_model_name == 'MobileNetV2':
            self.base_model = tf.keras.applications.MobileNetV2(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        elif base_model_name == 'EfficientNetB0':
            self.base_model = tf.keras.applications.EfficientNetB0(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        elif base_model_name == 'ResNet50':
            self.base_model = tf.keras.applications.ResNet50(
                input_shape=(self.img_height, self.img_width, 3),
                include_top=False,
                weights='imagenet'
            )
        else:
            raise ValueError(f"不支持的模型: {base_model_name}")
        
        print(f"使用预训练模型: {base_model_name}")
        
        # 第一阶段:冻结基础模型,只训练顶层
        self.base_model.trainable = False
        
        num_classes = len(self.class_names)
        
        # 数据增强
        data_augmentation = keras.Sequential([
            layers.RandomFlip("horizontal_and_vertical"),
            layers.RandomRotation(0.2),
            layers.RandomZoom(0.2),
            layers.RandomContrast(0.2),
        ])
        
        # 构建完整模型
        self.model = Sequential([
            data_augmentation,
            layers.Rescaling(1./127.5, offset=-1),  # 归一化到[-1, 1]
            self.base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dropout(0.3),
            layers.Dense(128, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(num_classes, activation='softmax')
        ])
        
        # 编译模型
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        print("迁移学习模型构建完成(基础模型冻结)")
        return self.model
    
    def fine_tune_model(self, fine_tune_layers=100):
        """微调模型 - 解冻部分基础模型层
        
        Args:
            fine_tune_layers: 要解冻的层数
        """
        if self.base_model is None:
            raise ValueError("请先调用 build_model_with_transfer_learning()")
        
        # 解冻基础模型的部分层
        self.base_model.trainable = True
        
        # 冻结前面的层,只训练后面的层
        for layer in self.base_model.layers[:-fine_tune_layers]:
            layer.trainable = False
        
        # 重新编译模型,使用更小的学习率
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),  # 更小的学习率
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        print(f"模型已设置为微调模式,解冻了最后 {fine_tune_layers} 层")
        print(f"使用学习率: 0.0001")
    
    def train(self, epochs_initial=10, epochs_fine_tune=10, verbose=1):
        """训练模型(两阶段训练)
        
        Args:
            epochs_initial: 初始训练轮数(基础模型冻结)
            epochs_fine_tune: 微调训练轮数(基础模型部分解冻)
        """
        if self.model is None:
            self.build_model_with_transfer_learning()
        
        # 定义回调函数
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=5,
                restore_best_weights=True,
                mode='max'
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=3,
                min_lr=1e-7
            )
        ]
        
        print("=" * 50)
        print("第一阶段:基础模型冻结训练")
        print("=" * 50)
        
        # 第一阶段训练:基础模型冻结
        history_initial = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs_initial,
            callbacks=callbacks,
            verbose=verbose
        )
        
        # 第二阶段:微调
        print("=" * 50)
        print("第二阶段:模型微调")
        print("=" * 50)
        
        self.fine_tune_model(fine_tune_layers=100)
        
        history_fine_tune = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs_fine_tune,
            callbacks=callbacks,
            verbose=verbose
        )
        
        # 合并训练历史
        self.history = {
            'accuracy': history_initial.history['accuracy'] + history_fine_tune.history['accuracy'],
            'val_accuracy': history_initial.history['val_accuracy'] + history_fine_tune.history['val_accuracy'],
            'loss': history_initial.history['loss'] + history_fine_tune.history['loss'],
            'val_loss': history_initial.history['val_loss'] + history_fine_tune.history['val_loss']
        }
        
        return self.history
    
    def evaluate(self):
        """评估模型"""
        if self.model is None:
            raise ValueError("模型尚未训练,请先调用 train() 方法")
        
        print("评估模型在验证集上的表现:")
        val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
        print(f"验证损失: {val_loss:.4f}")
        print(f"验证准确率: {val_accuracy:.4f}")
        
        return val_loss, val_accuracy
    
    def plot_training_history(self):
        """绘制训练历史"""
        if self.history is None:
            raise ValueError("没有训练历史数据,请先训练模型")
        
        acc = self.history['accuracy']
        val_acc = self.history['val_accuracy']
        loss = self.history['loss']
        val_loss = self.history['val_loss']
        
        epochs_range = range(len(acc))
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.axvline(x=len(self.history['accuracy']) - len([x for x in self.history['accuracy'] if x > 0])//2, 
                   color='r', linestyle='--', alpha=0.7, label='开始微调')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        
        plt.subplot(1, 2, 2)
        plt.plot(epochs_range, loss, label='Training Loss')
        plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.axvline(x=len(self.history['loss']) - len([x for x in self.history['loss'] if x > 0])//2, 
                   color='r', linestyle='--', alpha=0.7, label='开始微调')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        
        plt.tight_layout()
        plt.show()
    
    def save_model(self, model_path='flower_classification_model_transfer.keras'):
        """保存模型"""
        if self.model is None:
            raise ValueError("没有可保存的模型")
        
        # 确保文件扩展名正确
        if not model_path.endswith(('.keras', '.h5')):
            model_path += '.keras'
        
        self.model.save(model_path)
        print(f"模型已保存到: {model_path}")
        
        # 保存类别名称
        class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
        with open(class_names_path, 'w', encoding='utf-8') as f:
            json.dump(self.class_names, f, ensure_ascii=False, indent=2)
        print(f"类别名称已保存到: {class_names_path}")
        
        # 保存训练配置
        config_path = model_path.rsplit('.', 1)[0] + '_training_config.json'
        config = {
            'img_height': self.img_height,
            'img_width': self.img_width,
            'batch_size': self.batch_size,
            'model_type': 'transfer_learning'
        }
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(config, f, ensure_ascii=False, indent=2)
        print(f"训练配置已保存到: {config_path}")

使用模型:

python 复制代码
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os

class FlowerClassifier:
    """花卉分类预测器 - 兼容迁移学习模型"""
    
    def __init__(self, model_path='flower_classification_model_transfer.keras'):
        self.model = None
        self.class_names = None
        self.img_height = 224  # 默认使用迁移学习的标准尺寸
        self.img_width = 224
        self.load_model(model_path)
    
    def load_model(self, model_path):
        """加载已训练的模型"""
        if not os.path.exists(model_path):
            # 如果指定的路径不存在,尝试其他可能的格式
            possible_paths = [
                model_path,
                model_path.replace('.keras', '.h5'),
                model_path.replace('.h5', '.keras'),
                model_path + '.keras',
                model_path + '.h5'
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            else:
                raise FileNotFoundError(f"模型文件不存在: {model_path}")
        
        # 加载模型
        self.model = tf.keras.models.load_model(model_path)
        print(f"模型已从 {model_path} 加载")
        
        # 加载类别名称
        base_path = model_path.rsplit('.', 1)[0]
        class_names_path = base_path + '_class_names.json'
        
        if os.path.exists(class_names_path):
            with open(class_names_path, 'r', encoding='utf-8') as f:
                self.class_names = json.load(f)
            print(f"类别名称已加载: {self.class_names}")
        else:
            print("警告: 未找到类别名称文件,将使用默认类别索引")
            self.class_names = None
        
        # 尝试加载训练配置
        config_path = base_path + '_training_config.json'
        if os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
                self.img_height = config.get('img_height', 224)
                self.img_width = config.get('img_width', 224)
            print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
    
    def predict_image(self, image_path):
        """预测单张图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        # 加载和预处理图片
        img = tf.keras.utils.load_img(
            image_path, target_size=(self.img_height, self.img_width)
        )
        img_array = tf.keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)  # 创建批次
        
        # 进行预测
        predictions = self.model.predict(img_array, verbose=0)
        scores = tf.nn.softmax(predictions[0])
        
        # 获取预测结果
        predicted_class_idx = np.argmax(scores)
        confidence = 100 * np.max(scores)
        
        if self.class_names is not None:
            predicted_class = self.class_names[predicted_class_idx]
        else:
            predicted_class = f"Class_{predicted_class_idx}"
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'all_scores': scores.numpy(),
            'class_index': predicted_class_idx
        }
    
    def predict_image_from_url(self, image_url, image_name='temp_image'):
        """从URL下载图片并进行预测"""
        try:
            image_path = tf.keras.utils.get_file(image_name, origin=image_url)
            return self.predict_image(image_path)
        except Exception as e:
            print(f"从URL加载图片时出错: {e}")
            return None
    
    def predict_batch(self, image_dir):
        """批量预测目录中的图片"""
        if self.model is None:
            raise ValueError("模型未加载")
        
        results = []
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        
        for file_path in pathlib.Path(image_dir).glob('*'):
            if file_path.suffix.lower() in image_extensions:
                try:
                    result = self.predict_image(str(file_path))
                    result['file_path'] = str(file_path)
                    results.append(result)
                except Exception as e:
                    print(f"处理图片 {file_path} 时出错: {e}")
        
        return results
    
    def display_prediction(self, prediction_result, show_image=True):
        """显示预测结果"""
        if prediction_result is None:
            print("没有预测结果")
            return
        
        print(
            "预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
            .format(prediction_result['predicted_class'], prediction_result['confidence'])
        )
        
        # 显示所有类别的置信度
        if self.class_names is not None:
            print("\n所有类别置信度:")
            for i, class_name in enumerate(self.class_names):
                score = prediction_result['all_scores'][i] * 100
                print(f"  {class_name}: {score:.2f}%")

main:

python 复制代码
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer


# 使用示例
def train_flower_classifier():
    """训练迁移学习花卉分类器"""
    trainer = FlowerClassifierTrainer(img_height=224, img_width=224, batch_size=32)
    trainer.setup_gpu()
    trainer.load_data()
    
    # 使用迁移学习构建模型
    trainer.build_model_with_transfer_learning('MobileNetV2')
    
    # 训练模型(两阶段)
    history = trainer.train(epochs_initial=10, epochs_fine_tune=10)
    
    # 评估模型
    trainer.evaluate()
    
    # 绘制训练历史
    trainer.plot_training_history()
    
    # 保存模型
    trainer.save_model('flower_model_transfer_learning.keras')
    
    return trainer


if __name__ == "__main__":
    # 训练模型
    trainer = train_flower_classifier()
    
    # 使用训练好的模型进行预测
    classifier = FlowerClassifier('flower_model_transfer_learning.keras')
    
    # 示例预测(需要实际图片路径)
    # result = classifier.predict_image('path_to_your_flower_image.jpg')
    # classifier.display_prediction(result)

结果:准确率接近95%

相关推荐
大千AI助手1 天前
Hoeffding树:数据流挖掘中的高效分类算法详解
人工智能·机器学习·分类·数据挖掘·流数据··hoeffding树
新子y1 天前
【小白笔记】区分类方法/实例方法和静态函数/命名空间函数
笔记·分类
大千AI助手1 天前
独热编码:分类数据处理的基石技术
人工智能·机器学习·分类·数据挖掘·特征工程·one-hot·独热编码
十三画者2 天前
【文献分享】acmgscaler:用于在 ACMG/AMP 框架内对基因层面的变异效应得分进行标准化校准。
数据挖掘·数据分析·r语言
JJJJ_iii2 天前
【机器学习07】 激活函数精讲、Softmax多分类与优化器进阶
人工智能·笔记·python·算法·机器学习·分类·线性回归
星期天要睡觉2 天前
深度学习——基于 PyTorch 的蔬菜图像分类
人工智能·pytorch·python·深度学习·分类
言德斐2 天前
数据挖掘知识体系分析
人工智能·数据挖掘
nju_spy2 天前
复杂结构数据挖掘(三)关联规则挖掘实验
人工智能·数据挖掘·apriori·网格搜索·关联规则挖掘·fp-growth·位运算状态枚举
来酱何人2 天前
实时NLP数据处理:流数据的清洗、特征提取与模型推理适配
人工智能·深度学习·分类·nlp·bert
wktomo3 天前
数据挖掘比赛baseline参考
人工智能·数据挖掘