Swin Unet 代码运行教程(总结了多个训练和测试的问题)

发现网上尽然没有 swin unet 的运行教程呜呜,那我来出一份吧

环境准备

克隆 Swin Unet 项目地址(github.com/HuCaoFighti... python=3.7 版本安装项目依赖

Bash 复制代码
pip install -r requirements.txt

训练

获取训练的数据集,接下来以 Synapse 为例

通过仓库给的链接 得到 project_TransUNet,根据 ./datasets/README.md 文件的信息,将文件夹按照格式放到项目当中的位置,在其中我们也会得到Synapse处理好后的数据集(注意 TransUNet 无需安装依赖)

根据官网给的例子开始训练模型

Bash 复制代码
python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path ./data/Synapse --list_dir ./lists/Synapse --max_epochs 150 --output_dir ./model_out --img_size 224 --base_lr 0.05 --batch_size 24

your DATA_DIR:./data/Synapse your OUT_DIR:./model_out

遇到的问题

一、文件缺少

在lists/Synapse 目录中缺少 train.txt 和 val.txt 文件 这个需要自己创建,对应Synapse数据集中test_vol_h5和train_npz的文件名 ● train.txt:包含 train_npz 文件夹中所有训练文件的名称或路径。 ● val.txt:包含 test_vol_h5 文件夹中所有验证文件的名称或路径。 可以叫GPT生成脚本,或者用下面的脚本创建

python 复制代码
import os

# 定义数据目录路径
train_dir = './data/Synapse/train_npz'
val_dir = './data/Synapse/test_vol_h5'

# 定义输出的列表文件路径
output_dir = './lists/Synapse'
os.makedirs(output_dir, exist_ok=True)  # 如果目录不存在则创建

# 生成 train.txt 文件
train_files = os.listdir(train_dir)
with open(os.path.join(output_dir, 'train.txt'), 'w') as train_f:
    for file_name in train_files:
        if file_name.endswith('.npz'):
            train_f.write(f"{file_name}\n")

# 生成 val.txt 文件
val_files = os.listdir(val_dir)
with open(os.path.join(output_dir, 'val.txt'), 'w') as val_f:
    for file_name in val_files:
        if file_name.endswith('.npy.h5'):
            val_f.write(f"{file_name}\n")

print("train.txt 和 val.txt 文件已成功生成!")

二、标签值不在类别范围内

项目当中的类别为 [0, 4] ,但是我们的数据集的 label 有些是超过4的,不在范围当中

报错信息提到:Assertion t >= 0 && t < n_classes failed

所以我们得对其做个映射,在根目录下面创建 map_convert.py

python 复制代码
import os
import numpy as np
import h5py

# 定义数据路径
train_dir = './data/Synapse/train_npz'
test_dir = './data/Synapse/test_vol_h5'

def remap_labels(label_array):
    """
    将标签值重新映射,保留标签值在 [0, 3] 范围内
    并将标签 7 重新映射为 0。
    """
    label_array[label_array == 7] = 0
    label_array = np.clip(label_array, 0, 3)
    return label_array

# 处理训练数据
for filename in os.listdir(train_dir):
    if filename.endswith('.npz'):
        file_path = os.path.join(train_dir, filename)
        # 加载 npz 文件
        data = np.load(file_path)
        image = data['image']
        label = data['label']

        # 重新映射标签
        label = remap_labels(label)

        # 重新保存到原始路径
        np.savez_compressed(file_path, image=image, label=label)
        print(f"Processed and saved: {file_path}")

# 处理测试数据
for filename in os.listdir(test_dir):
    if filename.endswith('.npy.h5'):
        file_path = os.path.join(test_dir, filename)
        # 加载 h5 文件
        with h5py.File(file_path, 'r+') as h5f:
            image = h5f['image'][:]
            label = h5f['label'][:]

            # 重新映射标签
            label = remap_labels(label)

            # 删除原始 label 数据集并重新写入
            del h5f['label']
            h5f.create_dataset('label', data=label, compression="gzip")
            print(f"Processed and saved: {file_path}")

print("标签重新映射完成!")

解决完成之后,就可以正常运行啦~当然没遇到最好不过了

结果

训练完成之后,就可以在 Swin-Unet/model_out/log.txt 得到训练信息

我用的是3080Ti,大概训练时间为3-4h

测试

测试命令如下:

bash 复制代码
python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_savenii --root_path ./data/Synapse --output_dir ./model_out --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

注意点:

  1. 官网中是--volume_path,我改成了--root_path,因为源码中是使用了root_path
  2. max_epoch 和 batch_size 得对应上面训练时的参数

遇到的问题:

一、'Namespace' object has no attribute 'volume_path'

test.py 中,代码改为如下:

python 复制代码
if args.dataset == "Synapse":
    args.volume_path = os.path.join(args.root_path, "test_vol_h5")

二、不存在 test.txt 文件

这个文件是根据我们的测试集数据的文件名得到的,非自带,可以 GPT 生成脚本 在 Swin-Unet/lists/Synapse 目录下创建 test.txt 文件,内容如下:

bash 复制代码
test_vol_h5/case0001
test_vol_h5/case0002
test_vol_h5/case0003
test_vol_h5/case0004
test_vol_h5/case0008
test_vol_h5/case0022
test_vol_h5/case0025
test_vol_h5/case0029
test_vol_h5/case0032
test_vol_h5/case0035
test_vol_h5/case0036
test_vol_h5/case0038

三、cannot select an axis to squeeze out which has size not equal to one

代码中维度问题,出问题的代码如下

python 复制代码
image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

改为如下

python 复制代码
image = image.cpu().detach().numpy()
    label = label.cpu().detach().numpy()

    if image.shape[0] == 1:
        image = image.squeeze(0)
    if label.shape[0] == 1:
        label = label.squeeze(0)

    #image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

结果

运行上述命令行结束后,得到以下结果 Over!

祝大家科研顺利~有问题可以在评论区交流

相关推荐
W.KN3 小时前
PyTorch 数据类型和使用
人工智能·pytorch·python
虾饺爱下棋3 小时前
FCN语义分割算法原理与实战
人工智能·python·神经网络·算法
千册4 小时前
python+pyside6+sqlite 数据库测试
数据库·python·sqlite
双翌视觉5 小时前
智能制造的空间度量:机器视觉标定技术解析
数码相机·计算机视觉·视觉标定
悠哉悠哉愿意7 小时前
【电赛学习笔记】MaixCAM 的OCR图片文字识别
笔记·python·嵌入式硬件·学习·视觉检测·ocr
nbsaas-boot7 小时前
SQL Server 窗口函数全指南(函数用法与场景)
开发语言·数据库·python·sql·sql server
Catching Star7 小时前
【代码问题】【包安装】MMCV
python
摸鱼仙人~7 小时前
Spring Boot中的this::语法糖详解
windows·spring boot·python
Warren987 小时前
Java Stream流的使用
java·开发语言·windows·spring boot·后端·python·硬件工程
点云SLAM8 小时前
PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例
人工智能·pytorch·python·计算机视觉·3d深度学习·张量flatten操作·张量数据结构