VMamba环境本地适配配置

由于需要从源码改造VMamba,但官方给出的编译好的causual1d和mamba_ssm都是版本都于本地其他库不对应,因此从源码层级编译安装,这里记录一下,其他需要编译适配CUDA的都是一样,仅供参考。

  • 系统环境: Ubuntu 22.04
  • 核心驱动: CUDA 12.1 (nvcc -V 确认) 注意,官方提供的所有现成的包都是基于CUDA11.8或者CUDA12.2导致安装后导入环境报错,但由于个人电脑里面很多项目都是基于CUDA12.1,不想变更影响更多。
  • 核心框架: PyTorch 2.1.0 + Mamba 1.x

基础环境创建

python 复制代码
# 创建虚拟环境
conda create -n vmamba_dr python=3.10 -y
conda activate vmamba_dr

# 安装适配 CUDA 12.1 的 PyTorch
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121

# 安装基础依赖
pip install packaging timm==0.9.2 termcolor yacs pyyaml scipy wheel

# 装兼容 Torch 2.1 的 Transformers 版本
# 避免报错: AttributeError: module 'torch.utils._pytree' has no attribute 'register_pytree_node'
pip install transformers==4.37.2

2. 编译安装核心算子 (Mamba 基础)

注意:必须使用源码编译而不能直接pip install 安装现成的,且需禁用构建隔离 (--no-build-isolation) 以链接本地 Torch 环境。之前pip install后报错。

这是因为 pip 在编译源码(build wheel)时,默认会创建一个临时隔离环境 (build environment)。在这个临时环境里,它只安装了 setup.py 中声明的最基础构建工具(如 setuptools),但没有包含你主环境中安装好的 PyTorch 。而 causal-conv1dsetup.py 在运行初期就需要 import torch 来查询 CUDA 扩展的相关信息。

A. 安装 causal-conv1d (v1.4.0)

经过测试不同版本,这个版本兼容性相对比较好

注意一定要加入--no-build-isolation,否则

python 复制代码
cd ~
# 克隆源码
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
# 切换到稳定版本
git checkout v1.4.0

# 强制本地编译安装
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install . --no-build-isolation

B. 安装 mamba-ssm (v1.2.0.post1)

经测试,这个版本和Pytorch2.0兼容较好。同样必须源码安装。

python 复制代码
cd ~
git clone https://github.com/state-spaces/mamba.git
cd mamba
# 切换到 1.x 版本 (非常重要,切勿使用 2.x)
git checkout v1.2.0.post1

# 强制本地编译安装
MAMBA_FORCE_BUILD=TRUE pip install . --no-build-isolation

3.安装 VMamba 核心模块

python 复制代码
cd ~
git clone https://github.com/MzeroMiko/VMamba.git
cd VMamba/kernels/selective_scan

# 编译 VMamba 专用的 SS2D 扫描算子
pip install . --no-build-isolation

注意,也要加上 --no-build-isolation,原因同之前。

前面几个安装由于都会从github上下载源码,网络原因会失败,因此采用git到本地在安装,如果网络通畅的情况下,选项可以改为如下,要加上XX_FORCE_BUILD=TRUE,强制使用本地源码编译,而不是使用网上现成编译好的库。

python 复制代码
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install causal-conv1d==1.4.0 --no-build-isolation
MAMBA_FORCE_BUILD=TRUE pip install mamba-ssm==1.2.0.post1 --no-build-isolation

排错

以上具备后,运行原始的vmamba.py还是会报错:

报错1: NameError: name 'selective_scan_cuda_core' is not defined

通过单步调测发现,这个报错实际上不是 发生在 vmamba.py 文件里,而是发生在它引用的 csms6s.py 文件中。kernels/selective_scan 下编译安装生成的包名通常叫 selective_scan_cuda (这是 C++ 扩展在 setup.py 里定义的真实名字)。但是,VMamba 项目里的 csms6s.py(这是 Python 的封装层)里,代码却试图调用一个叫 selective_scan_cuda_core 的变量。 csms6s.py 里有一个 try...except ImportError 的静默失败逻辑,导致导入失败时没有报错,直到代码真正运行起来调用算子时,才报"名字未定义"。

解决1: 找到**csms6s.py** 文件,按以下步骤修改,文件开头的导入部分长这样:

python 复制代码
# 第 12-16 行
try:
    import selective_scan_cuda_core
