jetson上进行量化

要在jetson上量化,实际只要配置好torch环境就好了。配置好环境之后,直接把模型的原始Pth和量化pth一并放在jetson上,然后编程,加载pth文件到模型上,并且运行就好了。

但是,在jetson上安装torch请不要直接使用pip install安装pytorch.org官网上的版本,这会导致torch的部分功能不兼容,量化就不被兼容。

后果

  • 从 PyPI(pip install torch)安装的是通用 x86_64 版本
  • 这个版本默认不包含 ARM 架构的 QNNPACK 引擎,而Jetson就是ARM架构
  • PyPI 上的标准 PyTorch wheel 主要针对 x86 CPU(使用 FBGEMM)和 NVIDIA GPU(CUDA)

为什么会这样?

PyTorch 官方在 PyPI 上发布的 wheel 包为了减小体积和复杂度

  • x86_64 版本 → 只包含 FBGEMM(x86 优化)
  • ARM 版本 → 理论上应该包含 QNNPACK,但 PyPI 上可能没有完整的 ARM wheel

你的 Jetson 是 ARM64 架构,需要:

  1. NVIDIA 官方为 Jetson 编译的版本(包含 QNNPACK + CUDA for Tegra)
  2. 或自己编译启用 QNNPACK 的版本

毫无疑问,最优选肯定是卸载当前的torch,直接对应在Jetson上的torch。

首先卸载当前的torch。

bash 复制代码
# 1. 备份当前包列表
pip freeze > ./torch_backup_$(date +%Y%m%d).txt

# 2. 卸载现有 PyTorch
pip uninstall torch torchvision torchaudio -y

查看自己的jetson环境。

bash 复制代码
# 查看JetPack版本
cat /etc/nv_tegra_release

# 查看系统架构
uname -m

# 查看CUDA版本
nvcc --version

更新系统包

bash 复制代码
sudo apt update
sudo apt upgrade -y
sudo apt autoremove -y

安装必要依赖项

bash 复制代码
sudo apt install -y python3-pip libopenblas-base libopenmpi-dev libjpeg-dev zlib1g-dev

根据自己的环境,确定需要安装的torch包。

bash 复制代码
(venv) liuyang@ubuntu:~/pyq/jetson_deploy$ pip list | grep -i numpy
numpy             2.2.6
(venv) liuyang@ubuntu:~/pyq/jetson_deploy$ cat /etc/nv_tegra_release
# R36 (release), REVISION: 4.4, GCID: 41062509, BOARD: generic, EABI: aarch64, DATE: Mon Jun 16 16:07:13 UTC 2025
# KERNEL_VARIANT: oot
TARGET_USERSPACE_LIB_DIR=nvidia
TARGET_USERSPACE_LIB_DIR_PATH=usr/lib/aarch64-linux-gnu/nvidia
(venv) liuyang@ubuntu:~/pyq/jetson_deploy$ uname -m
aarch64
(venv) liuyang@ubuntu:~/pyq/jetson_deploy$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Aug_14_10:14:07_PDT_2024
Cuda compilation tools, release 12.6, V12.6.68
Build cuda_12.6.r12.6/compiler.34714021_0

Jetson 配置

  • JetPack 版本: R36.4.4 (JetPack 6.x 系列)
  • CUDA 版本: 12.6
  • 架构: aarch64 (ARM64)
  • 发布日期: 2025年6月 (非常新!)

这个地方下载的版本要选择最新的,因为我的cuDNN是9,所以直接选择jp6.1最新的。

安装代码如下:

bash 复制代码
pip3 install --no-cache https://developer.download.nvidia.com/compute/redist/jp/v61/pytorch/torch-2.5.0a0+872d972e41.nv24.08.17622132-cp310-cp310-linux_aarch64.whl

此时大概率还cusparselt库,去 https://developer.nvidia.com/cusparselt-downloads 安装即可。

选择后,这个页面会自动给出指令。

bash 复制代码
wget https://developer.download.nvidia.com/compute/cusparselt/0.8.1/local_installers/cusparselt-local-tegra-repo-ubuntu2204-0.8.1_0.8.1-1_arm64.deb
sudo dpkg -i cusparselt-local-tegra-repo-ubuntu2204-0.8.1_0.8.1-1_arm64.deb
sudo cp /var/cusparselt-local-tegra-repo-ubuntu2204-0.8.1/cusparselt-*-keyring.gpg /usr/share/keyrings/
sudo apt-get update
sudo apt-get -y install cusparselt

这里我踩坑了。随便选了个jp6.0,torch2.4的版本,结果这玩意对应的cuDNN是8,而我安装的是cuDNN9,直接导致torch无法运行。

如下是我当时的错误。

