Unet实现脑肿瘤分割检测

最终的实现效果是这样的,可以预测医学影像里面的脑肿瘤并且画出位置

如果说我们没有任何的东西的话,需要先下载一下数据集。(运行下面的代码的时候下载是在github下载的,可能需要科学上网)

对于数据集我们可以运行下面的download.py来下载(我们下载的照片COCO数据集是自带掩码的不需要其它替标签再给他打上去了,就是图片对应的joson文件)

python 复制代码
import requests
import os

def download_file(url, save_path):
    # 发送GET请求下载文件
    print(f"开始下载文件: {url}")
    response = requests.get(url, stream=True)
    response.raise_for_status()  # 检查是否下载成功
    
    # 获取文件大小
    file_size = int(response.headers.get('content-length', 0))
    
    # 写入文件
    with open(save_path, 'wb') as f:
        if file_size == 0:
            f.write(response.content)
        else:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    # 显示下载进度
                    progress = int(50 * downloaded / file_size)
                    print(f"\r下载进度: [{'=' * progress}{' ' * (50-progress)}] {downloaded}/{file_size} bytes", end='')
    print("\n下载完成!")

# 设置下载链接和保存路径
url = "https://github.com/Zeyi-Lin/UNet-Medical/releases/download/data/Brain.Tumor.Image.DataSet.zip"
save_path = "dataset/Brain_Tumor_Image_DataSet.zip"

# 创建datasets目录
os.makedirs("dataset", exist_ok=True)

# 执行下载
download_file(url, save_path)

我们来看代码里面,首先response第一行是可以以流的方式传播大文件并且选地址是url。第二行是回复404之类的来检测是否下载成功。file_size是获取文件的大小。

下一行写入文件with open(save_path,"wb") as f:这个是以二进制形式写入文件

下一句if file_size==0也就是说是小文件的意思,这个时候就直接写入

else 的话就downloaded=0初始化下载,然后chunk每次读取8192B循环写入文件,每执行一次都会写入然后累计下载量,通过以恶搞progress的计算然后print输出下载的进度。

最后给下载地址url以及下载的路径save_path。如果没有dataset就给他创建一个。以及最后的调用函数。如此运行这个函数就可以下载数据集了,也可以直接点代码里面的地址下载。

后面需要data.py对数据集进行初步的处理。

python 复制代码
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import torch.nn.functional as nn

class COCOSegmentationDataset(Dataset):
    def __init__(self, coco, image_dir, transform=None):
        self.coco = coco
        self.image_dir = image_dir
        self.image_ids = coco.getImgIds()
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = os.path.join(self.image_dir, image_info['file_name'])

        # 加载图像
        image = Image.open(image_path)
        image = np.array(image, dtype=np.uint8)

        # 创建掩码
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        anns = self.coco.loadAnns(ann_ids)
        mask = np.zeros((image_info['height'], image_info['width']), dtype=np.uint8)
        for ann in anns:
            mask = np.maximum(mask, self.coco.annToMask(ann))

        # 转换为张量并预处理
        if self.transform:
            image = self.transform(image)
            mask = torch.from_numpy(mask).float().unsqueeze(0)
            mask = nn.interpolate(mask.unsqueeze(0), size=(256, 256), mode='nearest').squeeze(0)

        return image, mask

首先代码的刚开始是def __init__这样一个初始化,里面引入了dir图像位置,ids图像id

,transform初步变化这几个self属性

下面len函数是用来返回数据集的长度的

之后image_id来选择一个图片

后面info调用coco.loadImags来通过coco api获取图片选择第一个图片(即使是一个图片也会返回一个数组我们选择【0】就是第一个图片)的名称,长宽等字典

(COCO API提供了一系列函数,用于加载、解析和可视化COCO数据集中的注释(annotations)和图像(images)。它可以帮助我们轻松地获取图像的信息、标注信息、类别信息等。)

之后通过path调用info里面的name以及image_dir来找到这个路径。

后面通过open打开对应路径的图片再用numpy转化成数组

下面有ann_ids通过getAnnIds来获取标注id

