Swin Unet 代码运行教程(总结了多个训练和测试的问题)

发现网上尽然没有 swin unet 的运行教程呜呜,那我来出一份吧

环境准备

克隆 Swin Unet 项目地址(github.com/HuCaoFighti... python=3.7 版本安装项目依赖

Bash 复制代码
pip install -r requirements.txt

训练

获取训练的数据集,接下来以 Synapse 为例

通过仓库给的链接 得到 project_TransUNet,根据 ./datasets/README.md 文件的信息,将文件夹按照格式放到项目当中的位置,在其中我们也会得到Synapse处理好后的数据集(注意 TransUNet 无需安装依赖)

根据官网给的例子开始训练模型

Bash 复制代码
python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path ./data/Synapse --list_dir ./lists/Synapse --max_epochs 150 --output_dir ./model_out --img_size 224 --base_lr 0.05 --batch_size 24

your DATA_DIR:./data/Synapse your OUT_DIR:./model_out

遇到的问题

一、文件缺少

在lists/Synapse 目录中缺少 train.txt 和 val.txt 文件 这个需要自己创建,对应Synapse数据集中test_vol_h5和train_npz的文件名 ● train.txt:包含 train_npz 文件夹中所有训练文件的名称或路径。 ● val.txt:包含 test_vol_h5 文件夹中所有验证文件的名称或路径。 可以叫GPT生成脚本,或者用下面的脚本创建

python 复制代码
import os

# 定义数据目录路径
train_dir = './data/Synapse/train_npz'
val_dir = './data/Synapse/test_vol_h5'

# 定义输出的列表文件路径
output_dir = './lists/Synapse'
os.makedirs(output_dir, exist_ok=True)  # 如果目录不存在则创建

# 生成 train.txt 文件
train_files = os.listdir(train_dir)
with open(os.path.join(output_dir, 'train.txt'), 'w') as train_f:
    for file_name in train_files:
        if file_name.endswith('.npz'):
            train_f.write(f"{file_name}\n")

# 生成 val.txt 文件
val_files = os.listdir(val_dir)
with open(os.path.join(output_dir, 'val.txt'), 'w') as val_f:
    for file_name in val_files:
        if file_name.endswith('.npy.h5'):
            val_f.write(f"{file_name}\n")

print("train.txt 和 val.txt 文件已成功生成!")

二、标签值不在类别范围内

项目当中的类别为 [0, 4] ,但是我们的数据集的 label 有些是超过4的,不在范围当中

报错信息提到:Assertion t >= 0 && t < n_classes failed

所以我们得对其做个映射,在根目录下面创建 map_convert.py

python 复制代码
import os
import numpy as np
import h5py

# 定义数据路径
train_dir = './data/Synapse/train_npz'
test_dir = './data/Synapse/test_vol_h5'

def remap_labels(label_array):
    """
    将标签值重新映射,保留标签值在 [0, 3] 范围内
    并将标签 7 重新映射为 0。
    """
    label_array[label_array == 7] = 0
    label_array = np.clip(label_array, 0, 3)
    return label_array

# 处理训练数据
for filename in os.listdir(train_dir):
    if filename.endswith('.npz'):
        file_path = os.path.join(train_dir, filename)
        # 加载 npz 文件
        data = np.load(file_path)
        image = data['image']
        label = data['label']

        # 重新映射标签
        label = remap_labels(label)

        # 重新保存到原始路径
        np.savez_compressed(file_path, image=image, label=label)
        print(f"Processed and saved: {file_path}")

# 处理测试数据
for filename in os.listdir(test_dir):
    if filename.endswith('.npy.h5'):
        file_path = os.path.join(test_dir, filename)
        # 加载 h5 文件
        with h5py.File(file_path, 'r+') as h5f:
            image = h5f['image'][:]
            label = h5f['label'][:]

            # 重新映射标签
            label = remap_labels(label)

            # 删除原始 label 数据集并重新写入
            del h5f['label']
            h5f.create_dataset('label', data=label, compression="gzip")
            print(f"Processed and saved: {file_path}")

print("标签重新映射完成!")

解决完成之后,就可以正常运行啦~当然没遇到最好不过了

结果

训练完成之后,就可以在 Swin-Unet/model_out/log.txt 得到训练信息

我用的是3080Ti,大概训练时间为3-4h

测试

测试命令如下:

bash 复制代码
python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_savenii --root_path ./data/Synapse --output_dir ./model_out --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

注意点:

  1. 官网中是--volume_path,我改成了--root_path,因为源码中是使用了root_path
  2. max_epoch 和 batch_size 得对应上面训练时的参数

遇到的问题:

一、'Namespace' object has no attribute 'volume_path'

test.py 中,代码改为如下:

python 复制代码
if args.dataset == "Synapse":
    args.volume_path = os.path.join(args.root_path, "test_vol_h5")

二、不存在 test.txt 文件

这个文件是根据我们的测试集数据的文件名得到的,非自带,可以 GPT 生成脚本 在 Swin-Unet/lists/Synapse 目录下创建 test.txt 文件,内容如下:

bash 复制代码
test_vol_h5/case0001
test_vol_h5/case0002
test_vol_h5/case0003
test_vol_h5/case0004
test_vol_h5/case0008
test_vol_h5/case0022
test_vol_h5/case0025
test_vol_h5/case0029
test_vol_h5/case0032
test_vol_h5/case0035
test_vol_h5/case0036
test_vol_h5/case0038

三、cannot select an axis to squeeze out which has size not equal to one

代码中维度问题,出问题的代码如下

python 复制代码
image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

改为如下

python 复制代码
image = image.cpu().detach().numpy()
    label = label.cpu().detach().numpy()

    if image.shape[0] == 1:
        image = image.squeeze(0)
    if label.shape[0] == 1:
        label = label.squeeze(0)

    #image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

结果

运行上述命令行结束后,得到以下结果 Over!

祝大家科研顺利~有问题可以在评论区交流

相关推荐
电棍23321 分钟前
工程记录:使用tello edu无人机进行计算机视觉工作(手势识别,yolo3搭载)
人工智能·计算机视觉·无人机
Bellafu6664 小时前
selenium常用的等待有哪些?
python·selenium·测试工具
小白学大数据5 小时前
Python爬虫常见陷阱:Ajax动态生成内容的URL去重与数据拼接
爬虫·python·ajax
2401_841495646 小时前
【计算机视觉】基于复杂环境下的车牌识别
人工智能·python·算法·计算机视觉·去噪·车牌识别·字符识别
Adorable老犀牛7 小时前
阿里云-ECS实例信息统计并发送统计报告到企业微信
python·阿里云·云计算·企业微信
倔强青铜三7 小时前
苦练Python第66天:文件操作终极武器!shutil模块完全指南
人工智能·python·面试
倔强青铜三7 小时前
苦练Python第65天:CPU密集型任务救星!多进程multiprocessing模块实战解析,攻破GIL限制!
人工智能·python·面试
Panda__Panda7 小时前
docker项目打包演示项目(数字排序服务)
运维·javascript·python·docker·容器·c#
Wnq100727 小时前
如何在移动 的巡检机器人上,实现管道跑冒滴漏的视觉识别
数码相机·opencv·机器学习·计算机视觉·目标跟踪·自动驾驶
Lris-KK8 小时前
力扣Hot100--94.二叉树的中序遍历、144.二叉树的前序遍历、145.二叉树的后序遍历
python·算法·leetcode