6.1.1.3 硬件感知状态空间优化(FlashConv/FlashFFTConv)
1.1 IO-Aware 计算范式
1.1.1 FlashAttention 内存层次原理
现代 GPU 计算遵循严格的内存层次结构,其中高带宽内存(HBM)与片上静态随机存取存储器(SRAM)存在显著带宽差异。FlashAttention 算法通过分块计算(tiling)与重计算策略(recomputation),将算法复杂度从内存约束转换为计算约束,实现了与序列长度线性相关的显存占用。
该范式的核心在于最小化 HBM 访存量。设序列长度为 L,隐藏维度为 d,注意力矩阵的朴素实现需存储中间矩阵 QK\^\\top \\in \\mathbb{R}\^{L \\times L} 与注意力权重 A \\in \\mathbb{R}\^{L \\times L},导致 O(L\^2) 显存复杂度。FlashAttention 通过在线 softmax 计算,将全局归一化分解为局部累积,仅需维护统计量 m(最大值)与 \\ell(指数和),实现 O(L) 显存。
1.1.2 卷积操作的内存墙问题
状态空间模型的卷积模式面临类似的内存瓶颈。全局卷积核 \\bar{K} \\in \\mathbb{R}\^L 与输入序列 u \\in \\mathbb{R}\^L 的 FFT 计算需执行:
\\mathcal{F}(\\bar{K}) \\odot \\mathcal{F}(u) = \\text{FFT}(\\bar{K}_{\\text{padded}}) \\odot \\text{FFT}(u_{\\text{padded}})
朴素实现需三次全局 FFT 变换,每次涉及 O(L \\log L) 的 HBM 带宽。当 L 达到 16K 或 128K 时,中间频域表示的物化(materialization)导致显存溢出。FlashFFTConv 通过融合卷积核 FFT 与输入 FFT,利用 Tensor Core 的矩阵乘法单元执行频域点积,减少 HBM 往返。
1.2 FlashFFTConv 算法架构
1.2.1 分块 FFT 与 Tensor Core 融合
FlashFFTConv 的核心创新在于将一维 FFT 重构为二维矩阵运算,利用现代 GPU 的 Tensor Core 加速。算法将序列分块为 B 个块,每块长度 T,满足 L = B \\times T。通过 Cooley-Tukey FFT 重构,一维 DFT 可表示为:
Y_{k_1, k_2} = \\sum_{n_2=0}\^{T-1} \\omega_L\^{k_1 n_2} \\omega_T\^{k_2 n_2} \\sum_{n_1=0}\^{B-1} x_{n_1, n_2} \\omega_B\^{k_1 n_1}
其中 \\omega_N = e\^{-2\\pi i / N}。内层求和与外层求和分别映射为矩阵乘法,由 Tensor Core 以 O(\\sqrt{L}) 的 SRAM footprint 执行。该分解将全局 FFT 的 O(L \\log L) HBM 访存降为 O(L) 计算密集型操作。
1.2.2 卷积核分解与状态扩展优化
S4 的卷积核 \\bar{K} 具有特定结构:\\bar{K}_k = C\\bar{A}\^k B。FlashFFTConv 利用该结构避免物化完整核向量。通过多项式分解,将核函数表示为 P 个低阶多项式的和:
\\bar{K}(z) = \\sum_{p=1}\^P \\frac{R_p}{z - \\lambda_p}
其中 \\lambda_p 为 \\bar{A} 的特征值,R_p 为残差。该有理函数形式允许直接通过 IIR 滤波器计算卷积,复杂度 O(NL),其中 N 为状态维度。当 N \\ll L 时,该方法避免 O(L \\log L) 的 FFT 开销,且无需 O(L) 的核存储。
1.3 内存-时间权衡理论分析
1.3.1 Transformer 的二次瓶颈
Transformer 的自注意力机制计算成本为:
\\text{Time} \\propto \\frac{2L\^2d}{W} + \\frac{2L\^2}{B_{HBM}}, \\text{Memory} \\propto 2L\^2 + 4Ld
其中 W 为 Tensor Core 算力(FLOPS),B_{HBM} 为 HBM 带宽。当 L 超过临界值 L_{crit} \\approx \\sqrt{M}(M 为 SRAM 容量),系统进入内存墙(memory wall)区域,计算单元闲置等待数据。
1.3.2 SSM 的线性扩展特性
S4 在卷积模式下理论复杂度为:
\\text{Time} \\propto \\frac{c_1 L \\log L \\cdot d}{W} + \\frac{c_2 L \\cdot d}{B_{HBM}}, \\text{Memory} \\propto c_3 L \\cdot d + c_4 N \\cdot d
系数 c_1 对应 FFT 计算,c_2 对应数据传输。FlashFFTConv 通过核函数分解与分块策略,将 c_1 降至接近 1,c_2 通过融合操作趋近于 0。状态扩展(state expansion)引入 O(N) 的额外内存,但 N 通常固定(如 64 或 256),与 L 无关。
1.3.3 长序列临界点分析
在序列长度 L \\in \\{16K, 32K, 128K\\} 区间,三种架构的权衡曲线呈现显著差异:
-
Transformer:在 16K 处已达到内存墙,时间成本随 L\^2 急剧上升,显存占用接近硬件上限(如 80GB A100)。
-
S4-朴素 FFT:32K 处进入带宽限制区,FFT 的 O(L \\log L) 计算密集但 O(L) 显存线性增长,允许 128K 序列处理。
-
S4-FlashFFTConv:通过 Tensor Core 融合与核分解,在所有测试长度保持计算受限,显存占用为 O(Nd + Ld),时间增长接近线性。
2 结构化伪代码
1. FlashFFTConv 分块卷积算法
该算法通过将长序列拆分为较小的块,利用 Tensor Core 加速 FFT 运算,并结合离线分解后的核函数进行频域卷积。
Procedure: FlashFFTConv(u, \\bar{A}, \\bar{B}, C, T_{block})
-
初始化与预处理
-
L \\leftarrow \\text{length}(u)
-
N \\leftarrow \\text{dim}(\\bar{A})
-
B \\leftarrow \\lceil L / T_{block} \\rceil
-
-
阶段 1: 卷积核分解(离线/预计算)
- \\{\\lambda_p, R_p\\}_{p=1}\^N \\leftarrow \\text{EigenDecompose}(\\bar{A}, C, \\bar{B})
-
阶段 2: 分块频域计算
-
y \\leftarrow \\text{zeros}(L)
-
For b \\leftarrow 0 to B-1 do:
-
u_b \\leftarrow u\[b \\cdot T_{block} : \\min((b+1) \\cdot T_{block}, L)\]
-
u_b \\leftarrow \\text{PadToPowerOfTwo}(u_b)
-
U_{freq} \\leftarrow \\text{TensorCoreFFT}(u_b)
-
核函数频域响应(通过残差计算):
- K_{freq} \\leftarrow \\text{ComputeKernelSpectrum}(\\{\\lambda_p, R_p\\}, \\text{length}(u_b))
-
Y_{freq} \\leftarrow U_{freq} \\odot K_{freq}
-
y_b \\leftarrow \\text{TensorCoreIFFT}(Y_{freq})
-
y\[b \\cdot T_{block} : \] \\leftarrow y\[b \\cdot T_{block} : \] + y_b (重叠相加法)
-
-
End For
-
-
Return y
2. 内存优化的状态扩展计算
此过程针对状态空间模型(SSM)在状态扩展时的显存占用进行了优化,重点在于利用高速 SRAM 缓存来减少全局显存(HBM)的读写。
Procedure: StateExpansion(u, \\bar{A}, \\bar{B}, N_{expand})
-
初始化
-
L \\leftarrow \\text{length}(u)
-
x \\leftarrow \\text{zeros}(N_{expand})
-
SRAM_{buffer} \\leftarrow \\text{allocate}(N_{expand} \\times T_{chunk})
-
y \\leftarrow \\text{zeros}(L)
-
-
分块迭代
-
For t \\leftarrow 0 to L-1 step T_{chunk} do:
-
T_{end} \\leftarrow \\min(t + T_{chunk}, L)
-
内循环:SRAM 内计算
-
For \\tau \\leftarrow t to T_{end}-1 do:
-
x \\leftarrow \\bar{A} \\cdot x + \\bar{B} \\cdot u\[\\tau\]
-
SRAM_{buffer}\[:, \\tau - t\] \\leftarrow x
-
-
End For
-
全局输出投影
-
y\[t : T_{end}\] \\leftarrow C \\cdot SRAM_{buffer}\[:, : T_{end}-t\] + D \\cdot u\[t : T_{end}\]
-
-
End For
-
-
Return y
3. Transformer 与 SSM 内存估计器
该算法用于对比 Transformer(注意力机制)与基于 FlashFFTConv 的 SSM 在不同序列长度下的显存需求。
Procedure: MemoryEstimator(L, d, N_{ssm}, B_{batch})
-
Transformer 分析
-
M_{attn} \\leftarrow 2 \\cdot B_{batch} \\cdot L\^2 (Attention 矩阵随序列长度平方增长)
-
M_{trans} \\leftarrow M_{attn} + 4 \\cdot B_{batch} \\cdot L \\cdot d (总显存)
-
-
S4-FlashFFTConv 分析
-
M_{kernel} \\leftarrow N_{ssm} \\cdot d (残差存储)
-
M_{fftbuf} \\leftarrow 2 \\cdot T_{block} \\cdot B_{batch} (分块 FFT 缓冲区)
-
M_{ssm} \\leftarrow M_{kernel} + M_{fftbuf} + 2 \\cdot B_{batch} \\cdot L \\cdot d
-
-
临界点计算
- L_{crit} \\leftarrow \\sqrt{\\frac{M_{ssm} - 4Ld}{2B_{batch}}} (计算两种架构内存消耗持平的序列长度)
-
Return M_{trans}, M_{ssm}, L_{crit}
3 代码实现
脚本1:理论内存-时间权衡分析
该脚本执行理论层面的复杂度分析,生成16K/32K/128K序列长度下Mamba/Transformer的内存-时间权衡曲线。运行方式:python theoretical_analysis.py
Python
复制
"""
脚本1:理论内存-时间权衡分析
内容:分析Transformer与S4(Mamba)在长序列下的理论复杂度与硬件限制
使用方式:直接运行 python theoretical_analysis.py
"""
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, List
@dataclass
class HardwareSpecs:
"""GPU硬件规格(以A100 80GB为例)"""
hbm_capacity_gb: float = 80.0
hbm_bandwidth_gb_s: float = 2039.0 # GB/s
tensor_core_tflops: float = 312.0 # TFLOPS (FP16)
sram_capacity_mb: float = 40.0 # 共享内存/SRAM每SM
def get_critical_length(self, d: int, bytes_per_elem: int = 2) -> float:
"""
计算SRAM能容纳的最大序列长度
假设batch=1,仅存储注意力矩阵
"""
sram_bytes = self.sram_capacity_mb * 1024 * 1024
# 注意力矩阵大小: L^2 * bytes_per_elem
max_L = np.sqrt(sram_bytes / bytes_per_elem)
return max_L
@dataclass
class ModelConfig:
"""模型配置参数"""
d_model: int = 768
n_layers: int = 12
n_ssm_states: int = 64 # S4状态维度N
batch_size: int = 1
class TheoreticalAnalyzer:
"""
理论分析器:计算不同架构的理论内存占用与计算时间
"""
def __init__(self, hw: HardwareSpecs, config: ModelConfig):
self.hw = hw
self.cfg = config
def transformer_memory(self, L: int) -> Tuple[float, float]:
"""
计算Transformer的显存占用(单位:GB)
组件:
1. Attention矩阵: 2 * L^2 (QK^T保存与梯度)
2. 激活值: 4 * L * d * n_layers (Q,K,V,O每层)
3. 参数: 忽略(相对激活可忽略)
"""
bytes_per_elem = 2 # FP16
# Attention矩阵 (activations + gradients)
attn_bytes = 2 * (L ** 2) * bytes_per_elem * self.cfg.batch_size
# 每层激活: Q,K,V,O = 4 * L * d
# 通常需要保存用于反向传播
activations_bytes = (4 * L * self.cfg.d_model * self.cfg.n_layers *
bytes_per_elem * self.cfg.batch_size)
total_bytes = attn_bytes + activations_bytes
total_gb = total_bytes / (1024**3)
# HBM带宽限制下的加载时间估计(秒)
# 假设每次前向+反向需要3次全量读写
io_time = 3 * total_gb / self.hw.hbm_bandwidth_gb_s
return total_gb, io_time
def s4_flashfftconv_memory(self, L: int) -> Tuple[float, float]:
"""
计算S4+FlashFFTConv的显存占用(单位:GB)
组件:
1. 卷积核参数: N * d (残差存储,与L无关)
2. 分块FFT缓冲区: 2 * T_block * batch (SRAM复用)
3. 序列激活: 2 * L * d (输入+输出)
4. 状态缓存: N * d (递归状态)
"""
bytes_per_elem = 2
# 卷积核分解后的参数(常数大小)
kernel_bytes = (self.cfg.n_ssm_states * self.cfg.d_model * bytes_per_elem)
# 分块FFT缓冲区(最大块大小,假设使用8K分块)
T_block = min(8192, L)
fft_buffer_bytes = 2 * T_block * self.cfg.batch_size * bytes_per_elem * 2 # 复数
# 序列激活
seq_bytes = 2 * L * self.cfg.d_model * bytes_per_elem * self.cfg.batch_size
# 状态缓存(双重缓冲)
state_bytes = 2 * self.cfg.n_ssm_states * self.cfg.d_model * bytes_per_elem
total_bytes = kernel_bytes + fft_buffer_bytes + seq_bytes + state_bytes
total_gb = total_bytes / (1024**3)
# 计算时间估计(秒)
# FFT复杂度: O(L log L) * d / TensorCore_TFLOPS
# 加上内存传输时间
flop_count = L * np.log2(L) * self.cfg.d_model * 5 # 5 FLOP per complex op
compute_time = flop_count / (self.hw.tensor_core_tflops * 1e12)
# 内存传输(假设分块加载,带宽受限部分)
io_time = total_gb / self.hw.hbm_bandwidth_gb_s
total_time = max(compute_time, io_time) # 重叠计算与通信
return total_gb, total_time
def naive_ssm_memory(self, L: int) -> Tuple[float, float]:
"""
朴素S4实现(全局FFT,物化卷积核)的内存占用
"""
bytes_per_elem = 2
# 物化卷积核: L * d
kernel_bytes = L * self.cfg.d_model * bytes_per_elem
# FFT缓冲区: 2 * L (零填充后)
fft_buffer = 2 * L * 2 * bytes_per_elem # 复数
# 序列激活
seq_bytes = 2 * L * self.cfg.d_model * bytes_per_elem
total_bytes = kernel_bytes + fft_buffer + seq_bytes
total_gb = total_bytes / (1024**3)
# 朴素FFT需要3次全局FFT(核FFT、输入FFT、逆FFT)
io_time = 3 * total_gb / self.hw.hbm_bandwidth_gb_s
return total_gb, io_time
def generate_tradeoff_curves(self, lengths: np.ndarray) -> dict:
"""
生成内存-时间权衡曲线数据
"""
results = {
'transformer': {'memory': [], 'time': [], 'feasible': []},
's4_flashfftconv': {'memory': [], 'time': [], 'feasible': []},
's4_naive': {'memory': [], 'time': [], 'feasible': []}
}
max_memory = self.hw.hbm_capacity_gb * 0.9 # 90%容量限制
for L in lengths:
# Transformer
mem_t, time_t = self.transformer_memory(int(L))
results['transformer']['memory'].append(mem_t)
results['transformer']['time'].append(time_t)
results['transformer']['feasible'].append(mem_t < max_memory)
# S4 FlashFFTConv
mem_s4, time_s4 = self.s4_flashfftconv_memory(int(L))
results['s4_flashfftconv']['memory'].append(mem_s4)
results['s4_flashfftconv']['time'].append(time_s4)
results['s4_flashfftconv']['feasible'].append(mem_s4 < max_memory)
# S4 Naive
mem_naive, time_naive = self.naive_ssm_memory(int(L))
results['s4_naive']['memory'].append(mem_naive)
results['s4_naive']['time'].append(time_naive)
results['s4_naive']['feasible'].append(mem_naive < max_memory)
return results
def visualize_tradeoffs(self, lengths: np.ndarray, results: dict):
"""
可视化内存-时间权衡曲线与可行性边界
"""
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
# 1. 内存随序列长度增长
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(lengths/1024, results['transformer']['memory'], 'r-o',
label='Transformer', linewidth=2, markersize=6)
ax1.plot(lengths/1024, results['s4_flashfftconv']['memory'], 'b-s',
label='S4-FlashFFTConv', linewidth=2, markersize=6)
ax1.plot(lengths/1024, results['s4_naive']['memory'], 'g--^',
label='S4-NaiveFFT', linewidth=2, markersize=6)
ax1.axhline(y=self.hw.hbm_capacity_gb, color='k', linestyle='--',
label='HBM Capacity (80GB)')
ax1.set_xlabel('Sequence Length (K tokens)')
ax1.set_ylabel('Memory Usage (GB)')
ax1.set_title('Memory Scaling Comparison')
ax1.set_yscale('log')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. 计算时间随序列长度增长
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(lengths/1024, results['transformer']['time'], 'r-o',
label='Transformer', linewidth=2, markersize=6)
ax2.plot(lengths/1024, results['s4_flashfftconv']['time'], 'b-s',
label='S4-FlashFFTConv', linewidth=2, markersize=6)
ax2.plot(lengths/1024, results['s4_naive']['time'], 'g--^',
label='S4-NaiveFFT', linewidth=2, markersize=6)
ax2.set_xlabel('Sequence Length (K tokens)')
ax2.set_ylabel('Estimated Time (seconds)')
ax2.set_title('Computational Time Scaling')
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 3. 内存-时间权衡散点图(帕累托前沿)
ax3 = fig.add_subplot(gs[0, 2])
for model, color, marker in [('transformer', 'red', 'o'),
('s4_flashfftconv', 'blue', 's'),
('s4_naive', 'green', '^')]:
mem = np.array(results[model]['memory'])
time = np.array(results[model]['time'])
feasible = np.array(results[model]['feasible'])
# 可行点(实心)
ax3.scatter(mem[feasible], time[feasible], c=color, marker=marker,
s=100, alpha=0.7, label=f'{model.replace("_", "-")} (feasible)')
# 不可行点(空心)
if np.any(~feasible):
ax3.scatter(mem[~feasible], time[~feasible], facecolors='none',
edgecolors=color, marker=marker, s=100, alpha=0.5)
ax3.set_xlabel('Memory (GB)')
ax3.set_ylabel('Time (seconds)')
ax3.set_title('Memory-Time Tradeoff Curve')
ax3.set_xscale('log')
ax3.set_yscale('log')
ax3.legend()
ax3.grid(True, alpha=0.3)
# 4. 特定长度对比(16K, 32K, 128K)
ax4 = fig.add_subplot(gs[1, :])
target_lengths = [16*1024, 32*1024, 128*1024]
x_pos = np.arange(len(target_lengths))
width = 0.25
for i, (model, color, label) in enumerate([
('transformer', 'red', 'Transformer'),
('s4_flashfftconv', 'blue', 'S4-FlashFFTConv'),
('s4_naive', 'green', 'S4-Naive')
]):
mem_vals = []
for L in target_lengths:
idx = np.argmin(np.abs(lengths - L))
mem_vals.append(results[model]['memory'][idx])
bars = ax4.bar(x_pos + i*width, mem_vals, width, label=label,
color=color, alpha=0.7)
# 添加数值标签
for bar, val in zip(bars, mem_vals):
height = bar.get_height()
ax4.annotate(f'{val:.1f}G',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=8, rotation=45)
ax4.set_xlabel('Sequence Length')
ax4.set_ylabel('Memory Usage (GB)')
ax4.set_title('Memory Footprint at Key Sequence Lengths (16K/32K/128K)')
ax4.set_xticks(x_pos + width)
ax4.set_xticklabels(['16K', '32K', '128K'])
ax4.legend()
ax4.axhline(y=self.hw.hbm_capacity_gb, color='k', linestyle='--', alpha=0.5)
ax4.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig('memory_time_tradeoff.png', dpi=300, bbox_inches='tight')
plt.show()
# 打印关键数据表
print("\n" + "="*80)
print("关键长度理论分析结果")
print("="*80)
for L in target_lengths:
idx = np.argmin(np.abs(lengths - L))
print(f"\n序列长度: {L/1024:.0f}K")
print("-" * 40)
for model in ['transformer', 's4_flashfftconv', 's4_naive']:
mem = results[model]['memory'][idx]
time = results[model]['time'][idx]
feasible = "✓" if results[model]['feasible'][idx] else "✗ OOM"
print(f"{model:20s}: Memory={mem:6.2f}GB, Time={time:.4f}s {feasible}")
if __name__ == "__main__":
# 配置硬件与模型
hw = HardwareSpecs()
config = ModelConfig(d_model=768, n_layers=12, n_ssm_states=64, batch_size=1)
# 初始化分析器
analyzer = TheoreticalAnalyzer(hw, config)
# 生成序列长度范围(4K到256K)
lengths = np.logspace(np.log10(4096), np.log10(262144), 50).astype(int)
# 执行分析
print("正在执行理论分析...")
results = analyzer.generate_tradeoff_curves(lengths)
# 可视化
analyzer.visualize_tradeoffs(lengths, results)
# 输出临界点分析
critical_L = hw.get_critical_length(config.d_model)
print(f"\n硬件SRAM临界点(最大免HBM交换长度): {critical_L:.0f}")
print(f"A100 HBM容量允许Transformer最大长度(估计): "
f"{np.sqrt(hw.hbm_capacity_gb * 0.8 * 1024**3 / 4 / config.batch_size):.0f}")
脚本2:FlashFFTConv算法实现
该脚本实现FlashFFTConv的核心分块FFT与核函数分解算法,模拟Tensor Core优化效果。运行方式:python flashfftconv_impl.py
Python
复制
"""
脚本2:FlashFFTConv算法实现
内容:实现分块FFT卷积与核函数分解优化,模拟IO-Aware计算
使用方式:直接运行 python flashfftconv_impl.py
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import eigvals, eig
import time
class FlashFFTConv:
"""
FlashFFTConv高效实现
包含核函数分解与分块FFT优化
"""
def __init__(self, A, B, C, D=0.0, block_size=8192):
"""
初始化FlashFFTConv
Args:
A, B, C, D: SSM参数
block_size: SRAM分块大小(模拟GPU共享内存限制)
"""
self.N = A.shape[0]
self.D = D
self.block_size = block_size
# 离散化(假设已离散化的A_bar, B_bar)
self.A = A
self.B = B.reshape(-1)
self.C = C.reshape(-1)
# 执行核函数分解(关键优化)
self._decompose_kernel()
def _decompose_kernel(self):
"""
核函数有理分解:将全局卷积核分解为残差形式
数学原理:
若 A 可对角化,A = V * diag(lambda) * V^{-1}
则 C A^k B = sum_{p} (C * v_p) * (w_p^T * B) * lambda_p^k
= sum_{p} R_p * lambda_p^k
其中 R_p 为残差,lambda_p 为极点
"""
# 特征分解(对于DPLR或Diagonal结构)
eigenvalues, eigenvectors = eig(self.A)
# 计算残差 R_p = (C * v_p) * (w_p^T * B)
# 对于简化情况,假设A为对角或接近对角
self.poles = eigenvalues # lambda_p
self.residuals = (self.C * eigenvectors).dot(np.linalg.inv(eigenvectors).dot(self.B))
# 仅保留显著残差(数值稳定性)
threshold = np.max(np.abs(self.residuals)) * 1e-6
mask = np.abs(self.residuals) > threshold
self.poles = self.poles[mask]
self.residuals = self.residuals[mask]
self.n_poles = len(self.poles)
def _compute_kernel_spectrum(self, L, dtype=complex):
"""
通过残差计算卷积核的频谱响应,避免物化时域核
K_freq[k] = sum_p R_p / (1 - lambda_p * exp(-2*pi*i*k/L))
"""
freqs = np.fft.fftfreq(L) * 2 * np.pi * 1j
K_freq = np.zeros(L, dtype=dtype)
# 并行计算所有极点的贡献
for p in range(self.n_poles):
lam = self.poles[p]
R = self.residuals[p]
# 几何级数求和:sum_{k=0}^{L-1} lambda^k * exp(-2*pi*i*n*k/L)
if np.abs(lam) < 0.9999: # 稳定性检查
K_freq += R / (1 - lam * np.exp(-freqs))
return K_freq
def tensorcore_fft(self, x, simulate_fusion=True):
"""
模拟Tensor Core优化的FFT
实际实现会将FFT分解为矩阵乘法,利用Tensor Core加速
"""
# 这里使用标准FFT,但在实际硬件中会是:
# 1. 将一维DFT分解为二维矩阵运算
# 2. 使用WMMA (Warp Matrix Multiply Accumulate) 指令
# 3. 在SRAM内完成分块计算
if simulate_fusion:
# 模拟融合操作:减少HBM写入
x_fft = np.fft.fft(x)
return x_fft
else:
# 朴素实现(多次HBM往返)
x_copy = x.copy() # 模拟HBM写入
x_fft = np.fft.fft(x_copy)
result = x_fft.copy() # 模拟HBM读取
return result
def flash_convolution(self, u):
"""
FlashFFTConv主算法:分块卷积与核分解
策略:
1. 长序列分块处理,每块在SRAM内完成FFT
2. 核函数通过频域残差直接计算,避免存储L长度向量
3. 使用重叠相加法(Overlap-Add)处理长序列
"""
L = len(u)
if L <= self.block_size:
# 短序列:直接FFT(SRAM容纳)
return self._short_sequence_conv(u)
else:
# 长序列:分块处理
return self._long_sequence_conv(u)
def _short_sequence_conv(self, u):
"""
短序列卷积(单块处理)
"""
L = len(u)
# 核频谱(通过分解公式计算,不存储时域核)
K_freq = self._compute_kernel_spectrum(L)
# 输入FFT(Tensor Core优化)
u_freq = self.tensorcore_fft(u)
# 频域乘积(点积)
y_freq = u_freq * K_freq
# 逆FFT
y = np.fft.ifft(y_freq).real
# D项
y += self.D * u
return y
def _long_sequence_conv(self, u):
"""
长序列分块卷积(重叠相加法)
内存复杂度:O(block_size) 而非 O(L)
"""
L = len(u)
B = self.block_size
# 计算FFT长度(2的幂次,且满足线性卷积需求)
L_fft = 2 * B
# 核频谱(分块大小)
K_freq_block = self._compute_kernel_spectrum(L_fft)
y = np.zeros(L)
# 分块处理
num_blocks = (L + B - 1) // B
for b in range(num_blocks):
start = b * B
end = min(start + B, L)
block_len = end - start
# 提取块并零填充至FFT长度
u_block = np.zeros(L_fft)
u_block[:block_len] = u[start:end]
# SRAM内计算
u_freq = self.tensorcore_fft(u_block)
y_freq = u_freq * K_freq_block
y_block = np.fft.ifft(y_freq).real
# 重叠相加
overlap_start = start
overlap_end = min(start + L_fft, L)
add_len = overlap_end - overlap_start
y[overlap_start:overlap_end] += y_block[:add_len]
# D项(仅有效数据段)
y += self.D * u
return y
def naive_convolution(self, u):
"""
朴素卷积(物化全局核,全局FFT)
用于对比验证正确性与内存占用
"""
L = len(u)
# 物化卷积核(内存密集)
K = np.zeros(L)
power = np.eye(self.N)
for k in range(L):
K[k] = self.C @ power @ self.B
power = power @ self.A
# 全局FFT(HBM密集)
L_fft = 2 * L - 1
K_padded = np.zeros(L_fft)
K_padded[:L] = K
u_padded = np.zeros(L_fft)
u_padded[:L] = u
K_freq = np.fft.fft(K_padded)
u_freq = np.fft.fft(u_padded)
y_full = np.fft.ifft(K_freq * u_freq).real
return y_full[:L] + self.D * u
def benchmark_memory_access(self, lengths):
"""
对比FlashFFTConv与朴素实现的内存访问量
"""
flash_hbm_access = []
naive_hbm_access = []
bytes_per_elem = 4 # float32
for L in lengths:
# FlashFFTConv: 分块加载,每块2*B大小,共L/B块
B = min(self.block_size, L)
num_blocks = (L + B - 1) // B
# 每次块处理:加载B,存储B(假设融合)
flash_access = num_blocks * (2 * B * bytes_per_elem)
# 朴素实现:物化核L + 零填充2L + 输入L + 输出2L
naive_access = (L + 2*L + L + 2*L) * bytes_per_elem
flash_hbm_access.append(flash_access)
naive_hbm_access.append(naive_access)
return np.array(flash_hbm_access), np.array(naive_hbm_access)
def visualize_kernel_decomposition(self):
"""
可视化核函数分解的效果
"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
L_show = 512
# 1. 精确核(通过矩阵幂计算)
K_exact = np.zeros(L_show)
power = np.eye(self.N)
for k in range(L_show):
K_exact[k] = self.C @ power @ self.B
power = power @ self.A
# 2. 分解近似核(通过残差重建)
t = np.arange(L_show)
K_approx = np.zeros(L_show, dtype=complex)
for p in range(self.n_poles):
lam = self.poles[p]
R = self.residuals[p]
K_approx += R * (lam ** t)
K_approx = K_approx.real
# 3. 频域对比
K_exact_freq = np.fft.fft(K_exact)
K_approx_freq = np.fft.fft(K_approx)
# 绘图
axes[0,0].plot(K_exact[:100], 'b-', label='Exact Kernel', linewidth=2)
axes[0,0].plot(K_approx[:100].real, 'r--', label='Decomposed Approx', linewidth=2)
axes[0,0].set_title('Time Domain Kernel (First 100 steps)')
axes[0,0].set_xlabel('Lag $k$')
axes[0,0].set_ylabel('Amplitude')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)
axes[0,1].plot(np.abs(K_exact_freq[:50]), 'b-', label='Exact Spectrum', linewidth=2)
axes[0,1].plot(np.abs(K_approx_freq[:50]), 'r--', label='Approx Spectrum', linewidth=2)
axes[0,1].set_title('Frequency Domain Magnitude')
axes[0,1].set_xlabel('Frequency Index')
axes[0,1].set_ylabel('Magnitude')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)
# 误差分析
error = np.abs(K_exact - K_approx.real)
axes[1,0].semilogy(error[:100], 'g-', linewidth=2)
axes[1,0].set_title('Approximation Error (Log Scale)')
axes[1,0].set_xlabel('Lag $k$')
axes[1,0].set_ylabel('Absolute Error')
axes[1,0].grid(True, alpha=0.3)
# 极点分布(复平面)
axes[1,1].scatter(self.poles.real, self.poles.imag, c='red', s=100, alpha=0.6)
axes[1,1].add_patch(plt.Circle((0,0), 1, fill=False, color='k', linestyle='--'))
axes[1,1].set_title(f'Pole Distribution ({self.n_poles} poles)')
axes[1,1].set_xlabel('Real')
axes[1,1].set_ylabel('Imaginary')
axes[1,1].grid(True, alpha=0.3)
axes[1,1].axis('equal')
plt.tight_layout()
plt.savefig('kernel_decomposition.png', dpi=300)
plt.show()
# 输出分解效率
compression_ratio = L_show / (2 * self.n_poles) # 存储2*N个参数vs L个
print(f"核函数分解压缩比: {compression_ratio:.1f}x (存储 {2*self.n_poles} vs {L_show} 参数)")
if __name__ == "__main__":
# 构造测试系统(对角化友好)
N = 16
# 构造具有良好极点分布的A矩阵
poles_real = -np.linspace(0.1, 1.0, N)
poles_imag = np.linspace(-0.5, 0.5, N)
A = np.diag(poles_real + 1j * poles_imag).real
# 添加轻微耦合使其非纯对角(更真实)
if N > 1:
A += np.diag(np.ones(N-1) * 0.05, k=1)
B = np.ones(N) / np.sqrt(N)
C = np.ones(N) / np.sqrt(N)
D = 0.01
# 初始化FlashFFTConv
flashconv = FlashFFTConv(A, B, C, D, block_size=2048)
# 可视化分解效果
flashconv.visualize_kernel_decomposition()
# 测试不同长度序列
test_lengths = [512, 2048, 8192, 16384, 32768]
print("\n" + "="*60)
print("FlashFFTConv性能验证")
print("="*60)
for L in test_lengths:
u = np.random.randn(L)
# FlashFFTConv
start = time.time()
y_flash = flashconv.flash_convolution(u)
time_flash = time.time() - start
# 朴素实现(仅短序列)
if L <= 8192:
start = time.time()
y_naive = flashconv.naive_convolution(u)
time_naive = time.time() - start
error = np.max(np.abs(y_flash - y_naive))
print(f"L={L:6d}: Flash={time_flash:.4f}s, Naive={time_naive:.4f}s, "
f"Error={error:.2e}, Speedup={time_naive/time_flash:.1f}x")
else:
print(f"L={L:6d}: Flash={time_flash:.4f}s, Naive=Skipped (长序列)")
# 内存访问分析
lengths = np.array([4096, 8192, 16384, 32768, 65536, 131072])
flash_mem, naive_mem = flashconv.benchmark_memory_access(lengths)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(lengths/1024, flash_mem/(1024**2), 'b-o', label='FlashFFTConv', linewidth=2, markersize=8)
ax.plot(lengths/1024, naive_mem/(1024**2), 'r-s', label='Naive FFT', linewidth=2, markersize=8)
ax.set_xlabel('Sequence Length (K tokens)')
ax.set_ylabel('HBM Access Volume (MB)')
ax.set_title('Memory Bandwidth Usage Comparison')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('memory_access_comparison.png', dpi=300)
plt.show()
print("\n分析完成。关键优化:")
print(f"- 分块大小: {flashconv.block_size}")
print(f"- 极点半径范围: [{np.min(np.abs(flashconv.poles)):.4f}, {np.max(np.abs(flashconv.poles)):.4f}]")
print(f"- 最大内存节省: {np.max(naive_mem/flash_mem):.1f}x (at 128K)")
脚本3:长序列实验对比与权衡曲线生成
该脚本执行Mamba(SSM)与Transformer在16K/32K/128K长度下的实际对比实验(模拟),生成完整的内存-时间权衡曲线。运行方式:python long_sequence_experiment.py
Python
复制
"""
脚本3:长序列实验对比与权衡曲线生成
内容:模拟Mamba与Transformer在16K/32K/128K序列下的内存-时间权衡
使用方式:直接运行 python long_sequence_experiment.py
"""
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
import time
@dataclass
class ExperimentConfig:
"""实验配置"""
d_model: int = 768
n_heads: int = 12
d_head: int = 64
n_layers: int = 12
n_ssm_states: int = 64
batch_size: int = 1
class TransformerSimulator:
"""
Transformer计算与内存模拟器
模拟FlashAttention优化后的Transformer
"""
def __init__(self, config: ExperimentConfig):
self.cfg = config
def compute_memory(self, L: int) -> float:
"""
计算显存占用(GB)
包含:
- FlashAttention不需要存储L^2矩阵,但需要KV Cache
- 激活值:2 * L * d * n_layers (简化估计)
"""
bytes_per_elem = 2 # FP16
# KV Cache(推理时)或激活值(训练时)
# FlashAttention采用分块重计算,激活值约为O(L*d)而非O(L^2)
kv_cache = 2 * L * self.cfg.d_model * bytes_per_elem # K,V各一份
# 注意力统计量(online softmax的m和l)
stats = 2 * L * self.cfg.n_heads * 4 # float32统计量
# 模型参数与梯度(与L无关,常数)
params = 12 * self.cfg.d_model**2 * self.cfg.n_layers * bytes_per_elem
total = (kv_cache + stats + params) * self.cfg.batch_size / (1024**3)
return total
def compute_time(self, L: int) -> float:
"""
计算时间(秒)
时间 = 计算时间 + 内存时间(受限于两者较大值)
"""
# FLOPs: 2 * L^2 * d (注意力) + 其他线性层
flops = 2 * (L ** 2) * self.cfg.d_model + 6 * L * self.cfg.d_model * self.cfg.n_layers
compute_time = flops / (312e12) # A100 Tensor Core 312 TFLOPS
# 内存带宽限制(假设需要读写KV cache)
memory_gb = self.compute_memory(L)
memory_time = memory_gb / 2039 # A100 HBM带宽
# 实际时间为两者较大值(掩盖或带宽限制)
return max(compute_time, memory_time) * 1.5 # 系数1.5模拟框架开销
class MambaSimulator:
"""
Mamba/SSM计算与内存模拟器
模拟FlashFFTConv优化后的S4
"""
def __init__(self, config: ExperimentConfig, optimized: bool = True):
self.cfg = config
self.optimized = optimized # True: FlashFFTConv, False: Naive
def compute_memory(self, L: int) -> float:
"""
计算显存占用(GB)
"""
bytes_per_elem = 2
if self.optimized:
# FlashFFTConv: 核函数分解存储,不物化L长度向量
# 仅需存储:残差参数(N*d) + 分块缓冲区(block_size) + 输入输出(2*L*d)
kernel_params = self.cfg.n_ssm_states * self.cfg.d_model * bytes_per_elem
block_buffer = 2 * 8192 * bytes_per_elem # 8K分块
seq_activations = 2 * L * self.cfg.d_model * bytes_per_elem
total = (kernel_params + block_buffer + seq_activations) / (1024**3)
else:
# 朴素实现:物化卷积核 L*d + FFT缓冲区 2L*d
kernel_materialized = L * self.cfg.d_model * bytes_per_elem
fft_buffer = 2 * L * bytes_per_elem # 复数缓冲区
seq_activations = 2 * L * self.cfg.d_model * bytes_per_elem
total = (kernel_materialized + fft_buffer + seq_activations) / (1024**3)
return total * self.cfg.batch_size
def compute_time(self, L: int) -> float:
"""
计算时间(秒)
"""
if self.optimized:
# FlashFFTConv: O(L log L) 但常数极小,接近线性
# 利用Tensor Core与分块优化
flops = L * np.log2(L) * self.cfg.d_model * 10 # FFT常数小
compute_time = flops / (312e12)
# 内存带宽:分块加载,带宽需求低
memory_gb = self.compute_memory(L)
memory_time = memory_gb / 2039
# FlashFFTConv计算密集,通常compute > memory
return compute_time * 1.2 + memory_time * 0.1
else:
# 朴素FFT: 3次全局FFT,内存带宽受限
flops = L * np.log2(L) * self.cfg.d_model * 10
compute_time = flops / (312e12)
memory_gb = self.compute_memory(L) * 3 # 3次全局读写
memory_time = memory_gb / 2039
return max(compute_time, memory_time) * 2.0
def run_experiment():
"""
执行完整实验:生成16K, 32K, 128K下的权衡曲线
"""
config = ExperimentConfig(d_model=768, n_layers=12, n_ssm_states=64)
trans_sim = TransformerSimulator(config)
mamba_opt_sim = MambaSimulator(config, optimized=True)
mamba_naive_sim = MambaSimulator(config, optimized=False)
# 测试长度
target_lengths = [16*1024, 32*1024, 128*1024]
extended_lengths = np.linspace(4096, 150000, 100)
# 收集数据
results = {
'lengths': np.array(target_lengths),
'transformer': {'mem': [], 'time': []},
'mamba_flash': {'mem': [], 'time': []},
'mamba_naive': {'mem': [], 'time': []}
}
# 目标长度精确计算
for L in target_lengths:
results['transformer']['mem'].append(trans_sim.compute_memory(L))
results['transformer']['time'].append(trans_sim.compute_time(L))
results['mamba_flash']['mem'].append(mamba_opt_sim.compute_memory(L))
results['mamba_flash']['time'].append(mamba_opt_sim.compute_time(L))
results['mamba_naive']['mem'].append(mamba_naive_sim.compute_memory(L))
results['mamba_naive']['time'].append(mamba_naive_sim.compute_time(L))
# 扩展曲线(用于绘制趋势线)
trend_data = {
'trans_mem': [trans_sim.compute_memory(l) for l in extended_lengths],
'trans_time': [trans_sim.compute_time(l) for l in extended_lengths],
'mamba_flash_mem': [mamba_opt_sim.compute_memory(l) for l in extended_lengths],
'mamba_flash_time': [mamba_opt_sim.compute_time(l) for l in extended_lengths],
}
return results, trend_data, extended_lengths
def visualize_experiment_results(results, trend_data, extended_lengths):
"""
生成论文级可视化图表
"""
fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.35)
colors = {
'transformer': '#d62728', # 红色
'mamba_flash': '#1f77b4', # 蓝色
'mamba_naive': '#2ca02c' # 绿色
}
labels = {
'transformer': 'Transformer (FlashAttention)',
'mamba_flash': 'Mamba (FlashFFTConv)',
'mamba_naive': 'Mamba (Naive FFT)'
}
# 1. 内存增长曲线(对数坐标)
ax1 = fig.add_subplot(gs[0, :2])
ax1.plot(extended_lengths/1024, trend_data['trans_mem'], '--',
color=colors['transformer'], alpha=0.5, linewidth=2)
ax1.plot(extended_lengths/1024, trend_data['mamba_flash_mem'], '--',
color=colors['mamba_flash'], alpha=0.5, linewidth=2)
# 标记关键长度点
for i, L in enumerate(results['lengths']):
L_k = L/1024
ax1.scatter([L_k], [results['transformer']['mem'][i]],
s=200, c=colors['transformer'], marker='o', zorder=5, edgecolors='black')
ax1.scatter([L_k], [results['mamba_flash']['mem'][i]],
s=200, c=colors['mamba_flash'], marker='s', zorder=5, edgecolors='black')
# 添加80GB线
ax1.axhline(y=80, color='k', linestyle='--', alpha=0.5, label='A100 80GB Limit')
ax1.fill_between(extended_lengths/1024, 80, 200, alpha=0.1, color='red', label='OOM Zone')
ax1.set_xlabel('Sequence Length (K tokens)')
ax1.set_ylabel('Memory Usage (GB)')
ax1.set_title('Memory Scaling: Theoretical Limits', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 100)
# 2. 时间增长曲线
ax2 = fig.add_subplot(gs[0, 2])
ax2.plot(extended_lengths/1024, trend_data['trans_time'], '--',
color=colors['transformer'], alpha=0.5, linewidth=2)
ax2.plot(extended_lengths/1024, trend_data['mamba_flash_time'], '--',
color=colors['mamba_flash'], alpha=0.5, linewidth=2)
for i, L in enumerate(results['lengths']):
L_k = L/1024
ax2.scatter([L_k], [results['transformer']['time'][i]],
s=150, c=colors['transformer'], marker='o', zorder=5)
ax2.scatter([L_k], [results['mamba_flash']['time'][i]],
s=150, c=colors['mamba_flash'], marker='s', zorder=5)
ax2.set_xlabel('Sequence Length (K tokens)')
ax2.set_ylabel('Time (seconds)')
ax2.set_title('Latency Comparison', fontsize=14, fontweight='bold')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)
# 3. 关键长度对比柱状图(内存)
ax3 = fig.add_subplot(gs[1, 0])
x = np.arange(len(results['lengths']))
width = 0.25
bars1 = ax3.bar(x - width, results['transformer']['mem'], width,
label='Transformer', color=colors['transformer'], alpha=0.8)
bars2 = ax3.bar(x, results['mamba_flash']['mem'], width,
label='Mamba-Flash', color=colors['mamba_flash'], alpha=0.8)
bars3 = ax3.bar(x + width, results['mamba_naive']['mem'], width,
label='Mamba-Naive', color=colors['mamba_naive'], alpha=0.8)
ax3.axhline(y=80, color='k', linestyle='--', alpha=0.5)
ax3.set_xlabel('Sequence Length')
ax3.set_ylabel('Memory (GB)')
ax3.set_title('Memory Footprint at Target Lengths')
ax3.set_xticks(x)
ax3.set_xticklabels(['16K', '32K', '128K'])
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')
# 添加数值标签
for bars in [bars1, bars2, bars3]:
for bar in bars:
height = bar.get_height()
ax3.annotate(f'{height:.1f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=8)
# 4. 关键长度对比柱状图(时间)
ax4 = fig.add_subplot(gs[1, 1])
bars1 = ax4.bar(x - width, results['transformer']['time'], width,
label='Transformer', color=colors['transformer'], alpha=0.8)
bars2 = ax4.bar(x, results['mamba_flash']['time'], width,
label='Mamba-Flash', color=colors['mamba_flash'], alpha=0.8)
bars3 = ax4.bar(x + width, results['mamba_naive']['time'], width,
label='Mamba-Naive', color=colors['mamba_naive'], alpha=0.8)
ax4.set_xlabel('Sequence Length')
ax4.set_ylabel('Time (seconds)')
ax4.set_title('Latency at Target Lengths')
ax4.set_xticks(x)
ax4.set_xticklabels(['16K', '32K', '128K'])
ax4.legend()
ax4.set_yscale('log')
ax4.grid(True, alpha=0.3, axis='y')
# 5. 内存-时间权衡散点图(帕累托前沿)
ax5 = fig.add_subplot(gs[1, 2])
# 生成多配置点以显示权衡曲线
configs = [(512, 8), (768, 12), (1024, 16), (2048, 24)] # (d_model, n_layers)
for d, l in configs:
cfg = ExperimentConfig(d_model=d, n_layers=l)
t_sim = TransformerSimulator(cfg)
m_sim = MambaSimulator(cfg, optimized=True)
for L in [16*1024, 32*1024, 128*1024]:
t_mem = t_sim.compute_memory(L)
t_time = t_sim.compute_time(L)
m_mem = m_sim.compute_memory(L)
m_time = m_sim.compute_time(L)
ax5.scatter(t_mem, t_time, c=colors['transformer'], s=100, alpha=0.6, marker='o')
ax5.scatter(m_mem, m_time, c=colors['mamba_flash'], s=100, alpha=0.6, marker='s')
ax5.set_xlabel('Memory (GB)')
ax5.set_ylabel('Time (seconds)')
ax5.set_title('Memory-Time Pareto Frontier\n(Multiple Configs)', fontsize=12)
ax5.set_xscale('log')
ax5.set_yscale('log')
ax5.grid(True, alpha=0.3)
# 添加图例说明
ax5.scatter([], [], c=colors['transformer'], s=100, marker='o', label='Transformer')
ax5.scatter([], [], c=colors['mamba_flash'], s=100, marker='s', label='Mamba')
ax5.legend()
# 6. 效率比率热图
ax6 = fig.add_subplot(gs[2, :])
# 创建效率比率矩阵(不同长度 vs 不同batch size)
batch_sizes = [1, 2, 4, 8, 16, 32]
lengths_grid = [16*1024, 32*1024, 64*1024, 128*1024]
efficiency_matrix = np.zeros((len(batch_sizes), len(lengths_grid)))
for i, bs in enumerate(batch_sizes):
cfg = ExperimentConfig(batch_size=bs)
t_sim = TransformerSimulator(cfg)
m_sim = MambaSimulator(cfg, optimized=True)
for j, L in enumerate(lengths_grid):
t_time = t_sim.compute_time(L)
m_time = m_sim.compute_time(L)
# 效率比 = Transformer时间 / Mamba时间
efficiency_matrix[i, j] = t_time / m_time if m_time > 0 else 0
im = ax6.imshow(efficiency_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=10)
ax6.set_xticks(range(len(lengths_grid)))
ax6.set_xticklabels(['16K', '32K', '64K', '128K'])
ax6.set_yticks(range(len(batch_sizes)))
ax6.set_yticklabels([str(bs) for bs in batch_sizes])
ax6.set_xlabel('Sequence Length')
ax6.set_ylabel('Batch Size')
ax6.set_title('Speedup Ratio (Transformer Time / Mamba Time)\nRed: Transformer Faster, Green: Mamba Faster',
fontsize=12)
# 添加数值标注
for i in range(len(batch_sizes)):
for j in range(len(lengths_grid)):
text = ax6.text(j, i, f'{efficiency_matrix[i, j]:.1f}x',
ha="center", va="center", color="black", fontsize=9)
plt.colorbar(im, ax=ax6, label='Speedup Factor')
plt.tight_layout()
plt.savefig('long_sequence_experiment.png', dpi=300, bbox_inches='tight')
plt.show()
# 打印实验报告
print("\n" + "="*80)
print("长序列实验结果报告 (16K/32K/128K)")
print("="*80)
print(f"{'Length':<10} {'Model':<20} {'Memory(GB)':<12} {'Time(s)':<10} {'Status'}")
print("-" * 80)
for i, L in enumerate(results['lengths']):
L_str = f"{int(L/1024)}K"
for model_key, model_name in [('transformer', 'Transformer'),
('mamba_flash', 'Mamba-Flash'),
('mamba_naive', 'Mamba-Naive')]:
mem = results[model_key]['mem'][i]
tim = results[model_key]['time'][i]
status = "✓ OK" if mem < 80 else "✗ OOM"
print(f"{L_str:<10} {model_name:<20} {mem:<12.2f} {tim:<10.4f} {status}")
print("-" * 80)
if __name__ == "__main__":
print("开始执行长序列对比实验...")
results, trend_data, extended_lengths = run_experiment()
visualize_experiment_results(results, trend_data, extended_lengths)
print("\n实验完成。所有可视化结果已保存至 long_sequence_experiment.png")