import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.applications.inception_v3 import InceptionV3
from keras.layers import Dense, Flatten, Concatenate, Input
from keras.models import Model
from tensorflow.keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.preprocessing import image
from sklearn.model_selection import train_test_split
# 定义超参数
IMG_WIDTH, IMG_HEIGHT = 500, 500
BATCH_SIZE = 32
EPOCHS = 300
LEARNING_RATE = 0.001
# 输入参数:图像文件目录和骨龄信息CSV文件路径
images_dir = r'D:\python_code\Bone-Age-Assessment-master\train'
boneage_info_path = r'D:\boneage_data\boneage-training-dataset.csv'
# 读取骨龄信息CSV文件
boneage_info = pd.read_csv(boneage_info_path)
# 确保id列的值为字符串类型
boneage_info['ID'] = boneage_info['ID'].astype(str)
# 使用 sklearn 自动划分训练集和验证集
train_boneage_info, val_boneage_info = train_test_split(boneage_info, test_size=0.2, random_state=42)
# 创建一个ImageDataGenerator用于加载和预处理图像
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
# 自定义生成器类
class CustomDataGenerator(tf.keras.utils.Sequence):
def __init__(self, df, image_dir, batch_size, img_size):
self.df = df
self.image_dir = image_dir
self.batch_size = batch_size
self.img_size = img_size
self.on_epoch_end()
def __len__(self):
return int(np.ceil(len(self.df) / float(self.batch_size)))
def __getitem__(self, idx):
batch_df = self.df[idx * self.batch_size:(idx + 1) * self.batch_size]
X, y = self.__data_generation(batch_df)
return X, y
def on_epoch_end(self):
self.indexes = np.arange(len(self.df))
np.random.shuffle(self.indexes)
def __data_generation(self, batch_df):
X_img = []
X_gender = []
y = []
for _, row in batch_df.iterrows():
# 获取图像数据
img_name = row['ID'] + '.png' # 添加扩展名
img_path = os.path.join(self.image_dir, img_name)
try:
img = image.load_img(img_path, color_mode='grayscale', target_size=self.img_size)
except FileNotFoundError:
print(f"File not found: {img_path}")
continue
img = image.img_to_array(img)
# Convert grayscale to RGB by repeating the single channel 3 times
img = np.repeat(img, 3, axis=2)
img = datagen.standardize(img)
X_img.append(img)
# 获取性别数据
gender = row['Male']
X_gender.append(gender)
# 获取骨龄
boneage = row['BoneAge']
y.append(boneage)
X_img = np.array(X_img)
X_gender = np.array(X_gender)
y = np.array(y)
return ([X_img, X_gender], y)
# 创建训练和验证数据生成器
train_generator = CustomDataGenerator(train_boneage_info, images_dir, BATCH_SIZE, (IMG_WIDTH, IMG_HEIGHT))
validation_generator = CustomDataGenerator(val_boneage_info, images_dir, BATCH_SIZE, (IMG_WIDTH, IMG_HEIGHT))
# 构建模型
IMG_CHANNELS = 3
base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS))
# 添加性别信息
gender_input = Input(shape=(1,), dtype=tf.int32)
# 将整数值转换为浮点数,以便进行后续处理
gender_float = tf.cast(gender_input, tf.float32)
gender_dense = Dense(32, activation='relu')(gender_float)
# 提取InceptionV3的最后一层
last_layer = base_model.output
flatten = Flatten()(last_layer)
# 拼接性别信息
concat = Concatenate()([flatten, gender_dense])
# 添加额外的全连接层
dense1 = Dense(1000, activation='relu')(concat)
dense2 = Dense(1000, activation='relu')(dense1)
# 最终输出层
output = Dense(1)(dense2)
# 创建模型
model = Model(inputs=[base_model.input, gender_input], outputs=output)
# 编译模型
model.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='mean_absolute_error')
# 定义早停和学习率衰减回调
#early_stopping = EarlyStopping(monitor='val_loss', patience=10)
educe_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
# 训练模型
history = model.fit(
train_generator,
steps_per_epoch=len(train_generator),
validation_data=validation_generator,
validation_steps=len(validation_generator),
epochs=EPOCHS,
callbacks=[reduce_lr,early_stopping]
)
# 保存模型
model.save(r'D:\python_code\Bone-Age-Assessment-master\bone_age_model_processed.h5')
gl2222
yyfhq2024-11-05 3:04
相关推荐
API快乐传递者16 分钟前
用 Python 爬取淘宝商品价格信息时需要注意什么?Aurora_th20 分钟前
蓝桥杯 Python组-神奇闹钟(datetime库)萧鼎35 分钟前
【Python】计算机视觉应用:OpenCV库图像处理入门子午1 小时前
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型是个热心市民1 小时前
构建一个导航栏web大哇唧2 小时前
python批量合并excel文件墨城烟柳Q2 小时前
自动化爬虫-selenium模块万字详解raoxiaoya2 小时前
python安装selenium,geckodriver,chromedriverDxy12393102163 小时前
python使用requests发送请求ssl错误gxchai3 小时前
利用pythonstudio写的PDF、图片批量水印生成器,可同时为不同读者生成多组水印