NVIDIA 为 JetPack 6.x 提供了官方 PyTorch wheel,大模型建议我下载 JetPack 6.x。

bash 复制代码
wget https://developer.download.nvidia.com/compute/redist/jp/v60/pytorch/torch-2.4.0a0+3bcc3cddb5.nv24.07.16234504-cp310-cp310-linux_aarch64.whl

下载到jetson的文件夹上后,开始安装

bash 复制代码
pip install torch-2.4.0a0+3bcc3cddb5.nv24.07.16234504-cp310-cp310-linux_aarch64.whl

此时运行测试文件的报错如下。

bash 复制代码
(venv) liuyang@ubuntu:~/pyq/jetson_deploy$ python test_torch.py 
Traceback (most recent call last):
  File "/home/liuyang/pyq/jetson_deploy/test_torch.py", line 86, in <module>
    import torch
  File "/home/liuyang/pyq/venv/lib/python3.10/site-packages/torch/__init__.py", line 289, in <module>
    from torch._C import *  # noqa: F403
ImportError: libcudnn.so.8: cannot open shared object file: No such file or directory

这个地方我卡了几小时,sb大模型还建议我重装cuDNN,实际上卸载torch重装就好了。

编写测试文件。

python 复制代码
import torch
print("="*60)
print(f"PyTorch 版本: {torch.__version__}")
print(f"CUDA 可用: {torch.cuda.is_available()}")
print(f"CUDA 版本: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")
print(f"量化引擎: {torch.backends.quantized.supported_engines}")

# 测试设置引擎
if 'qnnpack' in torch.backends.quantized.supported_engines:
    torch.backends.quantized.engine = 'qnnpack'
    print(f"✅ 成功设置量化引擎为: {torch.backends.quantized.engine}")
else:
    print(f"❌ QNNPACK 不可用")
    
print("="*60)

此时输出结果。

bash 复制代码
============================================================
✅ PyTorch 版本: 2.5.0a0+872d972e41.nv24.08
✅ CUDA 可用: True
✅ CUDA 版本: 12.6
✅ cuDNN 版本: 90300
量化引擎: ['qnnpack', 'none']
✅ 成功设置量化引擎为: qnnpack
/home/liuyang/pyq/venv/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1286: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  warnings.warn(
✅ 量化操作测试成功!
============================================================

现在其实,还有一个numpy警告,但是我不想降级Numpy,因为已经有很多包是基于>2的numpy安装的,一旦卸载可能会出现新的错误。

接下来就可以在jetson上运行量化模型了。

python 复制代码
#!/usr/bin/env python3
"""
模型验证脚本 - 支持原始模型和量化模型

验证方式:
- 原始模型:从 candidate_xxx_full.pth 加载
- 量化模型:从 quantized_xxx.pth 加载(精确复现)

支持的量化模式:none, static, qat, qaft
"""
import sys
sys.path.insert(0, '/home/liuyang/pyq/jetson_deploy')

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from data import get_data_loader, get_dataset_info
from models import CandidateModel
from models import fuse_model_modules, fuse_QATmodel_modules
from models import get_static_quantization_config

# ============ 添加这部分 ============
# 设置量化引擎(Jetson 是 ARM 架构,优先使用 qnnpack)
if torch.backends.quantized.supported_engines:
    if 'qnnpack' in torch.backends.quantized.supported_engines:
        torch.backends.quantized.engine = 'qnnpack'
        print(f"✅ 使用量化引擎: qnnpack (ARM 优化)")
    elif 'fbgemm' in torch.backends.quantized.supported_engines:
        torch.backends.quantized.engine = 'fbgemm'
        print(f"✅ 使用量化引擎: fbgemm (x86 优化)")
    else:
        print(f"⚠️ 支持的引擎: {torch.backends.quantized.supported_engines}")
else:
    raise RuntimeError("❌ 当前 PyTorch 不支持量化引擎")
# ===================================

# 固定随机种子
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)


