【vLLM 学习】Rlhf Utils

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

# 复制代码
import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated
相关推荐
P.H. Infinity2 小时前
【QLIB】三、学习层(一)
学习
OpenCSG2 小时前
提示词工程到AgenticOps:OpenCSG公益课
大数据·人工智能·开源·opencsg
悟纤2 小时前
Suno AI 音乐节奏设计完全指南 | Suno高级篇 | 第23篇
人工智能·suno·suno ai·suno api·ai music
EasyGBS2 小时前
EasyGBS的金融网点全场景智能可视化监管方案设计
大数据·人工智能
laplace01232 小时前
第八章 agent记忆与检索 下
数据库·人工智能·笔记·agent·rag
茜子.Java2 小时前
分享一个输入 allure --version 报错解决方式
学习
应用市场2 小时前
Intel Core Ultra 9做目标检测够用吗?CPU+GPU+NPU三路并发实测
人工智能·目标检测·计算机视觉
CoderOnly2 小时前
SFT后训练32B-LLM的一些观察
人工智能·大模型
拓端研究室3 小时前
2026中国医美护肤产品行业发展与未来趋势蓝皮书:射频、胶原蛋白、PDRN与肉毒素|附90+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能