之后通过标注ID列表,获取每个标注的完整详细信息(可以根据下面的这个理解)

{

'id': 456, # 标注ID

'image_id': 123, # 对应的图片ID

'category_id': 1, # 类别ID(比如1代表人,2代表车)

'bbox': [x, y, width, height], # 边界框坐标

'segmentation': [...], # 分割掩码的多边形坐标

'area': 1500.5, # 物体面积

'iscrowd': 0 # 是否被遮挡(0表示单个物体)

}

之后是mask作为一个掩码,让我们想处理相应的区域的时候让掩码操作只显示对应的区域

image = self.transform(image):对图像应用变换(如归一化、re

size等)。

from_numpy那个是转换成Tensoe张量因为PyTorch的所有操作(神经网络层、损失函数、优化器等)都基于torch.Tensor,保持数据类型统一可以避免后续的重复转换

这两步专门用于掩码的后处理,目的是将原始的掩码数据转换为适合神经网络输入的格式

一个return image ,mask返回图片的信息以及掩码的tensor

net.py主要是一些Unet模型的一些东西

python 复制代码
import torch
import torch.nn as nn

# 定义U-Net模型的下采样块
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0, max_pooling=True):
        super(DownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2) if max_pooling else None
        self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else None

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        if self.dropout:
            x = self.dropout(x)
        skip = x
        if self.maxpool:
            x = self.maxpool(x)
        return x, skip

# 定义U-Net模型的上采样块
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(out_channels * 2, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# 定义完整的U-Net模型
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, n_filters=32):
        super(UNet, self).__init__()
        
        # 编码器路径
        self.down1 = DownBlock(n_channels, n_filters)
        self.down2 = DownBlock(n_filters, n_filters * 2)
        self.down3 = DownBlock(n_filters * 2, n_filters * 4)
        self.down4 = DownBlock(n_filters * 4, n_filters * 8)
        self.down5 = DownBlock(n_filters * 8, n_filters * 16)
        
        # 瓶颈层 - 移除最后的maxpooling
        self.bottleneck = DownBlock(n_filters * 16, n_filters * 32, dropout_prob=0.4, max_pooling=False)
        
        # 解码器路径
        self.up1 = UpBlock(n_filters * 32, n_filters * 16)
        self.up2 = UpBlock(n_filters * 16, n_filters * 8)
        self.up3 = UpBlock(n_filters * 8, n_filters * 4)
        self.up4 = UpBlock(n_filters * 4, n_filters * 2)
        self.up5 = UpBlock(n_filters * 2, n_filters)
        
        # 输出层
        self.outc = nn.Conv2d(n_filters, n_classes, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 编码器路径
        x1, skip1 = self.down1(x)      # 128
        x2, skip2 = self.down2(x1)     # 64
        x3, skip3 = self.down3(x2)     # 32
        x4, skip4 = self.down4(x3)     # 16
        x5, skip5 = self.down5(x4)     # 8
        
        # 瓶颈层
        x6, skip6 = self.bottleneck(x5)  # 8 (无下采样)
        
        # 解码器路径
        x = self.up1(x6, skip5)    # 16
        x = self.up2(x, skip4)     # 32
        x = self.up3(x, skip3)     # 64
        x = self.up4(x, skip2)     # 128
        x = self.up5(x, skip1)     # 256
        
        x = self.outc(x)
        x = self.sigmoid(x)
        return x

我们在这个里面有DownBlock这样的一个函数定义的是系采样模块

还有UpBlock这样的上采样模块

然后再通过一个完整的Unet模型给完整的呈现出来

通过一个forward前向传播作为后面处理数据的函数

(例如你可以通过下面如图所示来进行一个而简单的模型的调用)

train,py

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import random
import os
import time
from net import UNet
from data import COCOSegmentationDataset

# 获取当前脚本所在目录
base_dir = os.path.dirname(os.path.abspath(__file__))

# 数据路径设置
train_dir = os.path.join(base_dir, 'dataset', 'train')
val_dir = os.path.join(base_dir, 'dataset', 'valid')
test_dir = os.path.join(base_dir, 'dataset', 'test')

train_annotation_file = os.path.join(base_dir, 'dataset', 'train', '_annotations.coco.json')
test_annotation_file = os.path.join(base_dir, 'dataset', 'test', '_annotations.coco.json')
val_annotation_file = os.path.join(base_dir, 'dataset', 'valid', '_annotations.coco.json')

# 检查文件是否存在
def check_files_exist():
    required_files = [
        train_annotation_file,
        val_annotation_file,
        test_annotation_file
    ]
    
    for file_path in required_files:
        if not os.path.exists(file_path):
            print(f"错误: 找不到文件 {file_path}")
            print(f"当前工作目录: {os.getcwd()}")
            print(f"基础目录: {base_dir}")
            return False
        else:
            print(f"找到文件: {file_path}")
    return True

print("检查数据文件...")
if not check_files_exist():
    print("请确保数据集已正确解压且路径正确")
    print("数据集应该包含: dataset/train/, dataset/valid/, dataset/test/ 文件夹")
    print("每个文件夹中应该有 _annotations.coco.json 文件")
    exit(1)

# 加载COCO数据集
print("加载COCO数据集...")
train_coco = COCO(train_annotation_file)
val_coco = COCO(val_annotation_file)
test_coco = COCO(test_annotation_file)
print("COCO数据集加载完成!")

# 配置参数
config = {
    "batch_size": 4,  # 暂时减小批次大小以便调试
    "learning_rate": 1e-4,
    "num_epochs": 40,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

# 定义损失函数
def dice_loss(pred, target, smooth=1e-6):
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return 1 - ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))