def load_model(checkpoint_path):
    """
    加载模型(自动检测并处理 QAT/QAFT 权重)
    
    Returns:
        model, task_head, checkpoint
    """
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    config = checkpoint['config']
    task_type = checkpoint['task_type']
    quant_mode = checkpoint.get('quant_mode', 'none')
    
    # 检查是否是 QAT/QAFT 权重
    model_keys = checkpoint['model'].keys()
    has_qat_keys = any('activation_post_process' in k or 'weight_fake_quant' in k 
                      for k in model_keys)
    
    # 重建模型
    candidate = CandidateModel(config=config, task_type=task_type)
    model = candidate.build_model()
    model.to('cpu')
    
    # 如果是 QAT/QAFT 权重,需要先准备对应结构
    if has_qat_keys and quant_mode in ['qat', 'qaft']:
        print(f"  - 检测到 {quant_mode.upper()} 权重,准备对应结构...")
        
        if quant_mode == 'qaft':
            model.qconfig = torch.quantization.QConfig(
                activation=torch.quantization.FakeQuantize.with_args(
                    observer=torch.quantization.MovingAverageMinMaxObserver,
                    quant_min=0, quant_max=255,
                    dtype=torch.quint8,
                    qscheme=torch.per_tensor_affine,
                    reduce_range=False
                ),
                weight=torch.quantization.FakeQuantize.with_args(
                    observer=torch.quantization.MovingAveragePerChannelMinMaxObserver,
                    quant_min=-128, quant_max=127,
                    dtype=torch.qint8,
                    qscheme=torch.per_channel_symmetric,
                    reduce_range=False
                )
            )
        else:  # qat
            model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        
        fuse_QATmodel_modules(model)
        model.train()
        torch.quantization.prepare_qat(model, inplace=True)
        model.eval()
    
    # 加载权重
    model.load_state_dict(checkpoint['model'], strict=True)
    print(f"  ✅ 模型权重加载成功")
    
    # 加载 task_head
    head_shape = checkpoint['head']['weight'].shape
    task_head = nn.Linear(head_shape[1], head_shape[0])
    task_head.load_state_dict(checkpoint['head'])
    
    return model, task_head, checkpoint


def load_quantized_model(checkpoint_path, quantized_path):
    """
    从量化 checkpoint 加载量化模型(精确复现)
    
    Returns:
        quantized_model
    """
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    quant_ckpt = torch.load(quantized_path, map_location='cpu', weights_only=False)
    
    config = checkpoint['config']
    task_type = checkpoint['task_type']
    quant_mode = checkpoint['quant_mode']
    quant_method = quant_ckpt.get('quant_method', 'unknown')
    
    print(f"  - 量化模式: {quant_mode}, 方法: {quant_method}")
    
    # 重建模型
    candidate = CandidateModel(config=config, task_type=task_type)
    model = candidate.build_model()
    model.to('cpu')
    
    # 根据量化模式准备结构并转换
    if quant_mode == 'static':
        model.eval()
        fuse_model_modules(model)
        
        # 根据方法选择精度
        if 'per_channel' in quant_method:
            precision = 'int8_per_channel'
        elif 'histogram' in quant_method:
            precision = 'histogram'
        elif 'moving_avg' in quant_method:
            precision = 'moving_average'
        else:
            precision = 'int8'
        
        quant_config = get_static_quantization_config(precision)
        model.qconfig = quant_config['qconfig']
        torch.quantization.prepare(model, inplace=True)
        quantized_model = torch.quantization.convert(model, inplace=False)
        
    elif quant_mode in ['qat', 'qaft']:
        if quant_mode == 'qaft':
            model.qconfig = torch.quantization.QConfig(
                activation=torch.quantization.FakeQuantize.with_args(
                    observer=torch.quantization.MovingAverageMinMaxObserver,
                    quant_min=0, quant_max=255,
                    dtype=torch.quint8,
                    qscheme=torch.per_tensor_affine,
                    reduce_range=False
                ),
                weight=torch.quantization.FakeQuantize.with_args(
                    observer=torch.quantization.MovingAveragePerChannelMinMaxObserver,
                    quant_min=-128, quant_max=127,
                    dtype=torch.qint8,
                    qscheme=torch.per_channel_symmetric,
                    reduce_range=False
                )
            )
        else:
            model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        
        fuse_QATmodel_modules(model)
        model.train()
        torch.quantization.prepare_qat(model, inplace=True)
        model.eval()
        quantized_model = torch.quantization.convert(model, inplace=False)
    else:
        raise ValueError(f"未知量化模式: {quant_mode}")
    
    # 加载量化权重
    quantized_model.load_state_dict(quant_ckpt['quantized_model'], strict=True)
    print(f"  ✅ 量化权重加载成功")
    
    return quantized_model


def evaluate(model, task_head, dataloader, description="模型"):
    """评估模型准确率"""
    model.eval()
    task_head.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader['test'], desc=f"评估{description}"):
            features = model(inputs)
            
            # 处理量化输出
            if hasattr(features, 'dequantize'):
                features = features.dequantize()
            
            outputs = task_head(features)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100.0 * correct / total
    return accuracy


