infoGCN++——尝试训练

目录

一、前言

二、下载数据集

二、基础conda环境

[🎯 方案一:使用 --no-build-isolation 参数(最推荐)](#🎯 方案一:使用 --no-build-isolation 参数(最推荐))

[🔧 方案二:在激活的Conda环境中编译](#🔧 方案二:在激活的Conda环境中编译)

[💡 方案三:安装指定的旧版Apex(备选方案)](#💡 方案三:安装指定的旧版Apex(备选方案))

[🧹 方案四:环境检查与修复](#🧹 方案四:环境检查与修复)

[二、生成skes_available_name.txt 文件](#二、生成skes_available_name.txt 文件)

[三、python get_raw_skes_data.py](#三、python get_raw_skes_data.py)

[🔍 根本原因:数据集自带损坏样本](#🔍 根本原因:数据集自带损坏样本)

改动说明

[💡 运行建议](#💡 运行建议)

[四、python get_raw_denoised_data.py](#四、python get_raw_denoised_data.py)

[🔍 错误定位](#🔍 错误定位)

[💡 解决方案](#💡 解决方案)

[方案:在 get_two_actors_points 中捕获空列表](#方案:在 get_two_actors_points 中捕获空列表)

[五、python seq_transformation.py](#五、python seq_transformation.py)

[📋 所需文件清单](#📋 所需文件清单)

[📝 文件名解析规则](#📝 文件名解析规则)

注意事项

[🔧 解决方案](#🔧 解决方案)

修改方法

完整的修改后代码段示例:

[💡 额外检查](#💡 额外检查)

六、seq_transformation.py是做什么的?它的输入输出和和中间操作是什么?

[📌 在整体流程中的位置](#📌 在整体流程中的位置)

[🎯 脚本的核心任务](#🎯 脚本的核心任务)

[📥 输入文件一览](#📥 输入文件一览)

[⚙️ 中间处理流程详解](#⚙️ 中间处理流程详解)

[📤 最终输出文件](#📤 最终输出文件)

[七、python main.py](#七、python main.py)

一、整体代码结构解析

[1. 入口与初始化 (main → Processor.init)](#1. 入口与初始化 (main → Processor.init))

[2. 数据加载 (load_data)](#2. 数据加载 (load_data))

[3. 模型结构 (SODE)](#3. 模型结构 (SODE))

[4. 训练流程 (train)](#4. 训练流程 (train))

[5. 评估流程 (eval)](#5. 评估流程 (eval))

[6. 学习率调整 (adjust_learning_rate)](#6. 学习率调整 (adjust_learning_rate))

[二、针对 NTU60 数据集的训练命令修改](#二、针对 NTU60 数据集的训练命令修改)

[示例命令(NTU60 训练)](#示例命令(NTU60 训练))

一、四种文件的作用与区别

关键转换逻辑(create_aligned_dataset)

八、wandb改为离线模式

[方法一:通过环境变量禁用 wandb(推荐,无需改代码)](#方法一:通过环境变量禁用 wandb(推荐,无需改代码))

[方法二:在代码中强制设置 wandb 为离线/禁用模式](#方法二:在代码中强制设置 wandb 为离线/禁用模式)

[方法三:添加命令行参数控制 wandb 开关(一劳永逸)](#方法三:添加命令行参数控制 wandb 开关(一劳永逸))

[方法四:直接注释掉 wandb 相关代码(最暴力)](#方法四:直接注释掉 wandb 相关代码(最暴力))

快速选择建议


一、前言

这一篇我们尝试跑通这个项目的训练,我在linux系统上跑通了,如果你选择用windows系统跑这个训练可能还是会出一些我没有提到的问题。

https://github.com/stnoah1/infogcn2/tree/main

二、下载数据集

作者提到要下载下面三个数据集。

  • NTU RGB+D 60 Skeleton
  • NTU RGB+D 120 Skeleton
  • NW-UCLA

我下载了NTU RGB+D 60 Skeleton,然后用这个数据集跑了数据处理然后跑通了训练。

https://aistudio.baidu.com/datasetdetail/146482

下载完了之后解压放到这,这个路径下面就是有很多.skeletion文件

D:\zero_track\infogcn2\data\nturgbd_raw\nturgb+d_skeletons

二、基础conda环境

我还是使用我习惯的python3.12+torch 2.5.1+cu121的conda环境,具体可参考环境部署专栏的《windows系统python开源项目环境配置1》

python 复制代码
# 1. 克隆项目
git clone https://github.com/stnoah1/infogcn2.git
cd infogcn2

# 2. 创建并激活虚拟环境 (推荐)
# conda create -n infogcn2 python=3.8 -y
conda create -n infogcn2 python=3.12 -y  # 我是使用3.12的
conda activate infogcn2

# 3. 安装PyTorch (请根据你的环境调整命令)
# 例如,在CUDA 11.3环境下:
conda install pytorch==1.12.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
# 我是使用torch 2.5.1

# 4. 安装其余依赖
pip install tqdm tensorboardX wandb einops torchdiffeq

# 5. 安装NVIDIA Apex (用于混合精度训练加速)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..

如果运行下面的命令:
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
报错: File "/tmp/pip-build-env-bf2piphf/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 317, in run_setup exec(code, locals()) ModuleNotFoundError: No module named 'torch'

考虑使用:pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

如果还是报错,后面运行python main.py的时候我们直接通过--half=False禁用Apex混合精度训练

如果运行下面的命令:

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

报错: File "/tmp/pip-build-env-bf2piphf/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 317, in run_setup exec(code, locals()) ModuleNotFoundError: No module named 'torch'

这个错误是因为Apex编译时要求在隔离的构建环境中能找到 torch 模块-4

下面是几种已验证的解决方案,按推荐程度排序:

🎯 方案一:使用 --no-build-isolation 参数(最推荐)

这个方案可以复用当前已安装 torch 的环境,无需重新下载依赖,成功率很高。在 apex 目录下执行以下命令即可-2

bash

复制代码
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

注意 :此命令使用的是 ./ (当前目录),而不是你命令中的 ./

🔧 方案二:在激活的Conda环境中编译

  1. 激活你的Conda环境:确保你使用的是已安装PyTorch的conda环境-。

    bash
    conda activate your_env_name

  2. 使用 python -m pip 执行安装:有时这能解决路径冲突问题-。

    bash

    复制代码
    python -m pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

💡 方案三:安装指定的旧版Apex(备选方案)

如果上述方法因兼容性问题失败,可以尝试一个经过验证的、兼容性较好的旧版本-8

bash

复制代码
git clone https://github.com/ptrblck/apex.git
cd apex
git checkout apex_no_distributed
pip install -v --no-cache-dir ./

🧹 方案四:环境检查与修复

如果以上方法均无效,可能是环境有更深层的冲突。

  1. 清理环境 :卸载任何可能与官方Apex冲突的 apex-21

    bash

    复制代码
    pip uninstall apex
  2. 版本兼容性检查 :使用以下命令确认你的PyTorch和CUDA版本是否匹配-21

    python

    复制代码
    import torch
    print(torch.__version__)
    print(torch.version.cuda)
  3. 安装编译工具 :确保系统有完整的C++编译环境-7

    bash

    复制代码
    # Ubuntu/Debian
    sudo apt-get update
    sudo apt-get install -y build-essential

通常,方案一能解决大部分问题。
self.model, self.optimizer = apex.amp.initialize( AttributeError: module 'apex' has no attribute 'amp'

vscode替换:将np.int替换为int

二、生成skes_available_name.txt 文件

具体步骤如下:

  1. 检查数据集下载路径 :确保已将 NTU RGB+D 数据集的骨架文件(.skeleton 文件)放置在 ../nturgbd_raw/nturgb+d_skeletons/ 目录下-38

  2. 创建 statistics 目录 :在 ./data/ntu/ 目录下创建 statistics 文件夹。

  3. 生成 skes_available_name.txt 文件 :手动生成包含所有骨架文件名的列表。你可以使用 Python 脚本快速完成此操作,在 ./data/ntu/ 目录下创建一个 Python 文件(例如 gen_skes_list.py),并写入以下代码:

python 复制代码
import os
import glob

# 设置骨架文件所在目录
skes_dir = '../nturgbd_raw/nturgb+d_skeletons/'
# 设置输出文件路径
output_file = './statistics/skes_available_name.txt'

# 查找所有 .skeleton 文件
skes_files = glob.glob(os.path.join(skes_dir, '*.skeleton'))

# 提取文件名(不带扩展名)
skes_names = [os.path.splitext(os.path.basename(f))[0] for f in skes_files]

# 排序(可选)
skes_names.sort()

# 写入到文件
with open(output_file, 'w') as f:
    for name in skes_names:
        f.write(name + '\n')

print(f"Generated {output_file} with {len(skes_names)} entries.")

运行上面的脚本

python 复制代码
(infogcn2_env) D:\zero_track\infogcn2\data\ntu> python gen_skes_list.py

三、python get_raw_skes_data.py

运行 python get_raw_skes_data.py又遇到另一个问题 assert num_frames_drop < num_frames AssertionError: Error: All frames data (71) of S001C002P006R001A008 is missing or lost

这个 AssertionError 错误的意思是,脚本在处理文件 S001C002P006R001A008 时,发现它的所有71帧数据都丢失了,导致无法继续处理。这不是你的代码或配置问题,而是NTU数据集本身的一个已知情况。

🔍 根本原因:数据集自带损坏样本

官方发布的 NTU RGB+D 数据集中,有少数样本的骨骼数据本身就存在缺失或不完整。具体来说,在包含56,880个样本的版本中,有302个样本有此类问题-7-10。你遇到的 S001C002P006R001A008 正是其中之一。

python 复制代码
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os.path as osp
import os
import numpy as np
import pickle
import logging

def get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger):
    """
    修改后:若全帧丢失,返回 None 而不是触发 AssertionError。
    """
    ske_file = osp.join(skes_path, ske_name + '.skeleton')
    assert osp.exists(ske_file), 'Error: Skeleton file %s not found' % ske_file
    print('Reading data from %s' % ske_file[-29:])
    with open(ske_file, 'r') as fr:
        str_data = fr.readlines()

    num_frames = int(str_data[0].strip('\r\n'))
    frames_drop = []
    bodies_data = dict()
    valid_frames = -1
    current_line = 1

    for f in range(num_frames):
        num_bodies = int(str_data[current_line].strip('\r\n'))
        current_line += 1

        if num_bodies == 0:
            frames_drop.append(f)
            continue

        valid_frames += 1
        joints = np.zeros((num_bodies, 25, 3), dtype=np.float32)
        colors = np.zeros((num_bodies, 25, 2), dtype=np.float32)

        for b in range(num_bodies):
            bodyID = str_data[current_line].strip('\r\n').split()[0]
            current_line += 1
            num_joints = int(str_data[current_line].strip('\r\n'))
            current_line += 1

            for j in range(num_joints):
                temp_str = str_data[current_line].strip('\r\n').split()
                joints[b, j, :] = np.array(temp_str[:3], dtype=np.float32)
                colors[b, j, :] = np.array(temp_str[5:7], dtype=np.float32)
                current_line += 1

            if bodyID not in bodies_data:
                body_data = dict()
                body_data['joints'] = joints[b]
                body_data['colors'] = colors[b, np.newaxis]
                body_data['interval'] = [valid_frames]
            else:
                body_data = bodies_data[bodyID]
                body_data['joints'] = np.vstack((body_data['joints'], joints[b]))
                body_data['colors'] = np.vstack((body_data['colors'], colors[b, np.newaxis]))
                pre_frame_idx = body_data['interval'][-1]
                body_data['interval'].append(pre_frame_idx + 1)

            bodies_data[bodyID] = body_data

    num_frames_drop = len(frames_drop)
    
    # 关键修改:用条件判断替代 assert
    if num_frames_drop >= num_frames:
        print(f'Warning: All frames ({num_frames}) of {ske_name} are missing or lost. Skipping this sample.')
        frames_drop_skes[ske_name] = np.array(frames_drop, dtype=int)
        frames_drop_logger.info(f'{ske_name}: ALL FRAMES MISSED ({num_frames_drop}/{num_frames})\n')
        return None  # 返回 None 表示该样本完全不可用

    if num_frames_drop > 0:
        frames_drop_skes[ske_name] = np.array(frames_drop, dtype=int)
        frames_drop_logger.info(f'{ske_name}: {num_frames_drop} frames missed: {frames_drop}\n')

    if len(bodies_data) > 1:
        for body_data in bodies_data.values():
            body_data['motion'] = np.sum(np.var(body_data['joints'], axis=0))

    return {'name': ske_name, 'data': bodies_data, 'num_frames': num_frames - num_frames_drop}


def get_raw_skes_data():
    skes_name = np.loadtxt(skes_name_file, dtype=str)
    num_files = skes_name.size
    print('Found %d available skeleton files.' % num_files)

    raw_skes_data = []
    frames_cnt = []
    corrupted_files = []  # 记录损坏的文件名

    for (idx, ske_name) in enumerate(skes_name):
        bodies_data = get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger)
        if bodies_data is None:
            corrupted_files.append(ske_name)  # 记录损坏文件
            frames_cnt.append(0)  # 占位
        else:
            raw_skes_data.append(bodies_data)
            frames_cnt.append(bodies_data['num_frames'])

        if (idx + 1) % 1000 == 0:
            print('Processed: %.2f%% (%d / %d)' % \
                  (100.0 * (idx + 1) / num_files, idx + 1, num_files))

    # 保存正常数据
    with open(save_data_pkl, 'wb') as fw:
        pickle.dump(raw_skes_data, fw, pickle.HIGHEST_PROTOCOL)
    np.savetxt(osp.join(save_path, 'raw_data', 'frames_cnt.txt'), np.array(frames_cnt), fmt='%d')

    print('Saved raw bodies data into %s' % save_data_pkl)
    print('Total valid samples: %d' % len(raw_skes_data))

    # 更新 skes_available_name.txt,移除损坏的文件名
    if corrupted_files:
        print(f'Found {len(corrupted_files)} corrupted samples. Updating skes_available_name.txt...')
        # 过滤掉损坏的文件名
        clean_names = [name for name in skes_name if name not in corrupted_files]
        # 覆盖原文件(或另存为 clean 版本,这里直接覆盖)
        with open(skes_name_file, 'w') as f:
            for name in clean_names:
                f.write(name + '\n')
        print(f'Updated {skes_name_file} with {len(clean_names)} clean entries.')
    else:
        print('No corrupted samples found.')


if __name__ == '__main__':
    save_path = './'
    skes_path = '../nturgbd_raw/nturgb+d_skeletons/'
    stat_path = osp.join(save_path, 'statistics')
    if not osp.exists('./raw_data'):
        os.makedirs('./raw_data')

    skes_name_file = osp.join(stat_path, 'skes_available_name.txt')
    save_data_pkl = osp.join(save_path, 'raw_data', 'raw_skes_data.pkl')
    frames_drop_pkl = osp.join(save_path, 'raw_data', 'frames_drop_skes.pkl')

    frames_drop_logger = logging.getLogger('frames_drop')
    frames_drop_logger.setLevel(logging.INFO)
    frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'raw_data', 'frames_drop.log')))
    frames_drop_skes = dict()

    get_raw_skes_data()

    with open(frames_drop_pkl, 'wb') as fw:
        pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL)

运行上面的脚本

python 复制代码
(infogcn2_env) D:\zero_track\infogcn2\data\ntu> python get_raw_skes_data.py

改动说明

改动位置 原代码行为 新代码行为
get_raw_bodies_data 底部 assert num_frames_drop < num_frames 若全帧丢失,打印警告并返回 None
get_raw_skes_data 循环内 无异常处理 检查返回值,若为 None 则记录损坏文件名并跳过保存
文件列表更新 处理完成后,从 skes_available_name.txt 中删除所有损坏文件名

💡 运行建议

  1. 备份原文件 :先将 get_raw_skes_data.py 备份一份。

  2. 替换代码:用上面完整代码替换原内容。

  3. 重新运行

    复制代码
    python get_raw_skes_data.py

脚本将自动处理所有文件,并在最后更新 statistics/skes_available_name.txt,让你得到一个完全干净的文件列表。同时,日志文件 frames_drop.log 中也会记录哪些文件被完全丢弃。

运行结果:

Saved raw bodies data into ./raw_data\raw_skes_data.pkl

Total valid samples: 56715

Found 165 corrupted samples. Updating skes_available_name.txt...

Updated ./statistics\skes_available_name.txt with 56715 clean entries.

四、python get_raw_denoised_data.py

报错:bodyID, actor1 = bodies_data[0]

IndexError: list index out of range

这个 IndexError 错误是因为 denoising_bodies_data 函数在某些样本上去噪后返回了一个空列表 ,导致 bodies_data[0] 取不到任何元素。这通常发生在某个样本的所有人体骨架都因帧数过短(≤11帧)被 denoising_by_length 全部过滤掉,最终没有留下任何有效 bodyID。

🔍 错误定位

复制代码
# get_two_actors_points 函数内
bodies_data, noise_info = denoising_bodies_data(bodies_data)  # 可能返回空列表
bodies_data = list(bodies_data)
if len(bodies_data) == 1:
    ...
else:
    bodyID, actor1 = bodies_data[0]   # 此处索引越界

💡 解决方案

需要在 denoising_bodies_data 或调用处增加对空列表的检查。

方案:在 get_two_actors_points 中捕获空列表

get_two_actors_points 函数开头附近,调用 denoising_bodies_data 后增加判断,若结果为空则返回 None,并在主循环中跳过该样本。

修改 get_two_actors_points 函数如下(只展示关键部分):

复制代码
def get_two_actors_points(bodies_data):
    ske_name = bodies_data['name']
    label = int(ske_name[-2:])
    num_frames = bodies_data['num_frames']
    bodies_info = get_bodies_info(bodies_data['data'])

    bodies_data, noise_info = denoising_bodies_data(bodies_data)
    bodies_info += noise_info
    bodies_data = list(bodies_data)

    # ---- 新增检查 ----
    if len(bodies_data) == 0:
        print(f'Warning: No valid bodyID left after denoising for {ske_name}. Skipping.')
        return None, None  # 返回 None 表示该样本无效
    # ----------------

    if len(bodies_data) == 1:
        ...

同时,在主函数 get_raw_denoised_data 的循环中增加对返回值的判断:

复制代码
for (idx, bodies_data) in enumerate(raw_skes_data):
    ske_name = bodies_data['name']
    print('Processing %s' % ske_name)
    num_bodies = len(bodies_data['data'])

    if num_bodies == 1:
        ...
    else:
        joints, colors = get_two_actors_points(bodies_data)
        if joints is None:           # 跳过无效样本
            continue
        joints, colors = remove_missing_frames(ske_name, joints, colors)
        num_frames = joints.shape[0]

    raw_denoised_joints.append(joints)
    raw_denoised_colors.append(colors)
    frames_cnt.append(num_frames)
    ...
python 复制代码
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import os.path as osp
import numpy as np
import pickle
import logging

root_path = './'
raw_data_file = osp.join(root_path, 'raw_data', 'raw_skes_data.pkl')
save_path = osp.join(root_path, 'denoised_data')

if not osp.exists(save_path):
    os.mkdir(save_path)

rgb_ske_path = osp.join(save_path, 'rgb+ske')
if not osp.exists(rgb_ske_path):
    os.mkdir(rgb_ske_path)

actors_info_dir = osp.join(save_path, 'actors_info')
if not osp.exists(actors_info_dir):
    os.mkdir(actors_info_dir)

missing_count = 0
noise_len_thres = 11
noise_spr_thres1 = 0.8
noise_spr_thres2 = 0.69754
noise_mot_thres_lo = 0.089925
noise_mot_thres_hi = 2

noise_len_logger = logging.getLogger('noise_length')
noise_len_logger.setLevel(logging.INFO)
noise_len_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_length.log')))
noise_len_logger.info('{:^20}\t{:^17}\t{:^8}\t{}'.format('Skeleton', 'bodyID', 'Motion', 'Length'))

noise_spr_logger = logging.getLogger('noise_spread')
noise_spr_logger.setLevel(logging.INFO)
noise_spr_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_spread.log')))
noise_spr_logger.info('{:^20}\t{:^17}\t{:^8}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion', 'Rate'))

noise_mot_logger = logging.getLogger('noise_motion')
noise_mot_logger.setLevel(logging.INFO)
noise_mot_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_motion.log')))
noise_mot_logger.info('{:^20}\t{:^17}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion'))

fail_logger_1 = logging.getLogger('noise_outliers_1')
fail_logger_1.setLevel(logging.INFO)
fail_logger_1.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_1.log')))

fail_logger_2 = logging.getLogger('noise_outliers_2')
fail_logger_2.setLevel(logging.INFO)
fail_logger_2.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_2.log')))

missing_skes_logger = logging.getLogger('missing_frames')
missing_skes_logger.setLevel(logging.INFO)
missing_skes_logger.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes.log')))
missing_skes_logger.info('{:^20}\t{}\t{}'.format('Skeleton', 'num_frames', 'num_missing'))

missing_skes_logger1 = logging.getLogger('missing_frames_1')
missing_skes_logger1.setLevel(logging.INFO)
missing_skes_logger1.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_1.log')))
missing_skes_logger1.info('{:^20}\t{}\t{}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1',
                                                              'Actor2', 'Start', 'End'))

missing_skes_logger2 = logging.getLogger('missing_frames_2')
missing_skes_logger2.setLevel(logging.INFO)
missing_skes_logger2.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_2.log')))
missing_skes_logger2.info('{:^20}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 'Actor2'))


def denoising_by_length(ske_name, bodies_data):
    """
    Denoising data based on the frame length for each bodyID.
    Filter out the bodyID which length is less or equal than the predefined threshold.

    """
    noise_info = str()
    new_bodies_data = bodies_data.copy()
    for (bodyID, body_data) in new_bodies_data.items():
        length = len(body_data['interval'])
        if length <= noise_len_thres:
            noise_info += 'Filter out: %s, %d (length).\n' % (bodyID, length)
            noise_len_logger.info('{}\t{}\t{:.6f}\t{:^6d}'.format(ske_name, bodyID,
                                                                  body_data['motion'], length))
            del bodies_data[bodyID]
    if noise_info != '':
        noise_info += '\n'

    return bodies_data, noise_info


def get_valid_frames_by_spread(points):
    """
    Find the valid (or reasonable) frames (index) based on the spread of X and Y.

    :param points: joints or colors
    """
    num_frames = points.shape[0]
    valid_frames = []
    for i in range(num_frames):
        x = points[i, :, 0]
        y = points[i, :, 1]
        if (x.max() - x.min()) <= noise_spr_thres1 * (y.max() - y.min()):  # 0.8
            valid_frames.append(i)
    return valid_frames


def denoising_by_spread(ske_name, bodies_data):
    """
    Denoising data based on the spread of Y value and X value.
    Filter out the bodyID which the ratio of noisy frames is higher than the predefined
    threshold.

    bodies_data: contains at least 2 bodyIDs
    """
    noise_info = str()
    denoised_by_spr = False  # mark if this sequence has been processed by spread.

    new_bodies_data = bodies_data.copy()
    # for (bodyID, body_data) in bodies_data.items():
    for (bodyID, body_data) in new_bodies_data.items():
        if len(bodies_data) == 1:
            break
        valid_frames = get_valid_frames_by_spread(body_data['joints'].reshape(-1, 25, 3))
        num_frames = len(body_data['interval'])
        num_noise = num_frames - len(valid_frames)
        if num_noise == 0:
            continue

        ratio = num_noise / float(num_frames)
        motion = body_data['motion']
        if ratio >= noise_spr_thres2:  # 0.69754
            del bodies_data[bodyID]
            denoised_by_spr = True
            noise_info += 'Filter out: %s (spread rate >= %.2f).\n' % (bodyID, noise_spr_thres2)
            noise_spr_logger.info('%s\t%s\t%.6f\t%.6f' % (ske_name, bodyID, motion, ratio))
        else:  # Update motion
            joints = body_data['joints'].reshape(-1, 25, 3)[valid_frames]
            body_data['motion'] = min(motion, np.sum(np.var(joints.reshape(-1, 3), axis=0)))
            noise_info += '%s: motion %.6f -> %.6f\n' % (bodyID, motion, body_data['motion'])
            # TODO: Consider removing noisy frames for each bodyID

    if noise_info != '':
        noise_info += '\n'

    return bodies_data, noise_info, denoised_by_spr


def denoising_by_motion(ske_name, bodies_data, bodies_motion):
    """
    Filter out the bodyID which motion is out of the range of predefined interval

    """
    # Sort bodies based on the motion, return a list of tuples
    # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True)
    bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True)

    # Reserve the body data with the largest motion
    denoised_bodies_data = [(bodies_motion[0][0], bodies_data[bodies_motion[0][0]])]
    noise_info = str()

    for (bodyID, motion) in bodies_motion[1:]:
        if (motion < noise_mot_thres_lo) or (motion > noise_mot_thres_hi):
            noise_info += 'Filter out: %s, %.6f (motion).\n' % (bodyID, motion)
            noise_mot_logger.info('{}\t{}\t{:.6f}'.format(ske_name, bodyID, motion))
        else:
            denoised_bodies_data.append((bodyID, bodies_data[bodyID]))
    if noise_info != '':
        noise_info += '\n'

    return denoised_bodies_data, noise_info


def denoising_bodies_data(bodies_data):
    """
    Denoising data based on some heuristic methods, not necessarily correct for all samples.

    Return:
      denoised_bodies_data (list): tuple: (bodyID, body_data).
    """
    ske_name = bodies_data['name']
    bodies_data = bodies_data['data']

    # Step 1: Denoising based on frame length.
    bodies_data, noise_info_len = denoising_by_length(ske_name, bodies_data)

    if len(bodies_data) == 1:  # only has one bodyID left after step 1
        return bodies_data.items(), noise_info_len

    # Step 2: Denoising based on spread.
    bodies_data, noise_info_spr, denoised_by_spr = denoising_by_spread(ske_name, bodies_data)

    if len(bodies_data) == 1:
        return bodies_data.items(), noise_info_len + noise_info_spr

    bodies_motion = dict()  # get body motion
    for (bodyID, body_data) in bodies_data.items():
        bodies_motion[bodyID] = body_data['motion']
    # Sort bodies based on the motion
    # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True)
    bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True)
    denoised_bodies_data = list()
    for (bodyID, _) in bodies_motion:
        denoised_bodies_data.append((bodyID, bodies_data[bodyID]))

    return denoised_bodies_data, noise_info_len + noise_info_spr

    # TODO: Consider denoising further by integrating motion method

    # if denoised_by_spr:  # this sequence has been denoised by spread
    #     bodies_motion = sorted(bodies_motion.items(), lambda x, y: cmp(x[1], y[1]), reverse=True)
    #     denoised_bodies_data = list()
    #     for (bodyID, _) in bodies_motion:
    #         denoised_bodies_data.append((bodyID, bodies_data[bodyID]))
    #     return denoised_bodies_data, noise_info

    # Step 3: Denoising based on motion
    # bodies_data, noise_info = denoising_by_motion(ske_name, bodies_data, bodies_motion)

    # return bodies_data, noise_info


def get_one_actor_points(body_data, num_frames):
    """
    Get joints and colors for only one actor.
    For joints, each frame contains 75 X-Y-Z coordinates.
    For colors, each frame contains 25 x 2 (X, Y) coordinates.
    """
    joints = np.zeros((num_frames, 75), dtype=np.float32)
    colors = np.ones((num_frames, 1, 25, 2), dtype=np.float32) * np.nan
    start, end = body_data['interval'][0], body_data['interval'][-1]
    joints[start:end + 1] = body_data['joints'].reshape(-1, 75)
    colors[start:end + 1, 0] = body_data['colors']

    return joints, colors


def remove_missing_frames(ske_name, joints, colors):
    """
    Cut off missing frames which all joints positions are 0s

    For the sequence with 2 actors' data, also record the number of missing frames for
    actor1 and actor2, respectively (for debug).
    """
    num_frames = joints.shape[0]
    num_bodies = colors.shape[1]  # 1 or 2

    if num_bodies == 2:  # DEBUG
        missing_indices_1 = np.where(joints[:, :75].sum(axis=1) == 0)[0]
        missing_indices_2 = np.where(joints[:, 75:].sum(axis=1) == 0)[0]
        cnt1 = len(missing_indices_1)
        cnt2 = len(missing_indices_2)

        start = 1 if 0 in missing_indices_1 else 0
        end = 1 if num_frames - 1 in missing_indices_1 else 0
        if max(cnt1, cnt2) > 0:
            if cnt1 > cnt2:
                info = '{}\t{:^10d}\t{:^6d}\t{:^6d}\t{:^5d}\t{:^3d}'.format(ske_name, num_frames,
                                                                            cnt1, cnt2, start, end)
                missing_skes_logger1.info(info)
            else:
                info = '{}\t{:^10d}\t{:^6d}\t{:^6d}'.format(ske_name, num_frames, cnt1, cnt2)
                missing_skes_logger2.info(info)

    # Find valid frame indices that the data is not missing or lost
    # For two-subjects action, this means both data of actor1 and actor2 is missing.
    valid_indices = np.where(joints.sum(axis=1) != 0)[0]  # 0-based index
    missing_indices = np.where(joints.sum(axis=1) == 0)[0]
    num_missing = len(missing_indices)

    if num_missing > 0:  # Update joints and colors
        joints = joints[valid_indices]
        colors[missing_indices] = np.nan
        global missing_count
        missing_count += 1
        missing_skes_logger.info('{}\t{:^10d}\t{:^11d}'.format(ske_name, num_frames, num_missing))

    return joints, colors


def get_bodies_info(bodies_data):
    bodies_info = '{:^17}\t{}\t{:^8}\n'.format('bodyID', 'Interval', 'Motion')
    for (bodyID, body_data) in bodies_data.items():
        start, end = body_data['interval'][0], body_data['interval'][-1]
        bodies_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), body_data['motion'])

    return bodies_info + '\n'


def get_two_actors_points(bodies_data):
    """
    Get the first and second actor's joints positions and colors locations.

    # Arguments:
        bodies_data (dict): 3 key-value pairs: 'name', 'data', 'num_frames'.
        bodies_data['data'] is also a dict, while the key is bodyID, the value is
        the corresponding body_data which is also a dict with 4 keys:
          - joints: raw 3D joints positions. Shape: (num_frames x 25, 3)
          - colors: raw 2D color locations. Shape: (num_frames, 25, 2)
          - interval: a list which records the frame indices.
          - motion: motion amount

    # Return:
        joints, colors.
    """
    ske_name = bodies_data['name']
    label = int(ske_name[-2:])
    num_frames = bodies_data['num_frames']
    bodies_info = get_bodies_info(bodies_data['data'])

    bodies_data, noise_info = denoising_bodies_data(bodies_data)  # Denoising data
    bodies_info += noise_info

    bodies_data = list(bodies_data)

    # ---- 新增检查 ----
    if len(bodies_data) == 0:
        print(f'Warning: No valid bodyID left after denoising for {ske_name}. Skipping.')
        return None, None  # 返回 None 表示该样本无效
    # ----------------

    if len(bodies_data) == 1:  # Only left one actor after denoising
        if label >= 50:  # DEBUG: Denoising failed for two-subjects action
            fail_logger_2.info(ske_name)

        bodyID, body_data = bodies_data[0]
        joints, colors = get_one_actor_points(body_data, num_frames)
        bodies_info += 'Main actor: %s' % bodyID
    else:
        if label < 50:  # DEBUG: Denoising failed for one-subject action
            fail_logger_1.info(ske_name)

        joints = np.zeros((num_frames, 150), dtype=np.float32)
        colors = np.ones((num_frames, 2, 25, 2), dtype=np.float32) * np.nan

        bodyID, actor1 = bodies_data[0]  # the 1st actor with largest motion
        start1, end1 = actor1['interval'][0], actor1['interval'][-1]
        joints[start1:end1 + 1, :75] = actor1['joints'].reshape(-1, 75)
        colors[start1:end1 + 1, 0] = actor1['colors']
        actor1_info = '{:^17}\t{}\t{:^8}\n'.format('Actor1', 'Interval', 'Motion') + \
                      '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start1, end1]), actor1['motion'])
        del bodies_data[0]

        actor2_info = '{:^17}\t{}\t{:^8}\n'.format('Actor2', 'Interval', 'Motion')
        start2, end2 = [0, 0]  # initial interval for actor2 (virtual)

        while len(bodies_data) > 0:
            bodyID, actor = bodies_data[0]
            start, end = actor['interval'][0], actor['interval'][-1]
            if min(end1, end) - max(start1, start) <= 0:  # no overlap with actor1
                joints[start:end + 1, :75] = actor['joints'].reshape(-1, 75)
                colors[start:end + 1, 0] = actor['colors']
                actor1_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion'])
                # Update the interval of actor1
                start1 = min(start, start1)
                end1 = max(end, end1)
            elif min(end2, end) - max(start2, start) <= 0:  # no overlap with actor2
                joints[start:end + 1, 75:] = actor['joints'].reshape(-1, 75)
                colors[start:end + 1, 1] = actor['colors']
                actor2_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion'])
                # Update the interval of actor2
                start2 = min(start, start2)
                end2 = max(end, end2)
            del bodies_data[0]

        bodies_info += ('\n' + actor1_info + '\n' + actor2_info)

    with open(osp.join(actors_info_dir, ske_name + '.txt'), 'w') as fw:
        fw.write(bodies_info + '\n')

    return joints, colors


def get_raw_denoised_data():
    """
    Get denoised data (joints positions and color locations) from raw skeleton sequences.

    For each frame of a skeleton sequence, an actor's 3D positions of 25 joints represented
    by an 2D array (shape: 25 x 3) is reshaped into a 75-dim vector by concatenating each
    3-dim (x, y, z) coordinates along the row dimension in joint order. Each frame contains
    two actor's joints positions constituting a 150-dim vector. If there is only one actor,
    then the last 75 values are filled with zeros. Otherwise, select the main actor and the
    second actor based on the motion amount. Each 150-dim vector as a row vector is put into
    a 2D numpy array where the number of rows equals the number of valid frames. All such
    2D arrays are put into a list and finally the list is serialized into a cPickle file.

    For the skeleton sequence which contains two or more actors (mostly corresponds to the
    last 11 classes), the filename and actors' information are recorded into log files.
    For better understanding, also generate RGB+skeleton videos for visualization.
    """

    with open(raw_data_file, 'rb') as fr:  # load raw skeletons data
        raw_skes_data = pickle.load(fr)

    num_skes = len(raw_skes_data)
    print('Found %d available skeleton sequences.' % num_skes)

    raw_denoised_joints = []
    raw_denoised_colors = []
    frames_cnt = []

    for (idx, bodies_data) in enumerate(raw_skes_data):
        ske_name = bodies_data['name']
        print('Processing %s' % ske_name)
        num_bodies = len(bodies_data['data'])

        if num_bodies == 1:  # only 1 actor
            num_frames = bodies_data['num_frames']
            body_data = list(bodies_data['data'].values())[0]
            joints, colors = get_one_actor_points(body_data, num_frames)
        else:  # more than 1 actor, select two main actors
            joints, colors = get_two_actors_points(bodies_data)

            if joints is None:           # 跳过无效样本
                continue

            # Remove missing frames
            joints, colors = remove_missing_frames(ske_name, joints, colors)
            num_frames = joints.shape[0]  # Update
            # Visualize selected actors' skeletons on RGB videos.

        raw_denoised_joints.append(joints)
        raw_denoised_colors.append(colors)
        frames_cnt.append(num_frames)

        if (idx + 1) % 1000 == 0:
            print('Processed: %.2f%% (%d / %d), ' % \
                  (100.0 * (idx + 1) / num_skes, idx + 1, num_skes) + \
                  'Missing count: %d' % missing_count)

    raw_skes_joints_pkl = osp.join(save_path, 'raw_denoised_joints.pkl')
    with open(raw_skes_joints_pkl, 'wb') as f:
        pickle.dump(raw_denoised_joints, f, pickle.HIGHEST_PROTOCOL)

    raw_skes_colors_pkl = osp.join(save_path, 'raw_denoised_colors.pkl')
    with open(raw_skes_colors_pkl, 'wb') as f:
        pickle.dump(raw_denoised_colors, f, pickle.HIGHEST_PROTOCOL)

    frames_cnt = np.array(frames_cnt, dtype=int)
    np.savetxt(osp.join(save_path, 'frames_cnt.txt'), frames_cnt, fmt='%d')

    print('Saved raw denoised positions of {} frames into {}'.format(np.sum(frames_cnt),
                                                                     raw_skes_joints_pkl))
    print('Found %d files that have missing data' % missing_count)

if __name__ == '__main__':
    get_raw_denoised_data()

运行结果:

Saved raw denoised positions of 4779321 frames into ./denoised_data\raw_denoised_joints.pkl

Found 155 files that have missing data

五、python seq_transformation.py

**报错:**ModuleNotFoundError: No module named 'h5py'

**解决方案:**pip install h5py
**报错:**from sklearn.model_selection import train_test_split ModuleNotFoundError: No module named 'sklearn'

**解决方案:**pip install scikit-learn
(infogcn2_env) D:\zero_track\infogcn2\data\ntu> python seq_transformation.py

在data/ntu目录下运行python seq_transformation.py报错:

from utils import create_aligned_dataset

ModuleNotFoundError: No module named 'utils'

原因是seq_transformation.py有一行是:

sys.path.append(['../..']) #这一行的本意是想把上一个目录的上一个目录的路径(也就是项目的根目录)加入环境中,但是好像失败了。

解决方案:

我们把sys.path.append(['../..'])这一行注释了,然后换成自己的项目根路径:

sys.path.append(r"D:\zero_track\infogcn2")
**报错:**camera = np.loadtxt(camera_file, dtype=int) # camera id: 1, 2, 3 File "D:\miniforge3\envs\infogcn2_env\Lib\site-packages\numpy\lib\_datasource.py", line 529, in open raise FileNotFoundError(f"{path} not found.") FileNotFoundError: ./statistics\camera.txt not found.

这个错误是因为 seq_transformation.py 脚本需要一个记录每个骨架文件对应相机ID 的元数据文件 camera.txt,但该文件在 ./statistics/ 目录下不存在。

报错: performer = np.loadtxt(performer_file, dtype=int) # subject id: 1~40 File "D:\miniforge3\envs\infogcn2_env\Lib\site-packages\numpy\lib\_datasource.py", line 529, in open raise FileNotFoundError(f"{path} not found.") FileNotFoundError: ./statistics\performer.txt not found.

在seq_transformation.py中搜索np.loadtxt会发现它需要很多txt,我们写个脚本生成它们.

📋 所需文件清单

脚本中涉及读取的文件如下(均位于 ./data/ntu/statistics/ 目录):

变量名 文件名 内容说明
setup_file setup.txt 设置批次编号(S 后三位,如 001
camera_file camera.txt 相机编号(C 后三位,1-3)
performer_file performer.txt 受试者编号(P 后三位,1-40)
replication_file replication.txt 重复次数编号(R 后三位,1-2)
label_file label.txt 动作类别编号(A 后三位,1-60)
skes_name_file skes_available_name.txt 已存在的骨架文件名列表

skes_available_name.txt 已存在,其余5个文件缺失。

📝 文件名解析规则

NTU 骨架文件命名格式:SsssCcccPpppRrrrAaaa.skeleton

例如:S001C002P006R001A008.skeleton

  • sss:设置批次(如 001)

  • ccc:相机 ID(001, 002, 003)

  • ppp:受试者 ID(001-040)

  • rrr:重复次数(001, 002)

  • aaa:动作类别(001-060)

注意事项

  • 该脚本假设文件名严格遵循 NTU 官方命名规则,且已通过 skes_available_name.txt 排除了损坏样本。

  • label.txt 中保存的是原始动作编号(1-60),seq_transformation.py 会自动将其减 1 转为 0-59 的标签。

  • 脚本已处理可能存在的不同字段长度(如 S001C002),使用 split 方法确保准确提取。

生成这些文件后,再运行 python seq_transformation.py 就不会再报文件缺失的错误了。

python 复制代码
import os

# 路径配置
stat_dir = './data/ntu/statistics'
skes_name_file = os.path.join(stat_dir, 'skes_available_name.txt')

# 输出文件路径
setup_file = os.path.join(stat_dir, 'setup.txt')
camera_file = os.path.join(stat_dir, 'camera.txt')
performer_file = os.path.join(stat_dir, 'performer.txt')
replication_file = os.path.join(stat_dir, 'replication.txt')
label_file = os.path.join(stat_dir, 'label.txt')

# 读取所有文件名
with open(skes_name_file, 'r') as f:
    skes_names = [line.strip() for line in f]

# 初始化列表
setups = []
cameras = []
performers = []
replications = []
labels = []

for name in skes_names:
    # 示例:S001C002P006R001A008
    # 使用 split 方法提取各字段
    parts = name.split('C')
    setup = int(parts[0][1:])          # S001 -> 1
    rest = parts[1]
    camera_str, rest = rest[:3], rest[3:]   # 002
    camera = int(camera_str)
    performer_str, rest = rest.split('P')[1][:3], rest.split('P')[1][3:]  # 006
    performer = int(performer_str)
    replication_str, rest = rest.split('R')[1][:3], rest.split('R')[1][3:]  # 001
    replication = int(replication_str)
    label_str = rest.split('A')[1][:3]  # 008
    label = int(label_str)

    setups.append(setup)
    cameras.append(camera)
    performers.append(performer)
    replications.append(replication)
    labels.append(label)

# 写入各个文件
with open(setup_file, 'w') as f:
    for val in setups:
        f.write(f"{val}\n")
print(f"Generated {setup_file}")

with open(camera_file, 'w') as f:
    for val in cameras:
        f.write(f"{val}\n")
print(f"Generated {camera_file}")

with open(performer_file, 'w') as f:
    for val in performers:
        f.write(f"{val}\n")
print(f"Generated {performer_file}")

with open(replication_file, 'w') as f:
    for val in replications:
        f.write(f"{val}\n")
print(f"Generated {replication_file}")

with open(label_file, 'w') as f:
    for val in labels:
        f.write(f"{val}\n")
print(f"Generated {label_file}")

print("\nAll statistics files have been created successfully.")

报错:AttributeError: `np.string_` was removed in the NumPy 2.0 release. Use `np.bytes_` instead.

这个错误是因为 NumPy 2.0 移除了 np.string_ 这个别名。在 seq_transformation.py 中,有一行代码使用了它:

复制代码
skes_name = np.loadtxt(skes_name_file, dtype=np.string_)

🔧 解决方案

np.string_ 替换为 NumPy 2.0 兼容的 np.bytes_,或者直接使用 Python 内置的 str 类型。推荐改为 dtype=str,因为它返回的是 Unicode 字符串,更方便处理。

修改方法

找到 seq_transformation.py 中的这一行(大约在 if __name__ == '__main__': 代码块内),将其修改为:

复制代码
skes_name = np.loadtxt(skes_name_file, dtype=str)
完整的修改后代码段示例:
复制代码
if __name__ == '__main__':
    camera = np.loadtxt(camera_file, dtype=int)
    performer = np.loadtxt(performer_file, dtype=int)
    label = np.loadtxt(label_file, dtype=int) - 1

    frames_cnt = np.loadtxt(frames_file, dtype=int)
    skes_name = np.loadtxt(skes_name_file, dtype=str)  # 修改这一行

    with open(raw_skes_joints_pkl, 'rb') as fr:
        skes_joints = pickle.load(fr)
    ...

💡 额外检查

整个 InfoGCN2 项目中可能还有其他地方使用了 np.intnp.float 等 NumPy 1.x 的别名,在 NumPy 2.0 下都会报错。建议在项目根目录下运行一次全局搜索,将以下旧别名全部替换:

旧写法 新写法
np.int int
np.float float
np.string_ strnp.bytes_
np.bool bool

你可以用 IDE 的"在文件中查找"功能搜索这些关键词并批量替换。修改后重新运行 python data/ntu/seq_transformation.py 即可继续。
报错:

File "D:\zero_track\infogcn2\data\ntu\seq_transformation.py", line 172, in split_dataset test_x = skes_joints[test_indices] IndexError: index 56711 is out of bounds for axis 0 with size 56711

六、seq_transformation.py是做什么的?它的输入输出和和中间操作是什么?

seq_transformation.py 是 InfoGCN2 预处理流程中的第三步 ,也是核心转换脚本 。它的任务是将上一步去噪后的原始骨架数据,转换成可以直接输入深度学习模型的标准格式,并按官方协议划分好训练集和测试集。

📌 在整体流程中的位置

步骤 脚本 作用
1 get_raw_skes_data.py .skeleton 文件解析原始骨架,生成 raw_skes_data.pkl
2 get_raw_denoised_data.py 去噪、筛选主要演员,生成 raw_denoised_joints.pkl
3 seq_transformation.py 坐标变换、帧对齐、数据集划分 ← 当前脚本

🎯 脚本的核心任务

简单说,它做了三件事:

  1. 坐标归一化:将骨架坐标从世界坐标系转换到以脊柱中心为原点的局部坐标系,消除人物在画面中的位置差异。

  2. 帧数对齐:将所有动作序列填充或截断到相同长度(300帧),方便批量训练。

  3. 数据集划分 :按照 NTU 官方标准(Cross-Subject 和 Cross-View)划分训练/测试集,并保存为 .npz 格式。

📥 输入文件一览

脚本运行时需要读取以下文件(均来自前两步的输出或我们手动生成的元数据):

文件路径 内容 来源
./denoised_data/raw_denoised_joints.pkl 去噪后的骨架坐标列表 步骤2生成
./denoised_data/frames_cnt.txt 每个序列的有效帧数 步骤2生成
./statistics/skes_available_name.txt 有效样本的文件名列表 步骤1生成
./statistics/label.txt 每个样本的动作标签(1-60) 手动生成
./statistics/performer.txt 每个样本的受试者ID(1-40) 手动生成
./statistics/camera.txt 每个样本的相机ID(1-3) 手动生成

⚙️ 中间处理流程详解

  1. 坐标平移 (seq_translation)

    • 将每一帧的**脊柱中点(joint-2)**作为新原点,所有关节坐标减去该点坐标。

    • 目的:消除人体在画面中的绝对位置,让模型关注姿态本身。

  2. 帧对齐 (align_frames)

    • 将所有序列统一填充到最大帧数(通常为300帧),不足的补0。

    • 对于单人动作,将75维向量复制一份到后75维(模拟双人数据的150维格式)。

  3. 数据集划分 (split_dataset)

    • 根据 performer.txtcamera.txt 决定每个样本属于训练集还是测试集。

    • Cross-Subject (CS):按人物ID划分(20人训练,20人测试)。

    • Cross-View (CV):按相机ID划分(相机2、3训练,相机1测试)。

    • 同时生成对应的 one-hot 标签。

📤 最终输出文件

脚本运行完成后,会在项目根目录 下生成两个 .npz 文件:

文件名 内容
NTU60_CS.npz Cross-Subject 划分的训练/测试数据
NTU60_CV.npz Cross-View 划分的训练/测试数据

每个 .npz 文件包含四个数组:

  • x_train / x_test:形状为 (样本数, 300, 150) 的骨架序列

  • y_train / y_test:形状为 (样本数, 60) 的 one-hot 标签

这两个文件就是最终用于训练 InfoGCN 模型的数据集

python 复制代码
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import sys
sys.path.append(r"D:\zero_track\infogcn2")
import os
import os.path as osp
import numpy as np
import pickle
import logging
import h5py
from sklearn.model_selection import train_test_split
from utils import create_aligned_dataset

root_path = './'
stat_path = osp.join(root_path, 'statistics')
setup_file = osp.join(stat_path, 'setup.txt')
camera_file = osp.join(stat_path, 'camera.txt')
performer_file = osp.join(stat_path, 'performer.txt')
replication_file = osp.join(stat_path, 'replication.txt')
label_file = osp.join(stat_path, 'label.txt')
skes_name_file = osp.join(stat_path, 'skes_available_name.txt')

denoised_path = osp.join(root_path, 'denoised_data')
raw_skes_joints_pkl = osp.join(denoised_path, 'raw_denoised_joints.pkl')
frames_file = osp.join(denoised_path, 'frames_cnt.txt')

save_path = './'


if not osp.exists(save_path):
    os.mkdir(save_path)


def remove_nan_frames(ske_name, ske_joints, nan_logger):
    num_frames = ske_joints.shape[0]
    valid_frames = []

    for f in range(num_frames):
        if not np.any(np.isnan(ske_joints[f])):
            valid_frames.append(f)
        else:
            nan_indices = np.where(np.isnan(ske_joints[f]))[0]
            nan_logger.info('{}\t{:^5}\t{}'.format(ske_name, f + 1, nan_indices))

    return ske_joints[valid_frames]

def seq_translation(skes_joints):
    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        num_bodies = 1 if ske_joints.shape[1] == 75 else 2
        if num_bodies == 2:
            missing_frames_1 = np.where(ske_joints[:, :75].sum(axis=1) == 0)[0]
            missing_frames_2 = np.where(ske_joints[:, 75:].sum(axis=1) == 0)[0]
            cnt1 = len(missing_frames_1)
            cnt2 = len(missing_frames_2)

        i = 0  # get the "real" first frame of actor1
        while i < num_frames:
            if np.any(ske_joints[i, :75] != 0):
                break
            i += 1

        origin = np.copy(ske_joints[i, 3:6])  # new origin: joint-2

        for f in range(num_frames):
            if num_bodies == 1:
                ske_joints[f] -= np.tile(origin, 25)
            else:  # for 2 actors
                ske_joints[f] -= np.tile(origin, 50)

        if (num_bodies == 2) and (cnt1 > 0):
            ske_joints[missing_frames_1, :75] = np.zeros((cnt1, 75), dtype=np.float32)

        if (num_bodies == 2) and (cnt2 > 0):
            ske_joints[missing_frames_2, 75:] = np.zeros((cnt2, 75), dtype=np.float32)

        skes_joints[idx] = ske_joints  # Update

    return skes_joints


def frame_translation(skes_joints, skes_name, frames_cnt):
    nan_logger = logging.getLogger('nan_skes')
    nan_logger.setLevel(logging.INFO)
    nan_logger.addHandler(logging.FileHandler("./nan_frames.log"))
    nan_logger.info('{}\t{}\t{}'.format('Skeleton', 'Frame', 'Joints'))

    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        # Calculate the distance between spine base (joint-1) and spine (joint-21)
        j1 = ske_joints[:, 0:3]
        j21 = ske_joints[:, 60:63]
        dist = np.sqrt(((j1 - j21) ** 2).sum(axis=1))

        for f in range(num_frames):
            origin = ske_joints[f, 3:6]  # new origin: middle of the spine (joint-2)
            if (ske_joints[f, 75:] == 0).all():
                ske_joints[f, :75] = (ske_joints[f, :75] - np.tile(origin, 25)) / \
                                      dist[f] + np.tile(origin, 25)
            else:
                ske_joints[f] = (ske_joints[f] - np.tile(origin, 50)) / \
                                 dist[f] + np.tile(origin, 50)

        ske_name = skes_name[idx]
        ske_joints = remove_nan_frames(ske_name, ske_joints, nan_logger)
        frames_cnt[idx] = num_frames  # update valid number of frames
        skes_joints[idx] = ske_joints

    return skes_joints, frames_cnt


def align_frames(skes_joints, frames_cnt):
    """
    Align all sequences with the same frame length.

    """
    num_skes = len(skes_joints)
    max_num_frames = frames_cnt.max()  # 300
    aligned_skes_joints = np.zeros((num_skes, max_num_frames, 150), dtype=np.float32)

    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        num_bodies = 1 if ske_joints.shape[1] == 75 else 2
        if num_bodies == 1:
            aligned_skes_joints[idx, :num_frames] = np.hstack((ske_joints,
                                                               np.zeros_like(ske_joints)))
        else:
            aligned_skes_joints[idx, :num_frames] = ske_joints

    return aligned_skes_joints


def one_hot_vector(labels):
    num_skes = len(labels)
    labels_vector = np.zeros((num_skes, 60))
    for idx, l in enumerate(labels):
        labels_vector[idx, l] = 1

    return labels_vector


def split_train_val(train_indices, method='sklearn', ratio=0.05):
    """
    Get validation set by splitting data randomly from training set with two methods.
    In fact, I thought these two methods are equal as they got the same performance.

    """
    if method == 'sklearn':
        return train_test_split(train_indices, test_size=ratio, random_state=10000)
    else:
        np.random.seed(10000)
        np.random.shuffle(train_indices)
        val_num_skes = int(np.ceil(0.05 * len(train_indices)))
        val_indices = train_indices[:val_num_skes]
        train_indices = train_indices[val_num_skes:]
        return train_indices, val_indices


def split_dataset(skes_joints, label, performer, camera, evaluation, save_path):
    train_indices, test_indices = get_indices(performer, camera, evaluation)
    m = 'sklearn'  # 'sklearn' or 'numpy'
    # Select validation set from training set
    # train_indices, val_indices = split_train_val(train_indices, m)

    # === 新增过滤步骤:确保索引在有效范围内 ===
    num_samples = len(skes_joints)
    train_indices = train_indices[train_indices < num_samples]
    test_indices = test_indices[test_indices < num_samples]
    # ==========================================

    # Save labels and num_frames for each sequence of each data set
    train_labels = label[train_indices]
    test_labels = label[test_indices]

    train_x = skes_joints[train_indices]
    train_y = one_hot_vector(train_labels)
    test_x = skes_joints[test_indices]
    test_y = one_hot_vector(test_labels)

    save_name = 'NTU60_%s.npz' % evaluation
    np.savez(save_name, x_train=train_x, y_train=train_y, x_test=test_x, y_test=test_y)

    # Save data into a .h5 file
    # h5file = h5py.File(osp.join(save_path, 'NTU_%s.h5' % (evaluation)), 'w')
    # Training set
    # h5file.create_dataset('x', data=skes_joints[train_indices])
    # train_one_hot_labels = one_hot_vector(train_labels)
    # h5file.create_dataset('y', data=train_one_hot_labels)
    # Validation set
    # h5file.create_dataset('valid_x', data=skes_joints[val_indices])
    # val_one_hot_labels = one_hot_vector(val_labels)
    # h5file.create_dataset('valid_y', data=val_one_hot_labels)
    # Test set
    # h5file.create_dataset('test_x', data=skes_joints[test_indices])
    # test_one_hot_labels = one_hot_vector(test_labels)
    # h5file.create_dataset('test_y', data=test_one_hot_labels)

    # h5file.close()


def get_indices(performer, camera, evaluation='CS'):
    test_indices = np.empty(0)
    train_indices = np.empty(0)

    if evaluation == 'CS':  # Cross Subject (Subject IDs)
        train_ids = [1,  2,  4,  5,  8,  9,  13, 14, 15, 16,
                     17, 18, 19, 25, 27, 28, 31, 34, 35, 38]
        test_ids = [3,  6,  7,  10, 11, 12, 20, 21, 22, 23,
                    24, 26, 29, 30, 32, 33, 36, 37, 39, 40]

        # Get indices of test data
        for idx in test_ids:
            temp = np.where(performer == idx)[0]  # 0-based index
            test_indices = np.hstack((test_indices, temp)).astype(int)

        # Get indices of training data
        for train_id in train_ids:
            temp = np.where(performer == train_id)[0]  # 0-based index
            train_indices = np.hstack((train_indices, temp)).astype(int)
    else:  # Cross View (Camera IDs)
        train_ids = [2, 3]
        test_ids = 1
        # Get indices of test data
        temp = np.where(camera == test_ids)[0]  # 0-based index
        test_indices = np.hstack((test_indices, temp)).astype(int)

        # Get indices of training data
        for train_id in train_ids:
            temp = np.where(camera == train_id)[0]  # 0-based index
            train_indices = np.hstack((train_indices, temp)).astype(int)

    return train_indices, test_indices


if __name__ == '__main__':
    camera = np.loadtxt(camera_file, dtype=int)  # camera id: 1, 2, 3
    performer = np.loadtxt(performer_file, dtype=int)  # subject id: 1~40
    label = np.loadtxt(label_file, dtype=int) - 1  # action label: 0~59

    frames_cnt = np.loadtxt(frames_file, dtype=int)  # frames_cnt
    skes_name = np.loadtxt(skes_name_file, dtype=str)

    with open(raw_skes_joints_pkl, 'rb') as fr:
        skes_joints = pickle.load(fr)  # a list

    skes_joints = seq_translation(skes_joints)

    skes_joints = align_frames(skes_joints, frames_cnt)  # aligned to the same frame length

    evaluations = ['CS', 'CV']
    for evaluation in evaluations:
        split_dataset(skes_joints, label, performer, camera, evaluation, save_path)

    create_aligned_dataset(file_list=['NTU60_CS.npz', 'NTU60_CV.npz'])

七、python main.py

pip install wandb torch_dct torchdiffeq matplotlib

这几个是运行main.py进行训练必要的依赖包

作者在readme中给出一个训练的命令示例:

python 复制代码
python main.py --half=True --batch_size=32 --test_batch_size=64 \
    --step 50 60 --num_epoch=70 --num_worker=4 --dataset=NW-UCLA --num_class=10 \
    --datacase=ucla --weight_decay=0.0003 --num_person=1 --num_point=20 --graph=graph.ucla.Graph \
    --feeder=feeders.feeder_ucla.Feeder --base_lr 1e-1 --base_channel 64 \
    --window_size 52 --lambda_1=1e-0 --lambda_2=1e-1 --lambda_3=1e-3 --n_step 3

但这是--dataset=NW-UCLA,而我的数据集是ntu60,我刚刚运行了python seq_transformation.py获得了NTU60那些npz文件

这段代码是一个基于 SODE (Skeleton-based ODE Network) 的骨架动作识别训练与评估框架。下面我会先逐模块解释代码的输入、输出和中间操作,然后针对 NTU60 数据集 给出训练命令的修改指南。


一、整体代码结构解析

1. 入口与初始化 (mainProcessor.__init__)

  • 参数解析 :通过 args.py 接收命令行参数,并用 wandb 记录实验配置。

  • 随机种子固定init_seed(seed) 保证可复现性。

  • Processor 类初始化

    • 保存参数到 wandb.run.dir(动态工作目录)。

    • 加载模型 (SODE)。

    • 获取预定义的邻接矩阵 A_vector(若 k != 8)。

    • 加载优化器。

    • 加载数据 (load_data)。

2. 数据加载 (load_data)

  • 数据路径data/{dataset}/{datacase}_aligned.npz

    • 例如 NW-UCLA 对应 data/NW-UCLA/ucla_aligned.npz

    • 你需要确保 NTU60 的 .npz 文件放在类似 data/NTU60/xxx_aligned.npz 路径下。

  • Feeder :从 feeders.feeder_ucla.Feeder 动态导入(实际是一个 PyTorch Dataset 类)。

    • 输入

      • data_path:npz 文件路径。

      • split'train''test'

      • p_interval:用于随机裁剪时序片段的数据增强参数。

      • vel:是否使用速度信息。

      • random_rot:随机旋转增强。

      • A:预计算邻接矩阵(可选)。

      • window_size:输入序列长度(帧数)。

    • 输出(每个 batch)

      • x:形状 (B, C, T, V, M) ------ 批量、通道数(通常3,即关节坐标)、时间帧、关节点数、人数。

      • y:标签,形状 (B,)

      • mask:有效帧掩码,用于处理变长序列(填充部分置0)。

      • index:样本索引(评估时用于记录)。

3. 模型结构 (SODE)

  • 输入x 如上所述。

  • 前向过程

    • 编码器:提取时空特征,生成潜在变量序列 z_0

    • ODE 求解器:对 z_0 进行连续演化,得到预测的潜在序列 z_hat

    • 解码器/分类器:

      • y_hat:多时间步的分类预测,形状 (N_cls, B, T, num_class)N_cls 取决于 window_sizen_step)。

      • x_hat:重建的骨架序列,形状 (N_rec, B, C, T, V, M)

      • kl_div:KL 散度正则项(若使用 VAE 结构)。

  • 输出y_hat, x_hat, z_0, z_hat, kl_div

4. 训练流程 (train)

  • 前向传播 :调用 model(x)

  • 损失计算

    • 分类损失 (cls_loss):LabelSmoothingCrossEntropy,作用于多时间步预测与标签。

    • 重建损失 (recon_loss):masked_recon_loss,只计算非填充帧的重建误差。

    • 特征一致性损失 (feature_loss):强制 ODE 演化的 z_hat 与原始 z_0 在有效区域一致。

    • KL 散度 (kl_div):可选正则。

  • 反向传播 :支持混合精度 (apex.amp)。

  • 日志记录 :每个 epoch 结束后用 wandb 记录损失、准确率曲线下面积 (AUC) 等。

5. 评估流程 (eval)

  • 类似训练阶段,但不更新梯度。

  • 额外保存每个样本的得分 (score_dict),用于后期分析。

6. 学习率调整 (adjust_learning_rate)

  • 使用 warm-up + 阶梯衰减策略。

二、针对 NTU60 数据集的训练命令修改

已经用 seq_transformation.py 生成了 NTU60 的 .npz 文件,那么需要修改以下参数:

参数 NW-UCLA 值 NTU60 建议值 说明
--dataset NW-UCLA NTU60 数据集名称(用于路径构建)
--datacase ucla ntu60 npz 文件前缀(如 ntu60_aligned.npz
--num_class 10 60 动作类别数
--num_person 1 2 NTU60 最多 2 人
--num_point 20 25 NTU60 骨架为 25 个关节点
--graph graph.ucla.Graph graph.ntu_rgb_d.Graph 图构建类(需确认你的代码库中是否有 ntu_rgb_d 的图定义)
--feeder feeders.feeder_ucla.Feeder feeders.feeder_ntu.Feeder(或通用 feeder) 确保你的 feeder 支持 NTU60 的 npz 格式
--window_size 52 300(或 64/128 等常用值) NTU60 单样本帧数较多(~300),可适当减小
--base_lr 1e-1 可保持或微调 取决于 batch size 和优化器
--step 50 60 根据 epoch 数调整,如 60 80 学习率衰减的 epoch 节点

示例命令(NTU60 训练)

python 复制代码
# linux系统命令(已跑通)
python main.py --half=False --batch_size=32 --test_batch_size=64 \
    --step 60 80 --num_epoch=100 --num_worker=4 --dataset=ntu --num_class=60 \
    --datacase=NTU60_CS --weight_decay=0.0003 --num_person=2 --num_point=25 --graph=graph.ntu_rgb_d.Graph \
    --feeder=feeders.feeder_ntu.Feeder --base_lr 1e-1 --base_channel 64 \
    --window_size 64 --lambda_1=1e-0 --lambda_2=1e-1 --lambda_3=1e-3 --n_step 3

注意

  • 确保 feeders.feeder_ntu.Feeder 存在且能正确解析你的 .npz 文件格式(通常 x(N, C, T, V, M) 的 float32 数组,y 为 int 标签,可能还有 mask)。

  • 若代码库中无 graph.ntu_rgb_d.Graph,可参考 graph.ucla.py 自行创建 25 节点的邻接矩阵定义。

  • 由于 NTU60 数据量大,建议适当减小 window_size 以避免显存溢出,并增加 num_epoch 以充分训练。

报错:ModuleNotFoundError: No module named 'resource'

解决方案:在linux上跑训练,这个问题在window系统不好解决
报错:ModuleNotFoundError: No module named 'apex'

解决方案:main.py通过--half=False禁用混合精度训练,避开这个问题。如果要解决的话在第二节已经说过,但是难度比较大。
报错:ModuleNotFoundError: No module named 'torch_dct'

解决方案:pip install torch-dct
ModuleNotFoundError: No module named 'matplotlib'

解决方案:pip install matplotlib
报错:File "D:\zero_track\infogcn2\feeders\feeder_ntu.py", line 52, in load_data

npz_data = np.load(self.data_path)

File "D:\miniforge3\envs\infogcn2_env\Lib\site-packages\numpy\lib\_npyio_impl.py", line 454, in load

fid = stack.enter_context(open(os.fspath(file), "rb"))

FileNotFoundError: [Errno 2] No such file or directory: 'data/NTU60/ntu60_aligned.npz'
seq_transformation.py在data/ntu下生成了四个东西NTU60_CS.npz、NTU60_CS_aligned.npz、NTU60_CV..npz、NTU60_CV_aligned..npz,解释一下这四个是什么,有什么用,为什么报错显示的名字不一样呢
从你提供的 seq_transformation.py 代码可以看出,NTU60 数据处理分两个阶段,最终产生了四种 .npz 文件。下面详细解释它们的区别、用途,以及为什么训练时报错名字不匹配。


一、四种文件的作用与区别

文件名 生成方式 内容与用途
NTU60_CS.npz split_dataset() 直接生成 包含 Cross-Subject 划分的原始对齐数据: - x_train / y_train(训练集骨架与 one-hot 标签) - x_test / y_test(测试集骨架与 one-hot 标签) ⚠️ 数据格式为 (N, T, V*C)(展平后的关节坐标),不适合直接用于图卷积网络。
NTU60_CS_aligned.npz create_aligned_dataset() 处理生成 NTU60_CS.npz 中的 展平数据重塑 为模型所需的标准格式: - x_train / x_test 形状:(N, C, T, V, M) - y_train / y_test 为整数标签(非 one-hot) ✅ 这是训练和测试时 Feeder 实际加载的文件
NTU60_CV.npz 同上,针对 Cross-View 划分 与 CS 版本结构完全相同,只是训练/测试划分依据不同(按相机视角)。
NTU60_CV_aligned.npz 同上 CV 版本的标准格式文件。
关键转换逻辑(create_aligned_dataset

该函数(位于 utils.py 中,未在问题中贴出)大致完成以下操作:

复制代码
# 将 (N, T, 150) 的展平数据转换为 (N, 3, T, 25, 2)
x_train = x_train.reshape(N, T, 25, 2, 3).transpose(0, 4, 1, 2, 3)

因此,只有 *_aligned.npz 文件能被模型正确加载

八、wandb改为离线模式

wandb.errors.errors.UsageError: No API key configured. Use `wandb login` to log in. 怎么用默认参数避开这个wandb的登录?

你遇到的 wandb 登录错误是因为项目默认启用了 wandb 进行实验跟踪,而本地没有配置 API 密钥。要绕过登录,有以下几种方法,从最推荐到最直接排列:


方法一:通过环境变量禁用 wandb(推荐,无需改代码)

在运行 python main.py 之前,设置环境变量 WANDB_MODE=disabled

Windows (CMD):

复制代码
set WANDB_MODE=disabled
python main.py

Windows (PowerShell):

powershell

复制代码
$env:WANDB_MODE="disabled"
python main.py

Linux/Mac:

bash

复制代码
export WANDB_MODE=disabled
python main.py

这样 wandb.init() 会直接跳过,不要求登录。


方法二:在代码中强制设置 wandb 为离线/禁用模式

打开 main.py,找到 wandb.init(project=arg.project) 这一行,修改为:

python

复制代码
wandb.init(project=arg.project, mode="disabled")

或者如果你希望仍然记录日志但不上传到云端,可以用离线模式:

python

复制代码
wandb.init(project=arg.project, mode="offline")

(离线模式仍会要求一个虚拟的 API key,但不会上传,可以在本地查看日志。)


方法三:添加命令行参数控制 wandb 开关(一劳永逸)

args.py 中添加一个参数:

python

复制代码
parser.add_argument('--use_wandb', type=bool, default=False)

然后在 main.py 中修改:

python

复制代码
if arg.use_wandb:
    wandb.init(project=arg.project)
    wandb.config.update(arg)
else:
    # 直接跳过 wandb,但仍需记录工作目录等(可手动创建)
    os.makedirs(arg.work_dir, exist_ok=True)

同时,需要处理后续代码中所有 wandb.log() 的地方,加条件判断 if arg.use_wandb:


方法四:直接注释掉 wandb 相关代码(最暴力)

main.py 中:

python

复制代码
# wandb.init(project=arg.project)
# arg.work_dir = wandb.run.dir
# wandb.config.update(arg)
arg.work_dir = f"./work_dir/{time.strftime('%Y%m%d_%H%M%S')}"  # 手动指定目录

并注释掉 train()eval() 函数中的 wandb.log(...) 调用。


快速选择建议

  • 只想立刻跑通 → 用 方法一(环境变量),一行命令解决问题。

  • 希望长期不用 wandb → 用 方法二 ,直接在代码里写死 mode="disabled"

  • 以后可能还想用 wandb → 用 方法一方法二 中的 mode="offline"

使用 方法一 后,训练时就不会再弹出 wandb 登录错误了。

相关推荐
传说故事2 小时前
【论文阅读】DreamDojo: A Generalist Robot World Model from Large-Scale Human Videos
论文阅读·人工智能·具身智能
李昊哲小课2 小时前
LLM Wiki - 本地知识库管理系统
人工智能·本地知识库·llm wiki
水如烟2 小时前
孤能子视角:跨域联接之示例,观察符投射
人工智能
思绪无限2 小时前
YOLOv5至YOLOv12升级:金属锈蚀检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·python·深度学习·目标检测·计算机视觉·yolov12
小追兵2 小时前
Claude Code API 中转站定价太乱,我做了个在线比价工具
人工智能·claude
北京软秦科技有限公司2 小时前
一份报告为何总被退回?IACheck结合AI报告审核给防静电地板检测合规性“补漏洞”
人工智能
Agent产品评测局2 小时前
如何搭建一个药品市场价格监控智能体来实现100%价格一致性? —— 2026全渠道价格均衡化架构实战指南
人工智能·ai·chatgpt·架构
每日综合2 小时前
多元产业赋能数字未来:大洋集团全面进军教育、游戏与大健康赛道
人工智能·游戏
网瘾新之助2 小时前
像 Git 一样思考你的 Claude Code 会话
人工智能