def combined_loss(pred, target):
    dice = dice_loss(pred, target)
    bce = nn.BCELoss()(pred, target)
    return 0.6 * dice + 0.4 * bce

def main():
    print("=" * 50)
    print("开始训练程序")
    print("=" * 50)
    
    print("配置信息:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # 设置设备
    device = torch.device(config["device"])
    print(f"使用设备: {device}")
    
    # 数据预处理
    print("创建数据预处理转换...")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 创建数据集
    print("创建数据集中...")
    try:
        train_dataset = COCOSegmentationDataset(train_coco, train_dir, transform=transform)
        print(f"训练集创建成功,样本数: {len(train_dataset)}")
        
        val_dataset = COCOSegmentationDataset(val_coco, val_dir, transform=transform)
        print(f"验证集创建成功,样本数: {len(val_dataset)}")
        
        test_dataset = COCOSegmentationDataset(test_coco, test_dir, transform=transform)
        print(f"测试集创建成功,样本数: {len(test_dataset)}")
    except Exception as e:
        print(f"创建数据集时出错: {e}")
        return
    
    # 创建数据加载器
    print("创建数据加载器中...")
    try:
        BATCH_SIZE = config["batch_size"]
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=0)
        print("数据加载器创建成功!")
    except Exception as e:
        print(f"创建数据加载器时出错: {e}")
        return
    
    # 测试数据加载
    print("测试数据加载...")
    try:
        start_time = time.time()
        for i, (images, masks) in enumerate(train_loader):
            print(f"批次 {i+1}: 图像形状 {images.shape}, 掩码形状 {masks.shape}")
            if i >= 2:  # 只测试前3个批次
                break
        end_time = time.time()
        print(f"数据加载测试成功! 耗时: {end_time - start_time:.2f}秒")
    except Exception as e:
        print(f"数据加载测试失败: {e}")
        return
    
    # 初始化模型
    print("初始化模型中...")
    try:
        model = UNet(n_filters=16).to(device)  # 使用更小的模型进行测试
        print(f"模型初始化成功! 参数数量: {sum(p.numel() for p in model.parameters())}")
    except Exception as e:
        print(f"模型初始化失败: {e}")
        return
    
    # 设置优化器
    print("设置优化器...")
    try:
        optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
        print("优化器设置成功!")
    except Exception as e:
        print(f"优化器设置失败: {e}")
        return
    
    # 开始训练
    print("开始训练...")
    try:
        train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=combined_loss,
            optimizer=optimizer,
            num_epochs=config["num_epochs"],
            device=device,
        )
    except Exception as e:
        print(f"训练过程中出错: {e}")
        return
    
    print("训练完成!")

# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    print("进入训练循环...")
    best_val_loss = float('inf')
    patience = 8
    patience_counter = 0
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(num_epochs):
        print(f"开始第 {epoch+1}/{num_epochs} 轮训练...")
        model.train()
        train_loss = 0
        train_acc = 0
        
        batch_count = 0
        for images, masks in train_loader:
            batch_count += 1
            if batch_count % 10 == 0:  # 每10个批次打印一次进度
                print(f"  处理第 {batch_count} 个批次...")
                
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += (outputs.round() == masks).float().mean().item()

        train_loss /= len(train_loader)
        train_acc /= len(train_loader)
        
        # 验证
        print(f"开始第 {epoch+1} 轮验证...")
        model.eval()
        val_loss = 0
        val_acc = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                val_loss += loss.item()
                val_acc += (outputs.round() == masks).float().mean().item()
        
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        print('-' * 50)
        
        # 早停
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"保存最佳模型,验证损失: {val_loss:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("早停触发")
                break
    
    return history

if __name__ == '__main__':
    main()

这里刚开始就是通过给train_dir,val_dir,test_dir来设置数据集的路径

下面的annotation这些调用.json文件作为上面的数据集的一个标签。。

但是呢这个.json的赋值的annotation的标签需要通过COCO来进行一个数据加载和预处理

主要在transforms.Compose这里进行了转Tensor

之后我们通过train_dataset = COCOSegmentationDataset(train_coco, train_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)这样两行代码进行了数据集和数据加载器的创建,后面的那个数据加载其的作用是把我们的这个dataset给他分成很多个batch以便后面的模型的批次的训练

以及这里的损失函数的定义dice_loss,这个函数是专门用于分割任务的衡量预测和真实掩码的重叠度的一个函数,,并且我们用的是diceLoss和bceLoss的结合,而后面的bce是为了提供稳定的梯度信号,帮助模型收敛。 以及我们选择的优化器Adam

最后就是核心的部分:model.train()训练模式,这里有一个典型的步骤就是

1,清空梯度,2,前向传播,3,计算损失,4,反向传播,5,更新参数。。

如此我们的训练模型就这么结束了。(里面还有会有对test集的预估从而检测我们的训练好坏,以及可视化的部分,这里可以自己看一下,我只讲了大概的框架)

最后是我们的predict.py

python 复制代码
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from net import UNet
import numpy as np
import os
import glob
import time
import datetime

# 设置matplotlib使用支持中文的字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

def load_model(model_path='readme_files/best_model.pth', device='cpu'):
    """加载训练好的模型"""
    try:
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"模型文件不存在: {model_path}")
            
        print(f"从 {model_path} 加载模型...")
        
        # 使用与训练时相同的配置
        model = UNet(n_filters=16).to(device)
        
        # 加载权重
        state_dict = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(state_dict)
        model.eval()
        
        print("模型加载成功!")
        print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
        return model
        
    except Exception as e:
        print(f"加载模型时出错: {e}")
        raise

def preprocess_image(image_path):
    """预处理输入图像"""
    image = Image.open(image_path).convert('RGB')
    display_image = image.resize((256, 256), Image.Resampling.BILINEAR)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image_tensor = transform(image)
    return image_tensor.unsqueeze(0), display_image

def predict_mask(model, image_tensor, device='cpu', threshold=0.5):
    """预测分割掩码"""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        prediction = model(image_tensor)
        prediction = (prediction > threshold).float()
    return prediction

def generate_unique_filename(image_path, suffix):
    """生成唯一的文件名,避免覆盖"""
    # 获取原始图像的文件名(不含扩展名)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    
    # 添加时间戳确保唯一性
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 组合成新文件名
    return f"{base_name}_{suffix}_{timestamp}.png"