def verify(checkpoint_path, quantized_path=None):
    """
    验证模型
    
    Args:
        checkpoint_path: candidate_xxx_full.pth
        quantized_path: quantized_xxx.pth(可选)
    """
    print(f"\n{'='*60}")
    print(f"模型验证")
    print('='*60)
    
    # 加载 checkpoint 信息
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    dataset_name = checkpoint['dataset_name']
    task_type = checkpoint['task_type']
    quant_mode = checkpoint['quant_mode']
    metrics = checkpoint['metrics']
    
    print(f"\n📋 模型信息:")
    print(f"  - 数据集: {dataset_name}")
    print(f"  - 量化模式: {quant_mode}")
    print(f"  - 保存的原始准确率: {metrics.get('original_accuracy_percent', 0):.2f}%")
    if quant_mode != 'none':
        print(f"  - 保存的量化准确率: {metrics.get('quantized_accuracy', 0):.2f}%")
    
    # 加载数据集
    print(f"\n📂 加载数据集...")
    data_loader = get_data_loader(task_type)
    root_dir = "/home/liuyang/pyq/jetson_deploy/data/UniMTS_data" if task_type == 'har' else "/home/liuyang/pyq/jetson_deploy/data"
    batch_size = 64 if task_type == 'har' else 128
    
    dataloaders = data_loader.get_dataloaders(
        root_dir=root_dir,
        batch_size=batch_size,
        datasets_list=[dataset_name],
        num_workers=0
    )
    test_dataloader = dataloaders[dataset_name]
    
    # ========== 验证原始模型 ==========
    print(f"\n🔧 加载原始模型...")
    model, task_head, _ = load_model(checkpoint_path)
    
    original_acc = evaluate(model, task_head, test_dataloader, "原始模型")
    saved_original_acc = metrics.get('original_accuracy_percent', 0)
    
    print(f"\n📊 原始模型验证:")
    print(f"  - 复现: {original_acc:.2f}% | 保存: {saved_original_acc:.2f}% | 差异: {abs(original_acc - saved_original_acc):.2f}%")
    
    if abs(original_acc - saved_original_acc) < 1.0:
        print(f"  ✅ 验证通过")
    else:
        print(f"  ⚠️ 差异较大(QAT/QAFT模式下正常,因为观察器状态未完全保存)")
    
    # ========== 验证量化模型 ==========
    if quant_mode != 'none' and quantized_path:
        print(f"\n🔧 加载量化模型...")
        quantized_model = load_quantized_model(checkpoint_path, quantized_path)
        
        # 重新加载 task_head
        _, task_head, _ = load_model(checkpoint_path)
        
        quant_acc = evaluate(quantized_model, task_head, test_dataloader, "量化模型")
        
        # 从量化 checkpoint 获取保存的准确率
        quant_ckpt = torch.load(quantized_path, map_location='cpu', weights_only=False)
        saved_quant_acc = quant_ckpt.get('quant_accuracy', 0)
        
        print(f"\n📊 量化模型验证:")
        print(f"  - 复现: {quant_acc:.2f}% | 保存: {saved_quant_acc:.2f}% | 差异: {abs(quant_acc - saved_quant_acc):.2f}%")
        
        if abs(quant_acc - saved_quant_acc) < 0.5:
            print(f"  ✅ 验证通过")
        else:
            print(f"  ⚠️ 差异较大")
    
    print(f"\n{'='*60}")
    print(f"验证完成")
    print('='*60)


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="模型验证工具")
    parser.add_argument('--checkpoint', type=str, required=True,
                       help='原始模型 checkpoint (candidate_xxx_full.pth)')
    parser.add_argument('--quantized', type=str, default=None,
                       help='量化模型 checkpoint (quantized_xxx.pth)')
    
    args = parser.parse_args()
    verify(args.checkpoint, args.quantized)
相关推荐
草莓熊Lotso32 分钟前
Linux 文件描述符与重定向实战:从原理到 minishell 实现
android·linux·运维·服务器·数据库·c++·人工智能
历程里程碑36 分钟前
Linux22 文件系统
linux·运维·c语言·开发语言·数据结构·c++·算法
七夜zippoe9 小时前
CANN Runtime任务描述序列化与持久化源码深度解码
大数据·运维·服务器·cann
Fcy64810 小时前
Linux下 进程(一)(冯诺依曼体系、操作系统、进程基本概念与基本操作)
linux·运维·服务器·进程
袁袁袁袁满10 小时前
Linux怎么查看最新下载的文件
linux·运维·服务器
代码游侠10 小时前
学习笔记——设备树基础
linux·运维·开发语言·单片机·算法
Harvey90311 小时前
通过 Helm 部署 Nginx 应用的完整标准化步骤
linux·运维·nginx·k8s
珠海西格电力科技12 小时前
微电网能量平衡理论的实现条件在不同场景下有哪些差异?
运维·服务器·网络·人工智能·云计算·智慧城市
释怀不想释怀12 小时前
Linux环境变量
linux·运维·服务器
zzzsde12 小时前
【Linux】进程(4):进程优先级&&调度队列
linux·运维·服务器