except ImportError:
    print(f"CUDA扩展不可用:")
    WITH_SELECTIVESCAN_CORE = False

# 第 17-20 行
try:
    import selective_scan_cuda
except ImportError:
    WITH_SELECTIVESCAN_MAMBA = False

在 VMamba 的 kernels/selective_scan 目录下执行 pip install . 后,安装进 Python 环境的包名叫做 selective_scan_cuda

但是,这份代码里的 selective_scan_cuda_core 并没有被正确关联到 selective_scan_cuda 上。

虽然下面的代码尝试导入了 selective_scan_cuda,但在第 120 行左右的 forward 函数里:

python 复制代码
        elif backend == "core":
            # 这里调用了 selective_scan_cuda_core,但它可能是未定义的!
            out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)

如果 WITH_SELECTIVESCAN_CORE 被设为 False(因为导入失败),但如果不小心走到了 backend="core" 的分支,或者在没有指定 backend 时逻辑判断有误,就会报 name 'selective_scan_cuda_core' is not defined

将其修改为:

python 复制代码
import time
import torch
import warnings

# ===================================================================
# [核心修复] 强制映射已安装的 CUDA 算子
# ===================================================================
WITH_SELECTIVESCAN_OFLEX = False # 通常不需要这个,设为 False 防止报错
WITH_SELECTIVESCAN_CORE = True   # 强制开启 Core 后端
WITH_SELECTIVESCAN_MAMBA = True  # 强制开启 Mamba 后端

try:
    # 1. 导入编译好的包
    import selective_scan_cuda
    
    # 2. [关键修复] 将其赋值给 selective_scan_cuda_core
    # 这样代码里无论调用 core 还是 mamba 后端,都能指向同一个正确的 C++ 实现
    selective_scan_cuda_core = selective_scan_cuda 
    
    print("成功导入 selective_scan_cuda 并映射到 selective_scan_cuda_core")

except ImportError:
    # 只有真的没装才报错
    selective_scan_cuda = None
    selective_scan_cuda_core = None
    WITH_SELECTIVESCAN_CORE = False
    WITH_SELECTIVESCAN_MAMBA = False
    print("严重错误: 无法导入 selective_scan_cuda!请确认已在 kernels/selective_scan 下运行 pip install .")

# 尝试导入 oflex (可选,不重要)
try:
    import selective_scan_cuda_oflex
    WITH_SELECTIVESCAN_OFLEX = True
except ImportError:
    pass
# ===================================================================

**报错2:**解决完上面的问题后,现在前向传播时候又报错

python 复制代码
"fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool) -> List[torch.Tensor]

解决2:这是因为Python 代码调用算子的参数顺序/数量 与 C++ 底层定义的接口不匹配。可能 VMamba 的不同版本(以及它依赖的 Mamba 版本)对底层 fwd 函数的参数定义微调过。

报错信息明确告诉你,底层的 fwd 函数期望接收 9 个参数

  1. u: Tensor
  2. delta: Tensor
  3. A: Tensor
  4. B: Tensor
  5. C: Tensor
  6. D: Optional[Tensor]
  7. z: Optional[Tensor] <-- 注意这里,这里通常是门控参数
  8. delta_bias: Optional[Tensor]
  9. delta_softplus: bool

csms6s.py 代码中,调用逻辑很可能漏掉了 z 参数 ,或者多传了一个不必要的参数(比如旧版本需要的 nrows 整数),导致参数错位。

找到这行代码:

python 复制代码
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)

修改为:

python 复制代码
# 将 z 替换为 None
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)

前向出问题了,反向也会出问题,一起修改,找到:

python 复制代码
        elif backend == "core":
            du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
                u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1

修改为:

python 复制代码
# 如果报错,找到 bwd 调用,把 z 换成 None
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
    u, delta, A, B, C, D, None, delta_bias, dout, x, out, None, delta_softplus, False
)
# 注意:这里也是把 z 的位置换成了 None

另外在backward中修改:

