PyTorch CUDA设备不可用错误解决方案

PyTorch CUDA设备不可用错误解决方案

问题描述

在使用PyTorch加载模型时遇到以下错误:

复制代码
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
已中止 (核心已转储)

解决方案

方案一:强制加载到CPU(快速解决)

python 复制代码
import torch

# 方法1:直接指定CPU设备
model = torch.load('model.pth', map_location=torch.device('cpu'))

# 方法2:使用字符串指定
model = torch.load('model.pth', map_location='cpu')

# 方法3:使用lambda函数
model = torch.load('model.pth', map_location=lambda storage, loc: storage.cpu())

方案二:检查并修复GPU环境(根本解决)

步骤1:检查显卡驱动
bash 复制代码
nvidia-smi
步骤2:验证PyTorch安装
python 复制代码
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"CUDA版本: {torch.version.cuda}")
步骤3:重新安装正确的PyTorch版本
bash 复制代码
# 安装支持CUDA的PyTorch
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# 或使用conda
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

方案三:使用state_dict方式(最佳实践)

python 复制代码
# 保存模型时使用state_dict
torch.save(model.state_dict(), 'model_weights.pth')

# 加载模型时指定设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))

特殊情况处理

AMD显卡用户

python 复制代码
# AMD显卡不支持CUDA,使用CPU
device = torch.device('cpu')

Apple M系列芯片

python 复制代码
# Apple M系列芯片使用MPS加速
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

完整解决方案代码

python 复制代码
import torch
import os

def load_model_safely(model_path, model_class, **kwargs):
    """
    安全加载模型,自动处理设备兼容性问题
    """
    # 检测可用设备
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print("✅ 检测到CUDA设备,使用GPU加速")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("✅ 检测到Apple M系列芯片,使用MPS加速")
    else:
        device = torch.device('cpu')
        print("💻 未检测到GPU设备,使用CPU")
    
    # 创建模型实例
    model = model_class(**kwargs)
    
    try:
        # 尝试加载state_dict
        state_dict = torch.load(model_path, map_location=device)
        
        # 处理checkpoint格式
        if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        
        # 处理多GPU训练的模型
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        
        print(f"✅ 模型成功加载到 {device}")
        return model, device
        
    except Exception as e:
        print(f"❌ 模型加载失败: {str(e)}")
        raise

# 使用示例
if __name__ == "__main__":
    # 假设你有一个模型类MyModel
    # model, device = load_model_safely('model.pth', MyModel, num_classes=10)
    pass

预防措施

  1. 保存模型时使用state_dict
python 复制代码
torch.save(model.state_dict(), 'model_weights.pth')
  1. 加载模型时指定设备
python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
  1. 定期检查环境配置
python 复制代码
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"GPU数量: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")

常见问题排查

问题1:安装了GPU版PyTorch但CUDA不可用

  • 检查显卡驱动是否最新
  • 确认PyTorch版本与CUDA版本匹配
  • 重启Python环境

问题2:版本不兼容

  • 卸载当前PyTorch:pip uninstall torch torchvision torchaudio
  • 重新安装匹配版本:参考PyTorch官网安装命令

问题3:多环境冲突

  • 确认当前使用的Python环境
  • 检查虚拟环境中的PyTorch安装
  • 使用which pythonwhich pip确认路径

注意:在生产环境中,建议始终使用state_dict方式保存和加载模型,以获得最佳的兼容性和灵活性。

相关推荐
Soari3 小时前
告别玩具级 Demo!深度拆解 agents-towards-production,用硬核工程把 AI Agent 推向工业级生产线
人工智能·软件工程·llmops·架构优化·genai·aiagent·生产级部署
minhuan3 小时前
RTX 4090显存终极优化:模型分层加载、CPU Offload显存和内存动态置换实践.179
人工智能·大模型应用·rtx 4090显存优化·模型分层加载·cpu offload优化
小郑加油3 小时前
python学习Day15:综合训练——数据清洗与缺失值补充
开发语言·python·学习
完成大叔3 小时前
Agent入门:用本地模型从零搭建
开发语言·python·langchain
2601_958548483 小时前
电镀整流机源头厂家:企业采购选型策略深度解析
人工智能
光锥智能3 小时前
智元WITA成为全国首例完成大模型备案的具身智能交互模型
人工智能
墨神谕3 小时前
人工智能(一)—AI的起源和发展
人工智能
科技云报道3 小时前
当攻击开始“自主决策”,安全体系如何应战?
人工智能
一切皆是因缘际会3 小时前
AI低代码开发实战:轻量化部署与多场景落地
人工智能·深度学习·低代码·机器学习·ai·架构