相同的数据集,训练结果对比。传统的只有70%左右,迁移学习方式有95%的准确率。
传统训练方式:
训练模型:
python
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os
class FlowerClassifierTrainer:
"""花卉分类模型训练器"""
def __init__(self, img_height=180, img_width=180, batch_size=32):
self.img_height = img_height
self.img_width = img_width
self.batch_size = batch_size
self.model = None
self.history = None
self.class_names = None
self.train_ds = None
self.val_ds = None
def setup_gpu(self):
"""配置GPU设置"""
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# 设置内存增长
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# 设置逻辑GPU
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
return True
except RuntimeError as e:
print(e)
return False
def load_data(self, data_dir=None):
"""加载数据集"""
if data_dir is None:
# 下载默认数据集
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
# 创建训练和验证数据集
self.train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(self.img_height, self.img_width),
batch_size=self.batch_size
)
self.val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(self.img_height, self.img_width),
batch_size=self.batch_size
)
self.class_names = self.train_ds.class_names
print(f"类别名称: {self.class_names}")
print(f"类别数量: {len(self.class_names)}")
return self.train_ds, self.val_ds
def build_model(self):
"""构建模型"""
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal", input_shape=(self.img_height, self.img_width, 3)),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
])
num_classes = len(self.class_names)
self.model = Sequential([
data_augmentation,
layers.Rescaling(1./255),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, name="outputs")
])
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
return self.model
def train(self, epochs=15, verbose=1):
"""训练模型"""
if self.model is None:
self.build_model()
print("开始训练模型...")
self.history = self.model.fit(
self.train_ds,
validation_data=self.val_ds,
epochs=epochs,
verbose=verbose
)
return self.history
def evaluate(self):
"""评估模型"""
if self.model is None:
raise ValueError("模型尚未训练,请先调用 train() 方法")
print("评估模型在验证集上的表现:")
val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
print(f"验证损失: {val_loss:.4f}")
print(f"验证准确率: {val_accuracy:.4f}")
return val_loss, val_accuracy
def plot_training_history(self):
"""绘制训练历史"""
if self.history is None:
raise ValueError("没有训练历史数据,请先训练模型")
acc = self.history.history['accuracy']
val_acc = self.history.history['val_accuracy']
loss = self.history.history['loss']
val_loss = self.history.history['val_loss']
epochs_range = range(len(acc))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.tight_layout()
plt.show()
def save_model_simple(self, model_path='flower_classification_model.keras'):
"""简化版保存方法 - 推荐使用"""
if self.model is None:
raise ValueError("没有可保存的模型")
# 确保文件扩展名正确
if not model_path.endswith(('.keras', '.h5')):
model_path += '.keras' # 默认使用新的Keras格式
self.model.save(model_path)
print(f"模型已保存到: {model_path}")
# 保存类别名称
class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
with open(class_names_path, 'w', encoding='utf-8') as f:
json.dump(self.class_names, f, ensure_ascii=False, indent=2)
print(f"类别名称已保存到: {class_names_path}")
使用模型:
python
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os
class FlowerClassifier:
"""花卉分类预测器"""
def __init__(self, model_path='flower_classification_model.keras'):
self.model = None
self.class_names = None
self.img_height = 180
self.img_width = 180
self.load_model(model_path)
def load_model(self, model_path):
"""加载已训练的模型"""
if not os.path.exists(model_path):
# 如果指定的路径不存在,尝试其他可能的格式
possible_paths = [
model_path,
model_path.replace('.keras', '.h5'),
model_path.replace('.h5', '.keras'),
model_path + '.keras',
model_path + '.h5'
]
for path in possible_paths:
if os.path.exists(path):
model_path = path
break
else:
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载模型
self.model = tf.keras.models.load_model(model_path)
print(f"模型已从 {model_path} 加载")
# 加载类别名称
base_path = model_path.rsplit('.', 1)[0]
class_names_path = base_path + '_class_names.json'
if os.path.exists(class_names_path):
with open(class_names_path, 'r', encoding='utf-8') as f:
self.class_names = json.load(f)
print(f"类别名称已加载: {self.class_names}")
else:
print("警告: 未找到类别名称文件,将使用默认类别索引")
self.class_names = None
# 尝试加载训练配置
config_path = base_path + '_training_config.json'
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.img_height = config.get('img_height', 180)
self.img_width = config.get('img_width', 180)
print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
def predict_image(self, image_path):
"""预测单张图片"""
if self.model is None:
raise ValueError("模型未加载")
# 加载和预处理图片
img = tf.keras.utils.load_img(
image_path, target_size=(self.img_height, self.img_width)
)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # 创建批次
# 进行预测
predictions = self.model.predict(img_array, verbose=0)
scores = tf.nn.softmax(predictions[0])
# 获取预测结果
predicted_class_idx = np.argmax(scores)
confidence = 100 * np.max(scores)
if self.class_names is not None:
predicted_class = self.class_names[predicted_class_idx]
else:
predicted_class = f"Class_{predicted_class_idx}"
return {
'predicted_class': predicted_class,
'confidence': confidence,
'all_scores': scores.numpy(),
'class_index': predicted_class_idx
}
def predict_image_from_url(self, image_url, image_name='temp_image'):
"""从URL下载图片并进行预测"""
try:
image_path = tf.keras.utils.get_file(image_name, origin=image_url)
return self.predict_image(image_path)
except Exception as e:
print(f"从URL加载图片时出错: {e}")
return None
def predict_batch(self, image_dir):
"""批量预测目录中的图片"""
if self.model is None:
raise ValueError("模型未加载")
results = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
for file_path in pathlib.Path(image_dir).glob('*'):
if file_path.suffix.lower() in image_extensions:
try:
result = self.predict_image(str(file_path))
result['file_path'] = str(file_path)
results.append(result)
except Exception as e:
print(f"处理图片 {file_path} 时出错: {e}")
return results
def display_prediction(self, prediction_result, show_image=True):
"""显示预测结果"""
if prediction_result is None:
print("没有预测结果")
return
print(
"预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
.format(prediction_result['predicted_class'], prediction_result['confidence'])
)
# 显示所有类别的置信度
if self.class_names is not None:
print("\n所有类别置信度:")
for i, class_name in enumerate(self.class_names):
score = prediction_result['all_scores'][i] * 100
print(f" {class_name}: {score:.2f}%")
main:
python
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer
def trainerModel():
# 示例1: 训练模型
print("=== 训练模型 ===")
trainer = FlowerClassifierTrainer()
# 设置GPU
gpu_available = trainer.setup_gpu()
if gpu_available:
print("🎉 使用GPU进行训练!")
else:
print("⚠️ 没有可用的GPU,使用CPU")
# 加载数据
trainer.load_data()
# 构建和训练模型
trainer.build_model()
trainer.model.summary()
# 训练模型
history = trainer.train(epochs=15)
# 评估模型
trainer.evaluate()
# 绘制训练历史
trainer.plot_training_history()
# 保存模型
trainer.save_model_simple('flower_classification_model.keras')
# 使用示例
if __name__ == "__main__":
print("\n" + "="*50 + "\n")
# 示例2: 使用训练好的模型进行预测
print("=== 使用模型进行预测 ===")
classifier = FlowerClassifier()
# 示例图片URL预测
sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
print("从URL预测图片:")
result = classifier.predict_image_from_url(sunflower_url, 'Red_sunflower')
classifier.display_prediction(result)
print("\n" + "="*50 + "\n")
结果:75%左右
迁移学习方式:
训练模型:
python
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os
class FlowerClassifierTrainer:
"""花卉分类模型训练器 - 迁移学习版本"""
def __init__(self, img_height=224, img_width=224, batch_size=32):
self.img_height = img_height
self.img_width = img_width
self.batch_size = batch_size
self.model = None
self.history = None
self.class_names = None
self.train_ds = None
self.val_ds = None
self.base_model = None
def setup_gpu(self):
"""配置GPU设置"""
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# 设置内存增长
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# 设置逻辑GPU
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(f"{len(gpus)} Physical GPU, {len(logical_gpus)} Logical GPU")
return True
except RuntimeError as e:
print(e)
return False
def load_data(self, data_dir=None):
"""加载数据集"""
if data_dir is None:
# 下载默认数据集
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file("flower_photos.tgz", origin=dataset_url, extract=True)
data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
# 创建训练和验证数据集
self.train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(self.img_height, self.img_width),
batch_size=self.batch_size
)
self.val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(self.img_height, self.img_width),
batch_size=self.batch_size
)
self.class_names = self.train_ds.class_names
print(f"类别名称: {self.class_names}")
print(f"类别数量: {len(self.class_names)}")
# 优化数据管道
AUTOTUNE = tf.data.AUTOTUNE
self.train_ds = self.train_ds.prefetch(buffer_size=AUTOTUNE)
self.val_ds = self.val_ds.prefetch(buffer_size=AUTOTUNE)
return self.train_ds, self.val_ds
def build_model_with_transfer_learning(self, base_model_name='MobileNetV2', fine_tune=False):
"""使用迁移学习构建模型
Args:
base_model_name: 预训练模型名称 ('MobileNetV2', 'EfficientNetB0', 'ResNet50')
fine_tune: 是否进行微调(解冻部分层)
"""
# 选择预训练模型
if base_model_name == 'MobileNetV2':
self.base_model = tf.keras.applications.MobileNetV2(
input_shape=(self.img_height, self.img_width, 3),
include_top=False,
weights='imagenet'
)
elif base_model_name == 'EfficientNetB0':
self.base_model = tf.keras.applications.EfficientNetB0(
input_shape=(self.img_height, self.img_width, 3),
include_top=False,
weights='imagenet'
)
elif base_model_name == 'ResNet50':
self.base_model = tf.keras.applications.ResNet50(
input_shape=(self.img_height, self.img_width, 3),
include_top=False,
weights='imagenet'
)
else:
raise ValueError(f"不支持的模型: {base_model_name}")
print(f"使用预训练模型: {base_model_name}")
# 第一阶段:冻结基础模型,只训练顶层
self.base_model.trainable = False
num_classes = len(self.class_names)
# 数据增强
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomContrast(0.2),
])
# 构建完整模型
self.model = Sequential([
data_augmentation,
layers.Rescaling(1./127.5, offset=-1), # 归一化到[-1, 1]
self.base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.3),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(num_classes, activation='softmax')
])
# 编译模型
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print("迁移学习模型构建完成(基础模型冻结)")
return self.model
def fine_tune_model(self, fine_tune_layers=100):
"""微调模型 - 解冻部分基础模型层
Args:
fine_tune_layers: 要解冻的层数
"""
if self.base_model is None:
raise ValueError("请先调用 build_model_with_transfer_learning()")
# 解冻基础模型的部分层
self.base_model.trainable = True
# 冻结前面的层,只训练后面的层
for layer in self.base_model.layers[:-fine_tune_layers]:
layer.trainable = False
# 重新编译模型,使用更小的学习率
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), # 更小的学习率
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print(f"模型已设置为微调模式,解冻了最后 {fine_tune_layers} 层")
print(f"使用学习率: 0.0001")
def train(self, epochs_initial=10, epochs_fine_tune=10, verbose=1):
"""训练模型(两阶段训练)
Args:
epochs_initial: 初始训练轮数(基础模型冻结)
epochs_fine_tune: 微调训练轮数(基础模型部分解冻)
"""
if self.model is None:
self.build_model_with_transfer_learning()
# 定义回调函数
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_accuracy',
patience=5,
restore_best_weights=True,
mode='max'
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=3,
min_lr=1e-7
)
]
print("=" * 50)
print("第一阶段:基础模型冻结训练")
print("=" * 50)
# 第一阶段训练:基础模型冻结
history_initial = self.model.fit(
self.train_ds,
validation_data=self.val_ds,
epochs=epochs_initial,
callbacks=callbacks,
verbose=verbose
)
# 第二阶段:微调
print("=" * 50)
print("第二阶段:模型微调")
print("=" * 50)
self.fine_tune_model(fine_tune_layers=100)
history_fine_tune = self.model.fit(
self.train_ds,
validation_data=self.val_ds,
epochs=epochs_fine_tune,
callbacks=callbacks,
verbose=verbose
)
# 合并训练历史
self.history = {
'accuracy': history_initial.history['accuracy'] + history_fine_tune.history['accuracy'],
'val_accuracy': history_initial.history['val_accuracy'] + history_fine_tune.history['val_accuracy'],
'loss': history_initial.history['loss'] + history_fine_tune.history['loss'],
'val_loss': history_initial.history['val_loss'] + history_fine_tune.history['val_loss']
}
return self.history
def evaluate(self):
"""评估模型"""
if self.model is None:
raise ValueError("模型尚未训练,请先调用 train() 方法")
print("评估模型在验证集上的表现:")
val_loss, val_accuracy = self.model.evaluate(self.val_ds, verbose=0)
print(f"验证损失: {val_loss:.4f}")
print(f"验证准确率: {val_accuracy:.4f}")
return val_loss, val_accuracy
def plot_training_history(self):
"""绘制训练历史"""
if self.history is None:
raise ValueError("没有训练历史数据,请先训练模型")
acc = self.history['accuracy']
val_acc = self.history['val_accuracy']
loss = self.history['loss']
val_loss = self.history['val_loss']
epochs_range = range(len(acc))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.axvline(x=len(self.history['accuracy']) - len([x for x in self.history['accuracy'] if x > 0])//2,
color='r', linestyle='--', alpha=0.7, label='开始微调')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.axvline(x=len(self.history['loss']) - len([x for x in self.history['loss'] if x > 0])//2,
color='r', linestyle='--', alpha=0.7, label='开始微调')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.tight_layout()
plt.show()
def save_model(self, model_path='flower_classification_model_transfer.keras'):
"""保存模型"""
if self.model is None:
raise ValueError("没有可保存的模型")
# 确保文件扩展名正确
if not model_path.endswith(('.keras', '.h5')):
model_path += '.keras'
self.model.save(model_path)
print(f"模型已保存到: {model_path}")
# 保存类别名称
class_names_path = model_path.rsplit('.', 1)[0] + '_class_names.json'
with open(class_names_path, 'w', encoding='utf-8') as f:
json.dump(self.class_names, f, ensure_ascii=False, indent=2)
print(f"类别名称已保存到: {class_names_path}")
# 保存训练配置
config_path = model_path.rsplit('.', 1)[0] + '_training_config.json'
config = {
'img_height': self.img_height,
'img_width': self.img_width,
'batch_size': self.batch_size,
'model_type': 'transfer_learning'
}
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
print(f"训练配置已保存到: {config_path}")
使用模型:
python
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import os
class FlowerClassifier:
"""花卉分类预测器 - 兼容迁移学习模型"""
def __init__(self, model_path='flower_classification_model_transfer.keras'):
self.model = None
self.class_names = None
self.img_height = 224 # 默认使用迁移学习的标准尺寸
self.img_width = 224
self.load_model(model_path)
def load_model(self, model_path):
"""加载已训练的模型"""
if not os.path.exists(model_path):
# 如果指定的路径不存在,尝试其他可能的格式
possible_paths = [
model_path,
model_path.replace('.keras', '.h5'),
model_path.replace('.h5', '.keras'),
model_path + '.keras',
model_path + '.h5'
]
for path in possible_paths:
if os.path.exists(path):
model_path = path
break
else:
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载模型
self.model = tf.keras.models.load_model(model_path)
print(f"模型已从 {model_path} 加载")
# 加载类别名称
base_path = model_path.rsplit('.', 1)[0]
class_names_path = base_path + '_class_names.json'
if os.path.exists(class_names_path):
with open(class_names_path, 'r', encoding='utf-8') as f:
self.class_names = json.load(f)
print(f"类别名称已加载: {self.class_names}")
else:
print("警告: 未找到类别名称文件,将使用默认类别索引")
self.class_names = None
# 尝试加载训练配置
config_path = base_path + '_training_config.json'
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.img_height = config.get('img_height', 224)
self.img_width = config.get('img_width', 224)
print(f"图像尺寸配置: {self.img_height}x{self.img_width}")
def predict_image(self, image_path):
"""预测单张图片"""
if self.model is None:
raise ValueError("模型未加载")
# 加载和预处理图片
img = tf.keras.utils.load_img(
image_path, target_size=(self.img_height, self.img_width)
)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # 创建批次
# 进行预测
predictions = self.model.predict(img_array, verbose=0)
scores = tf.nn.softmax(predictions[0])
# 获取预测结果
predicted_class_idx = np.argmax(scores)
confidence = 100 * np.max(scores)
if self.class_names is not None:
predicted_class = self.class_names[predicted_class_idx]
else:
predicted_class = f"Class_{predicted_class_idx}"
return {
'predicted_class': predicted_class,
'confidence': confidence,
'all_scores': scores.numpy(),
'class_index': predicted_class_idx
}
def predict_image_from_url(self, image_url, image_name='temp_image'):
"""从URL下载图片并进行预测"""
try:
image_path = tf.keras.utils.get_file(image_name, origin=image_url)
return self.predict_image(image_path)
except Exception as e:
print(f"从URL加载图片时出错: {e}")
return None
def predict_batch(self, image_dir):
"""批量预测目录中的图片"""
if self.model is None:
raise ValueError("模型未加载")
results = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
for file_path in pathlib.Path(image_dir).glob('*'):
if file_path.suffix.lower() in image_extensions:
try:
result = self.predict_image(str(file_path))
result['file_path'] = str(file_path)
results.append(result)
except Exception as e:
print(f"处理图片 {file_path} 时出错: {e}")
return results
def display_prediction(self, prediction_result, show_image=True):
"""显示预测结果"""
if prediction_result is None:
print("没有预测结果")
return
print(
"预测结果: 该图片最可能属于 '{}', 置信度: {:.2f}%"
.format(prediction_result['predicted_class'], prediction_result['confidence'])
)
# 显示所有类别的置信度
if self.class_names is not None:
print("\n所有类别置信度:")
for i, class_name in enumerate(self.class_names):
score = prediction_result['all_scores'][i] * 100
print(f" {class_name}: {score:.2f}%")
main:
python
from FlowerClassifier import FlowerClassifier
from FlowerClassifierTrainer import FlowerClassifierTrainer
# 使用示例
def train_flower_classifier():
"""训练迁移学习花卉分类器"""
trainer = FlowerClassifierTrainer(img_height=224, img_width=224, batch_size=32)
trainer.setup_gpu()
trainer.load_data()
# 使用迁移学习构建模型
trainer.build_model_with_transfer_learning('MobileNetV2')
# 训练模型(两阶段)
history = trainer.train(epochs_initial=10, epochs_fine_tune=10)
# 评估模型
trainer.evaluate()
# 绘制训练历史
trainer.plot_training_history()
# 保存模型
trainer.save_model('flower_model_transfer_learning.keras')
return trainer
if __name__ == "__main__":
# 训练模型
trainer = train_flower_classifier()
# 使用训练好的模型进行预测
classifier = FlowerClassifier('flower_model_transfer_learning.keras')
# 示例预测(需要实际图片路径)
# result = classifier.predict_image('path_to_your_flower_image.jpg')
# classifier.display_prediction(result)
结果:准确率接近95%