Swin Unet 代码运行教程(总结了多个训练和测试的问题)

发现网上尽然没有 swin unet 的运行教程呜呜,那我来出一份吧

环境准备

克隆 Swin Unet 项目地址(github.com/HuCaoFighti... python=3.7 版本安装项目依赖

Bash 复制代码
pip install -r requirements.txt

训练

获取训练的数据集,接下来以 Synapse 为例

通过仓库给的链接 得到 project_TransUNet,根据 ./datasets/README.md 文件的信息,将文件夹按照格式放到项目当中的位置,在其中我们也会得到Synapse处理好后的数据集(注意 TransUNet 无需安装依赖)

根据官网给的例子开始训练模型

Bash 复制代码
python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path ./data/Synapse --list_dir ./lists/Synapse --max_epochs 150 --output_dir ./model_out --img_size 224 --base_lr 0.05 --batch_size 24

your DATA_DIR:./data/Synapse your OUT_DIR:./model_out

遇到的问题

一、文件缺少

在lists/Synapse 目录中缺少 train.txt 和 val.txt 文件 这个需要自己创建,对应Synapse数据集中test_vol_h5和train_npz的文件名 ● train.txt:包含 train_npz 文件夹中所有训练文件的名称或路径。 ● val.txt:包含 test_vol_h5 文件夹中所有验证文件的名称或路径。 可以叫GPT生成脚本,或者用下面的脚本创建

python 复制代码
import os

# 定义数据目录路径
train_dir = './data/Synapse/train_npz'
val_dir = './data/Synapse/test_vol_h5'

# 定义输出的列表文件路径
output_dir = './lists/Synapse'
os.makedirs(output_dir, exist_ok=True)  # 如果目录不存在则创建

# 生成 train.txt 文件
train_files = os.listdir(train_dir)
with open(os.path.join(output_dir, 'train.txt'), 'w') as train_f:
    for file_name in train_files:
        if file_name.endswith('.npz'):
            train_f.write(f"{file_name}\n")

# 生成 val.txt 文件
val_files = os.listdir(val_dir)
with open(os.path.join(output_dir, 'val.txt'), 'w') as val_f:
    for file_name in val_files:
        if file_name.endswith('.npy.h5'):
            val_f.write(f"{file_name}\n")

print("train.txt 和 val.txt 文件已成功生成!")

二、标签值不在类别范围内

项目当中的类别为 [0, 4] ,但是我们的数据集的 label 有些是超过4的,不在范围当中

报错信息提到:Assertion t >= 0 && t < n_classes failed

所以我们得对其做个映射,在根目录下面创建 map_convert.py

python 复制代码
import os
import numpy as np
import h5py

# 定义数据路径
train_dir = './data/Synapse/train_npz'
test_dir = './data/Synapse/test_vol_h5'

def remap_labels(label_array):
    """
    将标签值重新映射,保留标签值在 [0, 3] 范围内
    并将标签 7 重新映射为 0。
    """
    label_array[label_array == 7] = 0
    label_array = np.clip(label_array, 0, 3)
    return label_array

# 处理训练数据
for filename in os.listdir(train_dir):
    if filename.endswith('.npz'):
        file_path = os.path.join(train_dir, filename)
        # 加载 npz 文件
        data = np.load(file_path)
        image = data['image']
        label = data['label']

        # 重新映射标签
        label = remap_labels(label)

        # 重新保存到原始路径
        np.savez_compressed(file_path, image=image, label=label)
        print(f"Processed and saved: {file_path}")

# 处理测试数据
for filename in os.listdir(test_dir):
    if filename.endswith('.npy.h5'):
        file_path = os.path.join(test_dir, filename)
        # 加载 h5 文件
        with h5py.File(file_path, 'r+') as h5f:
            image = h5f['image'][:]
            label = h5f['label'][:]

            # 重新映射标签
            label = remap_labels(label)

            # 删除原始 label 数据集并重新写入
            del h5f['label']
            h5f.create_dataset('label', data=label, compression="gzip")
            print(f"Processed and saved: {file_path}")

print("标签重新映射完成!")

解决完成之后,就可以正常运行啦~当然没遇到最好不过了

结果

训练完成之后,就可以在 Swin-Unet/model_out/log.txt 得到训练信息

我用的是3080Ti,大概训练时间为3-4h

测试

测试命令如下:

bash 复制代码
python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_savenii --root_path ./data/Synapse --output_dir ./model_out --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

注意点:

  1. 官网中是--volume_path,我改成了--root_path,因为源码中是使用了root_path
  2. max_epoch 和 batch_size 得对应上面训练时的参数

遇到的问题:

一、'Namespace' object has no attribute 'volume_path'

test.py 中,代码改为如下:

python 复制代码
if args.dataset == "Synapse":
    args.volume_path = os.path.join(args.root_path, "test_vol_h5")

二、不存在 test.txt 文件

这个文件是根据我们的测试集数据的文件名得到的,非自带,可以 GPT 生成脚本 在 Swin-Unet/lists/Synapse 目录下创建 test.txt 文件,内容如下:

bash 复制代码
test_vol_h5/case0001
test_vol_h5/case0002
test_vol_h5/case0003
test_vol_h5/case0004
test_vol_h5/case0008
test_vol_h5/case0022
test_vol_h5/case0025
test_vol_h5/case0029
test_vol_h5/case0032
test_vol_h5/case0035
test_vol_h5/case0036
test_vol_h5/case0038

三、cannot select an axis to squeeze out which has size not equal to one

代码中维度问题,出问题的代码如下

python 复制代码
image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

改为如下

python 复制代码
image = image.cpu().detach().numpy()
    label = label.cpu().detach().numpy()

    if image.shape[0] == 1:
        image = image.squeeze(0)
    if label.shape[0] == 1:
        label = label.squeeze(0)

    #image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

结果

运行上述命令行结束后,得到以下结果 Over!

祝大家科研顺利~有问题可以在评论区交流

相关推荐
ROBOT玲玉25 分钟前
Milvus 中,FieldSchema 的 dim 参数和索引参数中的 “nlist“ 的区别
python·机器学习·numpy
Kai HVZ1 小时前
python爬虫----爬取视频实战
爬虫·python·音视频
古希腊掌管学习的神1 小时前
[LeetCode-Python版]相向双指针——611. 有效三角形的个数
开发语言·python·leetcode
m0_748244831 小时前
StarRocks 排查单副本表
大数据·数据库·python
B站计算机毕业设计超人1 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
路人甲ing..2 小时前
jupyter切换内核方法配置问题总结
chrome·python·jupyter
游客5202 小时前
opencv中的常用的100个API
图像处理·人工智能·python·opencv·计算机视觉
每天都要学信号2 小时前
Python(第一天)
开发语言·python
凡人的AI工具箱2 小时前
每天40分玩转Django:Django国际化
数据库·人工智能·后端·python·django·sqlite
咸鱼桨3 小时前
《庐山派从入门到...》PWM板载蜂鸣器
人工智能·windows·python·k230·庐山派