PyTorch CUDA设备不可用错误解决方案

PyTorch CUDA设备不可用错误解决方案

问题描述

在使用PyTorch加载模型时遇到以下错误:

复制代码
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
已中止 (核心已转储)

解决方案

方案一:强制加载到CPU(快速解决)

python 复制代码
import torch

# 方法1:直接指定CPU设备
model = torch.load('model.pth', map_location=torch.device('cpu'))

# 方法2:使用字符串指定
model = torch.load('model.pth', map_location='cpu')

# 方法3:使用lambda函数
model = torch.load('model.pth', map_location=lambda storage, loc: storage.cpu())

方案二:检查并修复GPU环境(根本解决)

步骤1:检查显卡驱动
bash 复制代码
nvidia-smi
步骤2:验证PyTorch安装
python 复制代码
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"CUDA版本: {torch.version.cuda}")
步骤3:重新安装正确的PyTorch版本
bash 复制代码
# 安装支持CUDA的PyTorch
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# 或使用conda
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

方案三:使用state_dict方式(最佳实践)

python 复制代码
# 保存模型时使用state_dict
torch.save(model.state_dict(), 'model_weights.pth')

# 加载模型时指定设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))

特殊情况处理

AMD显卡用户

python 复制代码
# AMD显卡不支持CUDA,使用CPU
device = torch.device('cpu')

Apple M系列芯片

python 复制代码
# Apple M系列芯片使用MPS加速
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

完整解决方案代码

python 复制代码
import torch
import os

def load_model_safely(model_path, model_class, **kwargs):
    """
    安全加载模型,自动处理设备兼容性问题
    """
    # 检测可用设备
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print("✅ 检测到CUDA设备,使用GPU加速")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("✅ 检测到Apple M系列芯片,使用MPS加速")
    else:
        device = torch.device('cpu')
        print("💻 未检测到GPU设备,使用CPU")
    
    # 创建模型实例
    model = model_class(**kwargs)
    
    try:
        # 尝试加载state_dict
        state_dict = torch.load(model_path, map_location=device)
        
        # 处理checkpoint格式
        if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        
        # 处理多GPU训练的模型
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        
        print(f"✅ 模型成功加载到 {device}")
        return model, device
        
    except Exception as e:
        print(f"❌ 模型加载失败: {str(e)}")
        raise

# 使用示例
if __name__ == "__main__":
    # 假设你有一个模型类MyModel
    # model, device = load_model_safely('model.pth', MyModel, num_classes=10)
    pass

预防措施

  1. 保存模型时使用state_dict
python 复制代码
torch.save(model.state_dict(), 'model_weights.pth')
  1. 加载模型时指定设备
python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
  1. 定期检查环境配置
python 复制代码
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"GPU数量: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")

常见问题排查

问题1:安装了GPU版PyTorch但CUDA不可用

  • 检查显卡驱动是否最新
  • 确认PyTorch版本与CUDA版本匹配
  • 重启Python环境

问题2:版本不兼容

  • 卸载当前PyTorch:pip uninstall torch torchvision torchaudio
  • 重新安装匹配版本:参考PyTorch官网安装命令

问题3:多环境冲突

  • 确认当前使用的Python环境
  • 检查虚拟环境中的PyTorch安装
  • 使用which pythonwhich pip确认路径

注意:在生产环境中,建议始终使用state_dict方式保存和加载模型,以获得最佳的兼容性和灵活性。

相关推荐
人工智能AI技术7 分钟前
【VibeCoding系列教程12】 AI代码编辑器
人工智能
辣椒思密达13 分钟前
Python公开数据采集实战:如何解决请求高频拦截与Session会话中断问题
开发语言·python
zhangfeng113315 分钟前
ai训练 顿悟“总数据量是 m²,训练所需要的数据量是 log m
人工智能
半兽先生27 分钟前
05阶段:NLP自然语言处理基础
人工智能·自然语言处理
盈飞无限31 分钟前
SPC选型:智能VS传统,谁更懂中国制造?
人工智能·制造
li-xun32 分钟前
LINUX DO 社区注册机制调整与公益 AI 服务动态
linux·运维·人工智能
云烟成雨TD36 分钟前
Spring AI 1.x 系列【50】可观测性:接入 Prometheus + Grafana
人工智能·spring·prometheus
Albart57542 分钟前
Python 实战教程:用 30 分钟学会解决真实问题
开发语言·python
2301_773643621 小时前
ceph池
开发语言·ceph·python