python 复制代码
    @staticmethod
    def backward(ctx, dout, *args):
        # =============================================================
        # 1. 提取保存的张量
        # =============================================================
        if not ctx.has_z:
            u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
            z = None
            out = None  # <--- 【关键修复1】补上这一行,防止后面报 undefined
        else:
            u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors

        # 确保显存连续
        if dout.stride(-1) != 1:
            dout = dout.contiguous()

        # =============================================================
        # 2. 调用 C++ 反向传播核心
        # =============================================================
        # 注意:这里参数列表必须对应 fwd 的修改,把 z 位置显式设为 None
        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
            u, 
            delta, 
            A, 
            B, 
            C, 
            D, 
            None,       # <--- 【关键修复2】这里必须强制传 None,对应 fwd 的修改
            delta_bias, 
            dout, 
            x, 
            out,        # 这里现在是 None (如果走上面分支) 或者 Tensor (如果走下面)
            None,       # dz (optional)
            ctx.delta_softplus,
            False       # recompute_out_z
        )

        dz = rest[0] if ctx.has_z else None
        dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
        dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
        
        return (du, ddelta, dA, dB, dC,
                dD if D is not None else None,
                dz,
                ddelta_bias if delta_bias is not None else None,
                None,
                None)

这样代码能正常跑通,又带来一个疑问,这个z是可以随意设置为None的吗?究竟是什么含义?

这里简单回答一下,不确保正确,哈哈:z 是 Mamba 架构中的"门控信号(Gate)"。在我想魔改的代码里,设置为 None 是安全的,因为门控乘法操作被移到了 Python 层执行,而不是在 CUDA 核心里执行。

Mamba(以及类似的 SSM 模型)采用了 Gated Linear Unit (GLU) 的结构。输入数据通常会被线性投影成两份:

  • xxx (SSM 分支):进入复杂的 selective scan 状态空间模型进行时序建模。
  • zzz (Gate 分支):作为"门控开关",通常经过激活函数(如 SiLU)。

最终的模型输出 yyy 是两者的逐元素乘积:

y=SSM(x)⊙Activation(z)y = \text{SSM}(x) \odot \text{Activation}(z)y=SSM(x)⊙Activation(z)

这个 zzz 的作用就像一个调节阀 ,它决定了 SSM 提取的特征有多少能流向下一层。如果没有 zzz,模型就退化成了普通的 RNN/CNN,失去了 Mamba 强大的选择性过滤能力。


那为什么在 CUDA 接口里可以设为 None

既然 zzz 这么重要,为什么我们敢在 selective_scan_cuda_core.fwd(..., z=None) 里把它扔掉?

这就涉及到了 "算子融合(Kernel Fusion)" 的问题。

情况 A:融合算子 (Fused Kernel) ------ 高效但死板

原版 Mamba 为了追求极致速度,把"SSM 计算"和"乘以 zzz"这两步合并在一个 CUDA Kernel 里做完了。

  • 代码逻辑Kernel(x, z) -> 输出
  • 要求 :必须传入 z

情况 B:非融合算子 (Unfused) ------ 灵活但稍慢

VMamba 为了适应 2D 图像扫描(Cross Scan),逻辑比 1D 文本复杂。为了避免 CUDA 代码过于臃肿,作者通常选择把"乘以 zzz"这一步拿出来,放在 Python 里做

  • CUDA Kernel 逻辑 :只负责算 SSM(x),不负责乘 z。此时传 z=None 给 Kernel。
  • Python 逻辑 :拿到 Kernel 的输出后,手动写一行 output = output * z
相关推荐
victory04312 小时前
minimind SFT失败原因排查和解决办法
人工智能·python·深度学习
逐梦苍穹2 小时前
世界模型通俗讲解:AI大脑里的“物理模拟器“
人工智能·世界模型
发哥来了2 小时前
主流AI视频生成工具商用化能力评测:五大关键维度对比分析
大数据·人工智能·音视频
跳跳糖炒酸奶2 小时前
基于深度学习的单目深度估计综述阅读(1)
人工智能·深度学习·数码相机·单目深度估计
yangpipi-2 小时前
第一章 语言模型基础
人工智能·语言模型·自然语言处理
Piar1231sdafa2 小时前
基于yolo13-C3k2-RVB的洗手步骤识别与检测系统实现_1
人工智能·算法·目标跟踪
小北方城市网2 小时前
SpringBoot 集成 MyBatis-Plus 实战(高效 CRUD 与复杂查询):简化数据库操作
java·数据库·人工智能·spring boot·后端·安全·mybatis
川西胖墩墩2 小时前
开发者友好型AI调试与可观测性工具
人工智能
学统计的程序员2 小时前
一篇文章简述如何安装claude code并接入国产智谱AI大模型
人工智能·ai编程·claude