def visualize_result(original_image, predicted_mask, image_path, save_dir='./results'):
    """可视化预测结果"""
    # 创建结果目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 生成唯一文件名
    save_path = os.path.join(save_dir, generate_unique_filename(image_path, "predictions"))
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(original_image)
    plt.title('原始脑部图像', fontsize=12)
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(predicted_mask.squeeze(), cmap='gray')
    plt.title('预测肿瘤区域', fontsize=12)
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(np.array(original_image))
    plt.imshow(predicted_mask.squeeze(), cmap='Reds', alpha=0.4)
    plt.title('叠加显示 (红色=肿瘤)', fontsize=12)
    plt.axis('off')
        
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"可视化结果已保存为: {save_path}")
    
    # 显示图像(可选)
    plt.show()
    
    return save_path

def save_mask_as_image(mask_array, image_path, save_dir='./results'):
    """将分割掩码保存为图像文件"""
    # 创建结果目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 生成唯一文件名
    output_path = os.path.join(save_dir, generate_unique_filename(image_path, "mask"))
    
    mask_uint8 = (mask_array.squeeze() * 255).astype(np.uint8)
    mask_image = Image.fromarray(mask_uint8, mode='L')
    mask_image.save(output_path)
    print(f"分割掩码已保存为: {output_path}")
    
    return output_path

def save_mask_as_numpy(mask_array, image_path, save_dir='./results'):
    """将分割掩码保存为numpy文件"""
    # 创建结果目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 生成唯一文件名
    output_path = os.path.join(save_dir, generate_unique_filename(image_path, "data").replace('.png', '.npy'))
    
    np.save(output_path, mask_array.squeeze())
    print(f"分割数据已保存为: {output_path}")
    
    return output_path

def find_any_image():
    """在当前目录和子目录中查找任何图像文件"""
    print("在当前目录和子目录中查找图像文件...")
    
    # 支持的图像格式
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
    
    for ext in extensions:
        # 在当前目录查找
        files = glob.glob(f"*{ext}")
        if files:
            print(f"在当前目录找到图像: {files[0]}")
            return files[0]
        
        # 在所有子目录中查找
        files = glob.glob(f"**/*{ext}", recursive=True)
        if files:
            print(f"在子目录找到图像: {files[0]}")
            return files[0]
    
    return None

def main():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    try:
        # 加载模型
        model_path = "./best_model.pth"
        model = load_model(model_path, device)
        
        # 查找图像文件
        image_path = None
        
        # 首先尝试使用您原来的路径
        original_path =r"D:\桌面\Things_have_been_completed\xiae_Study\UNet-Medical-master\UNet-Transformer\dataset\test\2603_jpg.rf.5e3809e5081d5f1a7f30ba781331c4b2.jpg"
        if os.path.exists(original_path):
            image_path = original_path
            print(f"使用原始路径图像: {image_path}")
        else:
            print(f"原始路径不存在: {original_path}")
            
            # 尝试在当前目录和子目录中查找任何图像
            image_path = find_any_image()
            
            if image_path is None:
                # 如果还是找不到,让用户输入
                print("\n无法自动找到图像文件,请选择以下选项:")
                print("1. 手动输入图像文件路径")
                print("2. 退出程序")
                
                choice = input("请选择 (1 或 2): ").strip()
                if choice == "1":
                    image_path = input("请输入图像文件的完整路径: ").strip()
                    # 去除可能的引号
                    image_path = image_path.strip('"').strip("'")
                else:
                    print("程序退出")
                    return
        
        # 验证图像文件是否存在
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"图像文件不存在: {image_path}")
            
        print(f"处理图像: {image_path}")
        
        # 预处理图像
        image_tensor, original_image = preprocess_image(image_path)
        print(f"输入图像形状: {image_tensor.shape}")
        
        # 进行预测
        print("进行预测...")
        predicted_mask = predict_mask(model, image_tensor, device)
        
        # 将预测结果转回CPU并转换为numpy数组
        predicted_mask = predicted_mask.cpu().numpy()
        print(f"预测掩码形状: {predicted_mask.shape}")
        
        # 计算肿瘤区域统计
        tumor_pixels = np.sum(predicted_mask > 0)
        total_pixels = predicted_mask.size
        tumor_ratio = (tumor_pixels / total_pixels) * 100
        print(f"肿瘤区域统计: {tumor_pixels}/{total_pixels} 像素 ({tumor_ratio:.2f}%)")
        
        # 创建结果目录
        results_dir = './results'
        os.makedirs(results_dir, exist_ok=True)
        
        # 可视化结果
        print("生成可视化结果...")
        prediction_path = visualize_result(original_image, predicted_mask, image_path, results_dir)
        
        # 保存各种格式的结果
        mask_path = save_mask_as_image(predicted_mask, image_path, results_dir)
        data_path = save_mask_as_numpy(predicted_mask, image_path, results_dir)
        
        print("\n" + "="*50)
        print("所有结果已成功保存到 results 目录:")
        print(f"- 可视化对比图: {os.path.basename(prediction_path)}")
        print(f"- 二值分割掩码: {os.path.basename(mask_path)}")
        print(f"- 原始数据文件: {os.path.basename(data_path)}")
        print(f"- 肿瘤区域占比: {tumor_ratio:.2f}%")
        print("="*50)
        
    except Exception as e:
        print(f"预测过程中出错: {str(e)}")
        print("\n故障排除建议:")
        print("1. 确保当前目录或子目录中有图像文件")
        print("2. 图像文件格式支持: JPG, JPEG, PNG, BMP, TIFF")
        print("3. 您可以手动将图像文件复制到当前目录")
        print("4. 或者运行程序时手动输入图像文件完整路径")

