基于迁移学习实现肺炎X光片诊断分类

大家好,我是带我去滑雪!

肺炎是全球范围内致死率较高的疾病之一,尤其是在老年人、免疫系统较弱的患者群体中,更容易引发严重并发症。传统上,肺炎的诊断依赖于医生的临床经验以及影像学检查,尤其是X光片,它在肺炎的早期筛查和诊断中扮演了至关重要的角色。然而,X光片的读取不仅需要专业的放射科医生,而且受到经验和疲劳等因素的影响,导致诊断结果的准确性存在一定的偏差。近年来,人工智能(AI)技术,尤其是深度学习在医学影像领域取得了显著进展。通过深度学习模型,计算机能够高效地从大量影像数据中学习到复杂的模式,并实现对疾病的自动识别和分类,极大地提高了诊断的速度和准确性。迁移学习作为深度学习的一种重要方法,能够通过在已有的、大规模的医学图像数据上预训练模型,并迁移到肺炎X光片的分类任务上,减少对大量标注数据的需求,这对资源有限、标注困难的医学领域尤为重要。

基于迁移学习的肺炎X光片诊断分类研究,不仅可以缓解医生在实际工作中因繁重工作负担导致的诊断错误问题,还能够通过高效、准确的自动化诊断方法,在早期筛查中提供帮助,尤其是在偏远地区或医疗资源匮乏的环境中,为患者提供及时的诊疗建议,极大地促进了医疗资源的合理分配。此外,该研究的成功实现还可以为其他疾病的X光片图像诊断提供借鉴,推动人工智能技术在医学领域的广泛应用。下面开始代码实战。

目录

(1)导入相关模块

(2)构建数据集

(3)加载训练的网络

(4)调整模型

(5)设置测试集加载参数


(1)导入相关模块

python 复制代码
import os
from PIL import Image
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

(2)构建数据集

python 复制代码
class ChestXRayDataset(Dataset):
    def __init__(
            self,
            dataset_dir,
            transform=None) -> None:
        self.dataset_dir = dataset_dir
        self.transform = transform
        # 获取文件夹下所有图片路径
        self.dataset_images = glob(f"{self.dataset_dir}/**/*.jpeg", recursive=True)

    # 获取数据集大小
    def __len__(self):
        return len(self.dataset_images)

    # 读取图像,获取类别
    def __getitem__(self, idx):
        image_path = self.dataset_images[idx]
        image_name = os.path.basename(image_path)

        image = Image.open(image_path)
        if "NORMAL" in image_name:
            category = 0
        else:
            category = 1

        if self.transform:
            image = self.transform(image)

        return image, category

(3)加载训练的网络

python 复制代码
def prepare_model():
    # 加载预训练的模型
    resnet50_weight = ResNet50_Weights.DEFAULT
    resnet50_mdl = resnet50(weights=resnet50_weight)
    # 替换模型最后的全连接层
    num_ftrs = resnet50_mdl.fc.in_features
    resnet50_mdl.fc = nn.Linear(num_ftrs, 2)

    return resnet50_mdl


def train_model():
    # 确定使用CPU还是GPU
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"
    # 加载模型
    model = prepare_model()
    model = model.to(device)
    model.train()
    # 设置loss函数和optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 设置训练集数据加载相关变量
    batch_size = 32
    chest_xray = r"E:\工作\硕士\博客\博客99-深度学习医学特征提取\deeplea test\deeplea test\archive\chest_xray"
    train_dataset_dir = os.path.join(chest_xray, "train")
    train_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = ChestXRayDataset(train_dataset_dir, train_transforms)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True)

(4)调整模型

python 复制代码
   for epoch in range(5):
        print_batch = 50
        running_loss = 0
        running_corrects = 0
        for i, data in enumerate(train_dataloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += (loss.item() * batch_size)
            running_corrects += torch.sum(preds == labels.data)
            if i % print_batch == (print_batch - 1):  # print every 100 mini-batches
                accuracy = running_corrects / (print_batch * batch_size)
                print(
                    f'Epoch: {epoch + 1}, Batch: {i + 1:5d} Running Loss: {running_loss / 50:.3f} Accuracy: {accuracy:.3f}')
                running_loss = 0.0
                running_corrects = 0
        checkpoint_name = f"epoch_{epoch}.pth"
        torch.save(model.state_dict(), checkpoint_name)


def test_model():
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"
    # 加载模型
    checkpoint_name = "epoch_4.pth"
    model = prepare_model()
    model.load_state_dict(torch.load(checkpoint_name, map_location=device))
    model = model.to(device)
    model.eval()

(5)设置测试集加载参数

python 复制代码
    batch_size = 32
    chest_xray = r"E:\工作\硕士\博客\博客99-深度学习医学特征提取\deeplea test\deeplea test\archive\chest_xray"
    test_dataset_dir = os.path.join(chest_xray, "test")
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
        transforms.Resize((224, 224)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    test_dataset = ChestXRayDataset(test_dataset_dir, test_transforms)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False)
    # 在测试集测试模型
    with torch.no_grad():
        preds_list = []
        labels_list = []

        for i, data in enumerate(test_dataloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            preds_list.append(preds)
            labels_list.append(labels)
        preds = torch.cat(preds_list)
        labels = torch.cat(labels_list)
        # 计算评价指标
        corrects_num = torch.sum(preds == labels.data)
        accuracy = corrects_num / labels.shape[0]
        # 输出评价指标
        print(f"Accuracy on test dataset: {accuracy:.2%}")


if __name__ == "__main__":
    train_model()
    test_model()

输出结果:


更多优质内容持续发布中,请移步主页查看。

若有问题可邮箱联系:1736732074@qq.com

博主的WeChat:TCB1736732074

点赞+关注,下次不迷路!

相关推荐
打螺丝否2 小时前
稠密矩阵和稀疏矩阵的对比
python·机器学习·矩阵
初级炼丹师(爱说实话版)3 小时前
2025算法八股——机器学习——SVM损失函数
算法·机器学习·支持向量机
大学生毕业题目3 小时前
毕业项目推荐:83-基于yolov8/yolov5/yolo11的农作物杂草检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·目标检测·cnn·pyqt·杂草识别
非门由也3 小时前
《sklearn机器学习——聚类性能指标》Fowlkes-Mallows 得分
机器学习·聚类·sklearn
居7然3 小时前
美团大模型“龙猫”登场,能否重塑本地生活新战局?
人工智能·大模型·生活·美团
说私域3 小时前
社交新零售时代本地化微商的发展路径研究——基于开源AI智能名片链动2+1模式S2B2C商城小程序源的创新实践
人工智能·开源·零售
IT_陈寒3 小时前
Python性能优化:5个被低估的魔法方法让你的代码提速50%
前端·人工智能·后端
Deng_Xian_Sheng4 小时前
有哪些任务可以使用无监督的方式训练深度学习模型?
人工智能·深度学习·无监督
数据科学作家6 小时前
学数据分析必囤!数据分析必看!清华社9本书覆盖Stata/SPSS/Python全阶段学习路径
人工智能·python·机器学习·数据分析·统计·stata·spss
CV缝合救星8 小时前
【Arxiv 2025 预发行论文】重磅突破!STAR-DSSA 模块横空出世:显著性+拓扑双重加持,小目标、大场景统统拿下!
人工智能·深度学习·计算机视觉·目标跟踪·即插即用模块