【vLLM 学习】Rlhf Utils

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

# 复制代码
import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated
相关推荐
皮卡丘不断更3 小时前
手搓本地 RAG:我用 Python 和 Spring Boot 给 AI 装上了“实时代码监控”
人工智能·spring boot·python·ai编程
冬奇Lab3 小时前
Hook 机制实战:让 ClaudeCode 主动通知你
ai编程·claude
浪子小院3 小时前
ModelEngine 智能体全流程开发实战:从 0 到 1 搭建多协作办公助手
大数据·人工智能
程序员打怪兽3 小时前
详解YOLOv8网络结构
人工智能·深度学习
Yuer20253 小时前
全国首例“AI 幻觉”侵权案判了:这不是 AI 准不准的问题,而是谁该为 AI 负责
人工智能·edca os·可控ai
码路飞3 小时前
语音 AI Agent 延迟优化实战:我是怎么把响应时间从 2 秒干到 500ms 以内的
ai编程
一切尽在,你来3 小时前
1.1 AI大模型应用开发和Langchain的关系
人工智能·langchain
Coder_Boy_4 小时前
基于Spring AI的分布式在线考试系统-事件处理架构实现方案
人工智能·spring boot·分布式·spring
wdfk_prog4 小时前
[Linux]学习笔记系列 -- [drivers][input]serio
linux·笔记·学习
Light604 小时前
智链未来:彭山物流园区从物理基建到数据智能体的全维度构建方案
人工智能·系统架构·数字孪生·智慧物流·实施路径·彭山项目