【译】在 Mac 上加速 PyTorch 训练

写在前面

  1. 为什么突然深度介入大模型领域了

因为最近在评估大模型用于行业应用,通过 OpenCompass 排行榜了解到了很多大模型,像文心一言是自己深度试用过的,趁着这次评估,也体验或者通过其他团队的介绍了解了通义千问、清华智谱、书生·浦语。

  1. 为什么翻译这篇文章

从清华智普开源的 ChatGLM3-6B 模型看到说苹果电脑本地(苹果芯片或者带了 AMD 独立显卡的)运行大模型有 MPS 在后面,不用关心显存,所以准备安装了环境,自己来深度体验一下ChatGLM3-6B的微调,这里要感谢 ChatGLM3-6B 的官方文档,写的很详细。

Mac开发者无需关注GPU的限制。对于搭载了 Apple Silicon 或者 AMD GPU 的 Mac,可以使用 MPS 后端来在 GPU 上运行 ChatGLM3-6B。需要参考 Apple 的 官方说明 安装 PyTorch-Nightly(正确的版本号应该是2.x.x.dev2023xxxx,而不是 2.x.x)。

原文地址:Accelerated PyTorch training on Mac

以下是译文

一、Metal 加速

PyTorch 使用新的 Metal Performance Shaders (MPS) 后端为 GPU 训练加速。MPS 后端扩展了 PyTorch 框架,提供了在 Mac 上设置和运行操作的脚本和功能。MPS 框架通过针对每个 Metal GPU 系列的独特特性进行微调的内核来优化计算性能。新的 MPS 设备将机器学习计算图形和基元映射到 MPS Graph 框架和 MPS 提供的调整内核上。

二、要求

  • 配备 Apple silicon 或 AMD GPU 的 Mac 电脑
  • macOS 12.3 或更高版本
  • Python 3.7 或更高版本
  • Xcode 命令行工具:xcode-select --install

三、开始

您可以使用 Anaconda 或 pip。请注意,使用 Apple 芯片的 Mac 和使用 Intel x86 的 Mac 的环境设置会有所不同。

使用安装页面上的 PyTorch 安装选择器,为 MPS 设备加速选择 Preview (Nightly)。MPS 后端支持是 PyTorch 1.12 正式版的一部分。PyTorch 的预览版(夜间版)将为您的设备提供最新的 MPS 支持。

  1. 设置

Anaconda

Apple silicon

bash 复制代码
curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh
sh Miniconda3-latest-MacOSX-arm64.sh

x86

bash 复制代码
curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
sh Miniconda3-latest-MacOSX-x86_64.sh

pip

你可以使用 macOS 预装的 pip3。或者,你也可以从 Python 网站或 Homebrew 软件包管理器中安装。

  1. 安装

Anaconda

bash 复制代码
conda install pytorch torchvision torchaudio -c pytorch-nightly

pip

bash 复制代码
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

从源代码构建

构建支持 MPS 的 PyTorch 需要 Xcode 13.3.1 或更高版本。您可以从 Mac App Store 下载最新的公开 Xcode 版本,或从 Mac App Store 下载最新的测试版,或从 Apple Developer 网站下载最新的测试版。USE_MPS 环境变量控制 PyTorch 的构建,并包含 MPS 支持。

要构建 PyTorch,请遵循 PyTorch 网站上提供的说明。

  1. 验证

您可以使用简单的 Python 脚本验证 MPS 支持:

bash 复制代码
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

输出结果应显示

bash 复制代码
tensor([1.], device='mps:0')

四、反馈意见

MPS 后端处于测试阶段,我们正在积极解决问题和修复错误。要报告问题,请使用 GitHub 问题跟踪器,标签为 "module: mps"。

五、Resources

相关推荐
小明的IT世界4 分钟前
Agent系列3:改变你做 AI Agent 的方式
人工智能
AI科技摆渡10 分钟前
三步极速对接 Grok-Video-3 视频生成 API
人工智能·音视频
是大强17 分钟前
NCNN简介
人工智能
数字游民952720 分钟前
gpt image 2怎么用?3个案例+使用方法
人工智能·ai·数字游民9527
minhuan25 分钟前
大模型反向优化传统算法:用大模型学习传统算法的缺陷,反向迭代算法逻辑.152
人工智能·大模型算法应用·大模型反向优化传统算法·算法优化方案
新缸中之脑34 分钟前
用Remotion构建AI生成视频
人工智能·音视频
belldeep34 分钟前
Blender + AI 全套工作流
人工智能·ai·blender
何陋轩35 分钟前
【重磅】悟空来了:国产AI编程助手深度测评,能否吊打Copilot?
人工智能·算法·面试
AI医影跨模态组学37 分钟前
如何将深度学习MRI表型与iCCA淋巴结转移的生物学机制(KRAS突变、MUC5AC、免疫抑制微环境、大导管亚型)关联,并解释其对治疗响应的意义
人工智能·深度学习·机器学习·论文·医学·医学影像
GreenTea41 分钟前
DeepSeek-V4 技术报告深度分析:基础研究创新全景
前端·人工智能·后端