发现网上尽然没有 swin unet 的运行教程呜呜,那我来出一份吧
环境准备
克隆 Swin Unet 项目地址(github.com/HuCaoFighti... python=3.7 版本安装项目依赖
Bash
pip install -r requirements.txt
训练
获取训练的数据集,接下来以 Synapse 为例
通过仓库给的链接 得到 project_TransUNet,根据 ./datasets/README.md 文件的信息,将文件夹按照格式放到项目当中的位置,在其中我们也会得到Synapse处理好后的数据集(注意 TransUNet 无需安装依赖)
根据官网给的例子开始训练模型
Bash
python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path ./data/Synapse --list_dir ./lists/Synapse --max_epochs 150 --output_dir ./model_out --img_size 224 --base_lr 0.05 --batch_size 24
your DATA_DIR:./data/Synapse your OUT_DIR:./model_out
遇到的问题
一、文件缺少
在lists/Synapse 目录中缺少 train.txt 和 val.txt 文件 这个需要自己创建,对应Synapse数据集中test_vol_h5和train_npz的文件名 ● train.txt:包含 train_npz 文件夹中所有训练文件的名称或路径。 ● val.txt:包含 test_vol_h5 文件夹中所有验证文件的名称或路径。 可以叫GPT生成脚本,或者用下面的脚本创建
python
import os
# 定义数据目录路径
train_dir = './data/Synapse/train_npz'
val_dir = './data/Synapse/test_vol_h5'
# 定义输出的列表文件路径
output_dir = './lists/Synapse'
os.makedirs(output_dir, exist_ok=True) # 如果目录不存在则创建
# 生成 train.txt 文件
train_files = os.listdir(train_dir)
with open(os.path.join(output_dir, 'train.txt'), 'w') as train_f:
for file_name in train_files:
if file_name.endswith('.npz'):
train_f.write(f"{file_name}\n")
# 生成 val.txt 文件
val_files = os.listdir(val_dir)
with open(os.path.join(output_dir, 'val.txt'), 'w') as val_f:
for file_name in val_files:
if file_name.endswith('.npy.h5'):
val_f.write(f"{file_name}\n")
print("train.txt 和 val.txt 文件已成功生成!")
二、标签值不在类别范围内
项目当中的类别为 [0, 4] ,但是我们的数据集的 label 有些是超过4的,不在范围当中
报错信息提到:Assertion t >= 0 && t < n_classes failed
所以我们得对其做个映射,在根目录下面创建 map_convert.py
python
import os
import numpy as np
import h5py
# 定义数据路径
train_dir = './data/Synapse/train_npz'
test_dir = './data/Synapse/test_vol_h5'
def remap_labels(label_array):
"""
将标签值重新映射,保留标签值在 [0, 3] 范围内
并将标签 7 重新映射为 0。
"""
label_array[label_array == 7] = 0
label_array = np.clip(label_array, 0, 3)
return label_array
# 处理训练数据
for filename in os.listdir(train_dir):
if filename.endswith('.npz'):
file_path = os.path.join(train_dir, filename)
# 加载 npz 文件
data = np.load(file_path)
image = data['image']
label = data['label']
# 重新映射标签
label = remap_labels(label)
# 重新保存到原始路径
np.savez_compressed(file_path, image=image, label=label)
print(f"Processed and saved: {file_path}")
# 处理测试数据
for filename in os.listdir(test_dir):
if filename.endswith('.npy.h5'):
file_path = os.path.join(test_dir, filename)
# 加载 h5 文件
with h5py.File(file_path, 'r+') as h5f:
image = h5f['image'][:]
label = h5f['label'][:]
# 重新映射标签
label = remap_labels(label)
# 删除原始 label 数据集并重新写入
del h5f['label']
h5f.create_dataset('label', data=label, compression="gzip")
print(f"Processed and saved: {file_path}")
print("标签重新映射完成!")
解决完成之后,就可以正常运行啦~当然没遇到最好不过了
结果
训练完成之后,就可以在 Swin-Unet/model_out/log.txt 得到训练信息
我用的是3080Ti,大概训练时间为3-4h
测试
测试命令如下:
bash
python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_savenii --root_path ./data/Synapse --output_dir ./model_out --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24
注意点:
- 官网中是--volume_path,我改成了--root_path,因为源码中是使用了root_path
- max_epoch 和 batch_size 得对应上面训练时的参数
遇到的问题:
一、'Namespace' object has no attribute 'volume_path'
将 test.py 中,代码改为如下:
python
if args.dataset == "Synapse":
args.volume_path = os.path.join(args.root_path, "test_vol_h5")
二、不存在 test.txt 文件
这个文件是根据我们的测试集数据的文件名得到的,非自带,可以 GPT 生成脚本 在 Swin-Unet/lists/Synapse 目录下创建 test.txt 文件,内容如下:
bash
test_vol_h5/case0001
test_vol_h5/case0002
test_vol_h5/case0003
test_vol_h5/case0004
test_vol_h5/case0008
test_vol_h5/case0022
test_vol_h5/case0025
test_vol_h5/case0029
test_vol_h5/case0032
test_vol_h5/case0035
test_vol_h5/case0036
test_vol_h5/case0038
三、cannot select an axis to squeeze out which has size not equal to one
代码中维度问题,出问题的代码如下
python
image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)
改为如下
python
image = image.cpu().detach().numpy()
label = label.cpu().detach().numpy()
if image.shape[0] == 1:
image = image.squeeze(0)
if label.shape[0] == 1:
label = label.squeeze(0)
#image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)
结果
运行上述命令行结束后,得到以下结果 Over!
祝大家科研顺利~有问题可以在评论区交流