if __name__ == '__main__':
    main()

首先我们通过load_model来进行模型的加载,主要是加载权重best_model.pth,然后采用cpu进行处理(因为这个过程不需要大量的计算,cpu完全可以)

通过preprocess_image来对图片进行预处理,主要是转化成Tensor然后调整大小Resize并且进行标准化处理Normalize,以及增加批次的维度unsqueeze(0)

之后通过predict_mask来进行预测的则么一个模块,这个模块是通过调用模型然后生成目标的一个Tensor这个里面是图形的一个掩码图,目标的界限处是掩码1 其他是0,后面在进行对掩码处理就可以得到我们对应的一个目标的区域了。

最后我们在主函数里面调用就可以了。

设置cuda还是cpu,load_model加载模型,image_path这个是设置我们要检测的一个照片,然后呢preprocess预处理,predicted_mask预测生成掩码。

(注意路径这里你用反斜杠\的时候前面要加一个r或者用双反斜杠\\,不然系统会读成转义字符)

后面在进行一些可视化就ok了,如此就完成了这么一个U-Net医学映像的脑肿瘤预测。

大家可以看我另一篇加了Transformer的多注意力,会有更好的一个检测效果

Unet+Transformer脑肿瘤分割检测-CSDN博客

相关推荐
2501_941111772 小时前
C++代码移植性设计
开发语言·c++·算法
~无忧花开~2 小时前
Vue.config.js配置全攻略
开发语言·前端·javascript·vue.js
脉动数据行情2 小时前
Go语言对接股票、黄金、外汇API实时数据教程
开发语言·后端·golang
w***Q3502 小时前
前端跨平台开发工具,Tauri与Electron
前端·javascript·electron
幸会同学2 小时前
在Cesium中实现飘动的红旗
javascript·three.js·cesium
橘子真甜~2 小时前
C/C++ Linux网络编程5 - 网络IO模型与select解决客户端并发连接问题
linux·运维·服务器·c语言·开发语言·网络·c++
霖003 小时前
ZYNQ——ultra scale+ IP 核详解与配置
服务器·开发语言·网络·笔记·网络协议·tcp/ip
flypwn3 小时前
justCTF 2025JSpositive_player知识
开发语言·javascript·原型模式
oliveira-time3 小时前
原型模式中的深浅拷贝
java·开发语言·原型模式