文章目录
- [0 参考资料](#0 参考资料)
- [1 虚拟环境依赖版本详情](#1 虚拟环境依赖版本详情)
- [2 Easy R1训练环境部署](#2 Easy R1训练环境部署)
- [3 vllm_utils.py API修正](#3 vllm_utils.py API修正)
- [4 训练数据集配置](#4 训练数据集配置)
- [5 训练奖励函数配置(适配H800双卡)](#5 训练奖励函数配置(适配H800双卡))
- [6 训练配置文件](#6 训练配置文件)
- [7 训练启动脚本](#7 训练启动脚本)
0 参考资料
- GitHub - hiyouga/EasyR1: EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL
- Release v2.8.3 · Dao-AILab/flash-attention · GitHub
- Previous PyTorch Versions
1 虚拟环境依赖版本详情
- 表中的依赖和版本经过作者验证可以正常进行训练,vllm 0.15.1由于api变动导致训练的问题已在本文后续部分给出兼容代码。
- 如果安装过程中出现部分依赖兼容问题,可以查阅此表,希望提供帮助快速解决环境兼容问题。
| Package | Version | 说明 |
|---|---|---|
| accelerate | 1.12.0 | Hugging Face的模型训练和推理加速库 |
| aiohappyeyeballs | 2.6.1 | 异步DNS解析库,用于Happy Eyeballs算法 |
| aiohttp | 3.13.3 | 异步HTTP客户端/服务器框架 |
| aiohttp-cors | 0.8.1 | aiohttp的跨域资源共享(CORS)处理 |
| aiosignal | 1.4.0 | 异步信号支持库 |
| annotated-doc | 0.0.4 | 文档注释处理工具 |
| annotated-types | 0.7.0 | Python类型注解扩展 |
| anthropic | 0.79.0 | Anthropic API客户端(Claude模型) |
| antlr4-python3-runtime | 4.9.3 | ANTLR4解析器的Python运行时 |
| anyio | 4.12.1 | 异步I/O抽象层,兼容asyncio/trio |
| apache-tvm-ffi | 0.1.8.post2 | Apache TVM深度学习编译器FFI接口 |
| astor | 0.8.1 | Python AST(抽象语法树)操作库 |
| attrs | 25.4.0 | Python类装饰器,简化类定义 |
| av | 16.1.0 | PyAV多媒体处理库(FFmpeg绑定) |
| blake3 | 1.0.8 | BLAKE3哈希算法实现 |
| cachetools | 7.0.1 | 可扩展缓存集合(内存缓存等) |
| cbor2 | 5.8.0 | CBOR(简明二进制对象表示)编码/解码 |
| certifi | 2026.1.4 | Mozilla SSL证书集合 |
| cffi | 2.0.0 | C Foreign Function Interface |
| charset-normalizer | 3.4.4 | 字符编码自动检测 |
| click | 8.3.1 | 命令行接口创建工具包 |
| cloudpickle | 3.1.2 | 扩展的pickle模块,支持更多类型序列化 |
| codetiming | 1.4.0 | 代码执行时间测量工具 |
| colorful | 0.5.8 | 终端彩色输出库 |
| compressed-tensors | 0.13.0 | 压缩张量数据处理 |
| conda-pack | 0.8.1 | Conda环境打包工具 |
| cryptography | 46.0.5 | 加密算法和协议实现 |
| cuda-bindings | 13.1.1 | CUDA Python绑定 |
| cuda-pathfinder | 1.3.4 | CUDA路径查找工具 |
| cuda-python | 13.1.1 | NVIDIA CUDA Python接口 |
| cupy-cuda12x | 13.6.0 | NumPy兼容的CUDA数组计算库 |
| datasets | 4.5.0 | Hugging Face数据集加载和处理库 |
| depyf | 0.20.0 | Python反编译器 |
| dill | 0.4.0 | 扩展的pickle,支持序列化更多Python对象 |
| diskcache | 5.6.3 | 磁盘和文件缓存库 |
| distlib | 0.4.0 | 分发包的低级工具 |
| distro | 1.9.0 | Linux发行版信息获取 |
| dnspython | 2.8.0 | DNS工具包 |
| docstring_parser | 0.17.0 | 文档字符串解析器 |
| einops | 0.8.2 | 张量操作库,支持爱因斯坦求和约定 |
| email-validator | 2.3.0 | 电子邮件地址验证 |
| fastapi | 0.129.0 | 现代高性能Web框架 |
| fastapi-cli | 0.0.21 | FastAPI命令行工具 |
| fastapi-cloud-cli | 0.12.0 | FastAPI云部署CLI工具 |
| fastar | 0.8.0 | 快速数组操作库 |
| fastrlock | 0.8.3 | 快速可重入锁实现 |
| filelock | 3.24.0 | 文件锁实现,用于并发控制 |
| flash_attn | 2.8.3 | FlashAttention快速注意力机制实现 |
| flashinfer-python | 0.6.1 | LLM推理加速库(采样、注意力等) |
| frozenlist | 1.8.0 | 不可变列表实现 |
| fsspec | 2025.10.0 | 文件系统规范,统一文件系统接口 |
| gguf | 0.17.1 | GGUF格式(GGML通用文件格式)处理 |
| gitdb | 4.0.12 | Git对象数据库 |
| GitPython | 3.1.46 | Python Git操作库 |
| google-api-core | 2.29.0 | Google API客户端核心库 |
| google-auth | 2.48.0 | Google认证库 |
| googleapis-common-protos | 1.72.0 | Google API通用协议缓冲区 |
| grpcio | 1.78.0 | gRPC Python实现 |
| grpcio-reflection | 1.78.0 | gRPC服务反射 |
| h11 | 0.16.0 | HTTP/1.1协议纯Python实现 |
| hf-xet | 1.2.0 | Hugging Face XET存储后端 |
| httpcore | 1.0.9 | HTTP核心库 |
| httptools | 0.7.1 | HTTP解析工具 |
| httpx | 0.28.1 | 现代异步HTTP客户端 |
| httpx-sse | 0.4.3 | HTTPX的服务器发送事件(SSE)支持 |
| huggingface_hub | 0.36.2 | Hugging Face模型和数据集管理 |
| idna | 3.11 | 国际化域名编码 |
| ijson | 3.4.0.post0 | 迭代式JSON解析器 |
| importlib_metadata | 8.7.1 | 导入元数据读取(兼容旧版Python) |
| interegular | 0.3.3 | 正则表达式交集计算 |
| Jinja2 | 3.1.6 | 模板引擎 |
| jiter | 0.13.0 | 快速JSON解析器 |
| jmespath | 1.1.0 | JSON查询语言 |
| jsonschema | 4.26.0 | JSON Schema验证 |
| jsonschema-specifications | 2025.9.1 | JSON Schema规范数据 |
| lark | 1.2.2 | 解析库,支持多种语法 |
| liger_kernel | 0.7.0 | LLM训练优化内核(融合算子) |
| llguidance | 1.3.0 | LLM结构化生成约束库 |
| llvmlite | 0.44.0 | LLVM的轻量级Python绑定 |
| lm-format-enforcer | 0.11.3 | 语言模型输出格式强制 |
| loguru | 0.7.3 | 现代日志记录库 |
| markdown-it-py | 4.0.0 | Markdown解析器 |
| MarkupSafe | 2.1.5 | XML/HTML/XHTML标记安全字符串 |
| mathruler | 0.1.0 | 数学规则处理 |
| mcp | 1.26.0 | Model Context Protocol实现 |
| mdurl | 0.1.2 | Markdown URL处理 |
| mistral_common | 1.9.1 | Mistral AI通用工具 |
| model-hosting-container-standards | 0.1.13 | 模型托管容器标准 |
| mpmath | 1.3.0 | 多精度数学库 |
| msgpack | 1.1.2 | MessagePack序列化格式 |
| msgspec | 0.20.0 | 高性能序列化/验证库 |
| multidict | 6.7.1 | 多值字典实现 |
| multiprocess | 0.70.18 | 多进程处理库(dill支持) |
| networkx | 3.6.1 | 复杂网络分析库 |
| ninja | 1.13.0 | Ninja构建系统 |
| numba | 0.61.2 | JIT编译器,加速Python数值计算 |
| numpy | 2.2.6 | 数值计算基础库 |
| nvidia-cublas-cu12 | 12.8.4.1 | NVIDIA cuBLAS线性代数库(CUDA 12) |
| nvidia-cuda-cupti-cu12 | 12.8.90 | NVIDIA CUDA性能工具接口 |
| nvidia-cuda-nvrtc-cu12 | 12.8.93 | NVIDIA CUDA运行时编译 |
| nvidia-cuda-runtime-cu12 | 12.8.90 | NVIDIA CUDA运行时 |
| nvidia-cudnn-cu12 | 9.10.2.21 | NVIDIA cuDNN深度学习库 |
| nvidia-cudnn-frontend | 1.18.0 | cuDNN前端API |
| nvidia-cufft-cu12 | 11.3.3.83 | NVIDIA cuFFT快速傅里叶变换 |
| nvidia-cufile-cu12 | 1.13.1.3 | NVIDIA cuFile GPUDirect存储 |
| nvidia-curand-cu12 | 10.3.9.90 | NVIDIA cuRAND随机数生成 |
| nvidia-cusolver-cu12 | 11.7.3.90 | NVIDIA cuSolver密集和稀疏求解器 |
| nvidia-cusparse-cu12 | 12.5.8.93 | NVIDIA cuSPARSE稀疏矩阵库 |
| nvidia-cusparselt-cu12 | 0.7.1 | NVIDIA cuSPARSELt稀疏矩阵库 |
| nvidia-cutlass-dsl | 4.4.0 | NVIDIA CUTLASS DSL |
| nvidia-cutlass-dsl-libs-base | 4.4.0 | CUTLASS基础库 |
| nvidia-ml-py | 13.590.48 | NVIDIA管理库Python绑定 |
| nvidia-nccl-cu12 | 2.27.5 | NVIDIA NCCL多GPU通信库 |
| nvidia-nvjitlink-cu12 | 12.8.93 | NVIDIA JIT链接库 |
| nvidia-nvshmem-cu12 | 3.3.20 | NVIDIA NVSHMEM共享内存 |
| nvidia-nvtx-cu12 | 12.8.90 | NVIDIA NVTX标记工具 |
| omegaconf | 2.3.0 | 分层配置管理 |
| openai | 2.21.0 | OpenAI API客户端 |
| openai-harmony | 0.0.8 | OpenAI Harmony工具 |
| opencensus | 0.11.4 | 分布式追踪和指标收集 |
| opencensus-context | 0.1.3 | OpenCensus上下文传播 |
| opencv-python-headless | 4.13.0.92 | OpenCV计算机视觉库(无GUI) |
| opentelemetry-api | 1.39.1 | OpenTelemetry API |
| opentelemetry-exporter-prometheus | 0.60b1 | OpenTelemetry Prometheus导出器 |
| opentelemetry-proto | 1.39.1 | OpenTelemetry协议缓冲区 |
| opentelemetry-sdk | 1.39.1 | OpenTelemetry SDK |
| opentelemetry-semantic-conventions | 0.60b1 | OpenTelemetry语义约定 |
| orjson | 3.11.7 | 高性能JSON库 |
| outlines_core | 0.2.11 | 结构化生成核心库 |
| packaging | 25.0 | 包版本和依赖处理 |
| pandas | 3.0.0 | 数据分析和处理库 |
| partial-json-parser | 0.2.1.1.post7 | 部分JSON解析器 |
| peft | 0.18.1 | 参数高效微调库 |
| pillow | 12.0.0 | 图像处理库 |
| pip | 26.0.1 | Python包安装工具 |
| platformdirs | 4.9.1 | 跨平台目录路径确定 |
| prettytable | 3.17.0 | 美观的ASCII表格生成 |
| prometheus_client | 0.24.1 | Prometheus监控客户端 |
| prometheus-fastapi-instrumentator | 7.1.0 | FastAPI Prometheus指标收集 |
| propcache | 0.4.1 | 属性缓存工具 |
| proto-plus | 1.27.1 | Protocol Buffers增强 |
| protobuf | 6.33.5 | 协议缓冲区序列化 |
| psutil | 7.2.2 | 系统和进程监控 |
| py-cpuinfo | 9.0.0 | CPU信息获取 |
| py-spy | 0.4.1 | Python性能分析器(采样) |
| pyarrow | 23.0.0 | Apache Arrow列式数据格式 |
| pyasn1 | 0.6.2 | ASN.1编码/解码 |
| pyasn1_modules | 0.4.2 | ASN.1模块集合 |
| pybase64 | 1.4.3 | Base64编码/解码加速 |
| pycountry | 24.6.1 | 国家/地区数据 |
| pycparser | 3.0 | C解析器 |
| pydantic | 2.12.5 | 数据验证和设置管理 |
| pydantic_core | 2.41.5 | Pydantic核心功能 |
| pydantic-extra-types | 2.11.0 | Pydantic额外类型 |
| pydantic-settings | 2.12.0 | Pydantic设置管理 |
| pyecharts | 2.1.0 | ECharts图表Python绑定 |
| Pygments | 2.19.2 | 语法高亮库 |
| PyJWT | 2.11.0 | JSON Web Token实现 |
| pylatexenc | 2.10 | LaTeX编码处理 |
| python-dateutil | 2.9.0.post0 | 日期时间处理扩展 |
| python-dotenv | 1.2.1 | 环境变量从.env文件加载 |
| python-json-logger | 4.0.0 | JSON格式日志记录 |
| python-multipart | 0.0.22 | 多部分表单数据解析 |
| pyvers | 0.2.2 | Python版本管理 |
| PyYAML | 6.0.3 | YAML解析和生成 |
| pyzmq | 27.1.0 | ZeroMQ Python绑定 |
| qwen-vl-utils | 0.0.14 | 通义千问视觉语言模型工具 |
| ray | 2.53.0 | 分布式计算和机器学习框架 |
| referencing | 0.37.0 | JSON引用解析 |
| regex | 2026.1.15 | 正则表达式增强 |
| requests | 2.32.5 | HTTP请求库 |
| rich | 13.9.4 | 富文本和美观的终端输出 |
| rich-toolkit | 0.19.4 | Rich工具集 |
| rignore | 0.7.6 | Gitignore处理 |
| rpds-py | 0.30.0 | Rust持久数据结构Python绑定 |
| rsa | 4.9.1 | RSA加密实现 |
| safetensors | 0.7.0 | 安全的张量序列化格式 |
| sentencepiece | 0.2.1 | 文本分词库 |
| sentry-sdk | 2.52.0 | Sentry错误监控SDK |
| setproctitle | 1.3.7 | 进程标题设置 |
| setuptools | 80.10.2 | Python包构建和安装 |
| shellingham | 1.5.4 | 检测当前shell |
| simplejson | 3.20.2 | JSON编码/解码器 |
| six | 1.17.0 | Python 2/3兼容性工具 |
| smart_open | 7.5.0 | 智能文件打开(支持云存储) |
| smmap | 5.0.2 | 滑动内存映射管理器 |
| sniffio | 1.3.1 | 异步库检测 |
| sse-starlette | 3.2.0 | Starlette的服务器发送事件 |
| starlette | 0.52.1 | ASGI工具包(FastAPI基础) |
| supervisor | 4.3.0 | 进程控制系统 |
| swanlab | 0.7.8 | 实验跟踪和可视化 |
| sympy | 1.14.0 | 符号数学库 |
| tabulate | 0.9.0 | 表格数据格式化 |
| tensordict | 0.11.0 | 张量字典数据结构 |
| tiktoken | 0.12.0 | OpenAI的BPE分词器 |
| tokenizers | 0.22.2 | Hugging Face快速分词器 |
| torch | 2.9.1+cu128 | PyTorch深度学习框架(CUDA 12.8) |
| torchaudio | 2.9.1+cu128 | PyTorch音频处理 |
| torchdata | 0.11.0 | PyTorch数据加载工具 |
| torchvision | 0.24.1+cu128 | PyTorch视觉库 |
| tqdm | 4.67.3 | 进度条库 |
| transformers | 4.56.2 | Hugging Face预训练模型库 |
| triton | 3.5.1 | OpenAI Triton GPU编译器 |
| typer | 0.23.1 | 基于类型提示的CLI框架 |
| typing_extensions | 4.15.0 | 类型提示向后兼容 |
| typing-inspection | 0.4.2 | 类型检查工具 |
| urllib3 | 2.6.3 | HTTP客户端 |
| uvicorn | 0.40.0 | 高性能ASGI服务器 |
| uvloop | 0.22.1 | 快速asyncio事件循环 |
| verl | 0.3.3.dev0 | RL框架(可编辑安装:/autodl-fs/data/EasyR1) |
| virtualenv | 20.36.1 | 虚拟环境创建工具 |
| vllm | 0.15.1 | 大语言模型推理和服务引擎 |
| wandb | 0.25.0 | Weights & Biases实验跟踪 |
| watchfiles | 1.1.1 | 文件监控工具 |
| wcwidth | 0.6.0 | 宽字符宽度计算 |
| websockets | 16.0 | WebSocket实现 |
| wheel | 0.46.3 | Python wheel包格式 |
| wrapt | 2.1.1 | 对象包装和装饰器 |
| xgrammar | 0.1.29 | 结构化生成语法引擎 |
| xxhash | 3.6.0 | xxHash快速哈希算法 |
| yarl | 1.22.0 | URL解析库 |
| zipp | 3.23.0 | 压缩文件路径遍历 |
环境特征总结:
- 深度学习框架: PyTorch 2.9.1 (CUDA 12.8), Transformers, vLLM
- 推理优化: FlashAttention, FlashInfer, Triton, xGrammar
- 分布式训练: Ray, DeepSpeed相关组件
- 大模型工具: 完整的Hugging Face生态, OpenAI/Anthropic API客户端
- Web服务: FastAPI + Uvicorn,支持模型部署
- CUDA生态: 完整的NVIDIA CUDA 12.8工具链
- 实验跟踪: wandb, swanlab, Ray Tune
- RL框架: verl (EasyR1项目,可编辑安装)
2 Easy R1训练环境部署
- 虚拟环境创建和pytorch安装
bash
# 创建虚拟环境
conda create -n easy-r1 python=3.12 -y
# 安装cuda 12.8+torch 2.9.1
conda activate easy-r1
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
# 验证安装
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name(0)}')"
bash
apt update
# 下载EasyR1项目
apt intall git
git clone https://github.com/hiyouga/EasyR1.git
cd EasyR1
# Flash Attention 编译依赖 packaging 和 ninja
pip install packaging ninja psutil
# 安装Flash Attention2
pip install flash-attn==2.8.3 --no-build-isolation
# 验证flash_attn
python -c "import flash_attn; print(flash_attn.__version__)"
# 安装项目依赖
pip install -e .
- 日志监控工具安装
bash
pip install swanlab
# https://swanlab.cn/space/~
swanlab login
- 主要依赖验证脚本
bash
vim test_flash_attn.py
python
import torch
import flash_attn
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"Flash Attention: {flash_attn.__version__}")
# 测试是否能正常导入 flash_attn 函数
from flash_attn import flash_attn_func
print("Flash Attention 导入成功!")
bash
python test_flash_attn.py
3 vllm_utils.py API修正
/xxx/EasyR1/verl/utils/vllm_utils.py使用的vllm api过时,这里提供一个兼容vllm 0.15.1版本的vllm_utils.py源码内容。
bash
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import version
from typing import List
from msgspec import field
from packaging import version as vs
from vllm.lora.lora_model import LoRAModel
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
class TensorLoRARequest(LoRARequest):
peft_config: dict = field(default=None)
lora_tensors: dict = field(default=None)
class VLLMHijack:
@staticmethod
def hijack():
def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
"""
based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors
Reason:
VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.
To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to load memory-based LoRA tensors.
"""
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_tensors = None
from vllm.lora.peft_helper import PEFTHelper
if isinstance(lora_request, TensorLoRARequest):
peft_config = lora_request.peft_config
lora_tensors = lora_request.lora_tensors
peft_helper = PEFTHelper.from_dict(peft_config)
else:
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None:
hf_to_vllm_mapper = model.hf_to_vllm_mapper
# vLLM 0.15.1+ compatibility: use getattr for embedding_modules and embedding_padding_modules
embedding_modules = getattr(self, 'embedding_modules', None)
embedding_padding_modules = getattr(self, 'embedding_padding_modules', None)
# Build kwargs dynamically based on vLLM version
kwargs = {
'lora_model_id': lora_request.lora_int_id,
'device': "cpu",
'dtype': self.lora_config.lora_dtype,
}
# Add optional parameters only if they exist in the method signature
import inspect
sig = inspect.signature(self._lora_model_cls.from_lora_tensors)
sig_params = list(sig.parameters.keys())
if 'tensors' in sig_params:
kwargs['tensors'] = lora_tensors
if 'peft_helper' in sig_params:
kwargs['peft_helper'] = peft_helper
if 'target_embedding_padding' in sig_params:
kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if 'embedding_modules' in sig_params:
kwargs['embedding_modules'] = embedding_modules
if 'embedding_padding_modules' in sig_params:
kwargs['embedding_padding_modules'] = embedding_padding_modules
if 'weights_mapper' in sig_params:
kwargs['weights_mapper'] = hf_to_vllm_mapper
if isinstance(lora_request, TensorLoRARequest):
lora = self._lora_model_cls.from_lora_tensors(**kwargs)
else:
# For from_local_checkpoint, build different kwargs
local_kwargs = {
'lora_path': lora_path,
'expected_lora_modules': expected_lora_modules,
'peft_helper': peft_helper,
'lora_model_id': lora_request.lora_int_id,
'device': "cpu",
'dtype': self.lora_config.lora_dtype,
}
# Check which parameters are supported
local_sig = inspect.signature(self._lora_model_cls.from_local_checkpoint)
local_sig_params = list(local_sig.parameters.keys())
if 'target_embedding_padding' in local_sig_params:
local_kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if 'embedding_modules' in local_sig_params:
local_kwargs['embedding_modules'] = embedding_modules
if 'embedding_padding_modules' in local_sig_params:
local_kwargs['embedding_padding_modules'] = embedding_padding_modules
if 'weights_mapper' in local_sig_params:
local_kwargs['weights_mapper'] = hf_to_vllm_mapper
lora = self._lora_model_cls.from_local_checkpoint(**local_kwargs)
# vLLM 0.15.1+ compatibility: extra_vocab_size may not exist
lora_extra_vocab_size = getattr(lora, 'extra_vocab_size', 0)
config_extra_vocab_size = getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if lora_extra_vocab_size > config_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora_extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{config_extra_vocab_size}."
)
return lora
setattr(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter)
if vs.parse(version("vllm")).base_version == "0.11.0":
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
def hijack__get_mm_mapping(self) -> MultiModelKeys:
"""
Patch vllm.model_executor.models.qwen3_vl.Qwen3VLForConditionalGeneration.get_mm_mapping in vLLM 0.11.0
Reason:
vLLM 0.11.0 uses "model.visual.*" prefixes for Qwen3-VL, but the real module names are "visual.*".
This breaks LoRA filtering for multimodal parts, so we align the prefixes to the real module names.
Fixed upstream: https://github.com/vllm-project/vllm/commit/9f4e309
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.merger.",
tower_model="visual.",
)
setattr(Qwen3VLForConditionalGeneration, "get_mm_mapping", hijack__get_mm_mapping)
4 训练数据集配置
- 这里只提供EasyR1要求的数据集格式,具体数据集的收集和格式转化请自行处理。
json
{"problem": "xxx", "answer": "xxx"}
5 训练奖励函数配置(适配H800双卡)
- 创建
EasyR1/easyr1/reward/pentest_reward.py:
python
import re
from typing import List, Dict, Any
# 正则表达式模式(与原项目完全一致)
THINK_PATTERN = re.compile(r"</think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
STEP_PATTERN = re.compile(r"===\s*Step\s*\d+\s*===\s*Thought:.*?Command:", re.DOTALL | re.IGNORECASE)
POST_THINK_PATTERN = re.compile(r"</think>(.*)", re.DOTALL)
STEP_NUM_PATTERN = re.compile(r"===\s*Step\s*(\d+)\s*===", re.IGNORECASE)
COMMAND_PATTERN = re.compile(r"Command:\s*(.*?)(?:\n|$)", re.DOTALL | re.IGNORECASE)
class PentestRewardFunction:
"""
渗透测试 GRPO 奖励函数
包含两个奖励组件:
1. format_reward: 评估输出格式是否符合要求
2. accuracy_reward: 评估内容与参考答案的匹配程度
"""
def __init__(
self,
format_weight: float = 0.3,
accuracy_weight: float = 0.7,
**kwargs
):
"""
初始化奖励函数
Args:
format_weight: 格式奖励权重
accuracy_weight: 准确性奖励权重
"""
self.format_weight = format_weight
self.accuracy_weight = accuracy_weight
def __call__(
self,
prompts: List[str],
completions: List[str],
**kwargs
) -> Dict[str, List[float]]:
"""
计算奖励值(EasyR1 接口)
Args:
prompts: 提示列表
completions: 模型生成的完成列表
kwargs: 包含 answer 等额外信息
Returns:
Dict 包含各个奖励分量
"""
# 获取参考答案
answers = kwargs.get("answer", [""] * len(completions))
rewards = []
format_rewards = []
accuracy_rewards = []
for prompt, completion, answer in zip(prompts, completions, answers):
# 计算格式奖励
format_score = self._compute_format_reward(completion)
format_rewards.append(format_score)
# 计算准确性奖励
accuracy_score = self._compute_accuracy_reward(completion, answer)
accuracy_rewards.append(accuracy_score)
# 加权求和
total_score = (
self.format_weight * format_score +
self.accuracy_weight * accuracy_score
)
rewards.append(total_score)
return {
"rewards": rewards,
"format_reward": format_rewards,
"accuracy_reward": accuracy_rewards
}
def _compute_format_reward(self, completion: str) -> float:
"""
格式奖励:评估输出是否符合预期格式
评分标准:
- 包含思考部分: +0.3
- 包含正确步骤格式: +0.3
- 满分: 0.6(原项目设计)
"""
score = 0.0
# 检查思考部分
think_match = THINK_PATTERN.search(completion)
if think_match and think_match.group(1).strip():
score += 0.3
# 检查步骤格式
if STEP_PATTERN.search(completion):
score += 0.3
return score
def _compute_accuracy_reward(self, completion: str, answer: str) -> float:
"""
准确性奖励:评估内容与参考答案的匹配度
评分标准:
- 步骤编号正确: +0.2
- 命令完全匹配: +1.0
- 命令部分匹配 (Jaccard > 0.5): +0.7 * similarity
"""
score = 0.0
# 提取思考后的内容
match_format = POST_THINK_PATTERN.search(completion)
if not match_format:
return score # 没有思考部分,准确性为0
analysis_content = match_format.group(1).strip()
# 检查步骤编号匹配
gen_step_match = STEP_NUM_PATTERN.search(analysis_content)
true_step_match = STEP_NUM_PATTERN.search(answer)
if (gen_step_match and true_step_match and
gen_step_match.group(1) == true_step_match.group(1)):
score += 0.2
# 提取命令
gen_command_match = COMMAND_PATTERN.search(analysis_content)
true_command_match = COMMAND_PATTERN.search(answer)
gen_command = gen_command_match.group(1).strip() if gen_command_match else ""
true_command = true_command_match.group(1).strip() if true_command_match else ""
if not true_command:
return score
# 命令完全匹配
if gen_command == true_command:
score += 1.0
elif gen_command:
# 计算 Jaccard 相似度
gen_words = set(gen_command.split())
true_words = set(true_command.split())
intersection = gen_words & true_words
union_size = len(gen_words) + len(true_words) - len(intersection)
if union_size > 0:
similarity = len(intersection) / union_size
if similarity > 0.5:
score += similarity * 0.7
return min(score, 1.5) # 最高分 1.5 (0.2 + 1.0 + 0.3)
# 创建全局奖励函数实例(用于 EasyR1)
_reward_fn_instance = PentestRewardFunction()
def create_reward_function(reward_inputs):
"""
EasyR1 奖励函数入口点
Args:
reward_inputs: List[Dict],每个 dict 包含:
- response: 模型生成的文本
- ground_truth: 参考答案
Returns:
List[Dict[str, float]]: 每个输入的奖励分数
"""
prompts = []
completions = []
answers = []
for item in reward_inputs:
prompts.append("")
completions.append(item.get("response", ""))
answers.append(item.get("ground_truth", ""))
result = _reward_fn_instance(prompts, completions, answer=answers)
scores = []
for i in range(len(reward_inputs)):
scores.append({
"overall": result["rewards"][i],
"format": result["format_reward"][i],
"accuracy": result["accuracy_reward"][i],
})
return scores
REWARD_NAME = "pentest_reward"
REWARD_TYPE = "batch"
- 注册奖励函数,编辑
EasyR1/easyr1/reward/__init__.py,添加:
bash
from .pentest_reward import PentestRewardFunction, create_reward_function
__all__ = [
# ... 其他导入
"PentestRewardFunction",
"create_reward_function",
]
- 同时在上一级目录创建
vim EasyR1/easyr1/__init__.py。(后续训练配置文件使用绝对路径,可能这里不需要设置)
6 训练配置文件
- 创建
EasyR1/examples/pentest_grpo_h800_optimized.yaml,使用 verl 的配置格式。
yaml
data:
train_files: /xxx/EasyR1/data/pentest/pentest_grpo_train.jsonl
val_files: /xxx/EasyR1/data/pentest/pentest_grpo_eval.jsonl
prompt_key: problem
answer_key: answer
image_key: images
video_key: videos
image_dir: null
video_fps: 2.0
# 上下文长度配置(基于实际数据分析优化)
# 统计数据:12,679样本,95%分位数=3,555 tokens,99%分位数=4,817 tokens
# 设置为5120 prompt + 3072 response = 8192 total
# - 5120 > 99%分位数(4817),只过滤0.3%的极端长样本
# - 3072 response空间更充足,支持更详细的thought和command
max_prompt_length: 5120
max_response_length: 3072
# 增大batch size以提高GPU利用率
rollout_batch_size: 16
mini_rollout_batch_size: 8
val_batch_size: 8
format_prompt: null
override_chat_template: null
shuffle: true
seed: 1
min_pixels: 262144
max_pixels: 4194304
filter_overlong_prompts: true
algorithm:
adv_estimator: grpo
disable_kl: false
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2
online_filtering: false
filter_key: overall
filter_low: 0.01
filter_high: 0.99
worker:
actor:
# 增大global batch size
global_batch_size: 16
micro_batch_size_per_device_for_update: 4
micro_batch_size_per_device_for_experience: 4
max_grad_norm: 1.0
padding_free: true
dynamic_batching: true
ulysses_size: 1
model:
model_path: unsloth/DeepSeek-R1-0528-Qwen3-8B
enable_gradient_checkpointing: true
trust_remote_code: true
freeze_vision_tower: false
lora:
rank: 32
alpha: 64
target_modules: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj
exclude_modules: null
optim:
lr: 5.0e-6
weight_decay: 0.01
strategy: adamw
lr_warmup_ratio: 0.0
fsdp:
enable_full_shard: true
enable_cpu_offload: false
enable_rank0_init: true
torch_dtype: bf16
offload:
# 关闭offload以提高速度(H800显存充足)
offload_params: false
offload_optimizer: false
rollout:
n: 4
temperature: 0.6
top_p: 0.9
limit_images: 0
# 提高显存利用率
gpu_memory_utilization: 0.75
# 关闭eager模式,启用CUDA graph加速
enforce_eager: false
# 启用chunked prefill提高长序列效率
enable_chunked_prefill: true
tensor_parallel_size: 2
# 增大最大batch tokens以支持更大batch
max_num_batched_tokens: 8192
disable_tqdm: false
val_override_config:
temperature: 0.6
top_p: 0.9
n: 1
ref:
fsdp:
enable_full_shard: true
enable_cpu_offload: true
enable_rank0_init: true
torch_dtype: bf16
offload:
offload_params: false
reward:
reward_function: /xxx/EasyR1/easyr1/reward/pentest_reward.py:create_reward_function
trainer:
total_epochs: 2
max_steps: null
project_name: pentest
experiment_name: grpo-h800x2-optimized
logger: ["console", "swanlab"]
nnodes: 1
n_gpus_per_node: 2
max_try_make_batch: 20
# 验证配置
val_freq: 100 # 每100步验证一次(平衡训练速度和监控频率)
val_before_train: true # 训练前验证,评估初始模型性能
val_only: false # 仅验证模式(用于单独评估)
val_generations_to_log: 3 # 验证时记录3个生成样本,便于观察模型输出
# 保存配置
save_freq: 50 # 每50步保存一次检查点
save_limit: 3 # 最多保留3个检查点,旧的自动删除
save_model_only: false # false=保存完整检查点(含优化器状态),true=仅保存模型
save_checkpoint_path: /xxx/EasyR1/checkpoints/pentest-optimized # 检查点保存路径
load_checkpoint_path: null # 如需从检查点恢复,填写路径
find_last_checkpoint: true # 自动查找最新的检查点
syr1.jsonl
val_files: /xxx/EasyR1/data/pentest/pentest_grpo_eval.jsonl
prompt_key: problem
answer_key: answer
image_key: images
video_key: videos
image_dir: null
video_fps: 2.0
max_prompt_length: 4096
max_response_length: 4096
# 增大batch size以提高GPU利用率
rollout_batch_size: 16
mini_rollout_batch_size: 8
val_batch_size: 8
format_prompt: null
override_chat_template: null
shuffle: true
seed: 1
min_pixels: 262144
max_pixels: 4194304
filter_overlong_prompts: true
algorithm:
adv_estimator: grpo
disable_kl: false
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2
online_filtering: false
filter_key: overall
filter_low: 0.01
filter_high: 0.99
worker:
actor:
# 增大global batch size
global_batch_size: 16
micro_batch_size_per_device_for_update: 4
micro_batch_size_per_device_for_experience: 4
max_grad_norm: 1.0
padding_free: true
dynamic_batching: true
ulysses_size: 1
model:
model_path: unsloth/DeepSeek-R1-0528-Qwen3-8B
enable_gradient_checkpointing: true
trust_remote_code: true
freeze_vision_tower: false
lora:
rank: 32
alpha: 64
target_modules: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj
exclude_modules: null
optim:
lr: 5.0e-6
weight_decay: 0.01
strategy: adamw
lr_warmup_ratio: 0.0
fsdp:
enable_full_shard: true
enable_cpu_offload: false
enable_rank0_init: true
torch_dtype: bf16
offload:
# 关闭offload以提高速度(H800显存充足)
offload_params: false
offload_optimizer: false
rollout:
n: 4
temperature: 0.6
top_p: 0.9
limit_images: 0
# 提高显存利用率
gpu_memory_utilization: 0.75
# 关闭eager模式,启用CUDA graph加速
enforce_eager: false
# 启用chunked prefill提高长序列效率
enable_chunked_prefill: true
tensor_parallel_size: 2
# 增大最大batch tokens以支持更大batch
max_num_batched_tokens: 8192
disable_tqdm: false
val_override_config:
temperature: 0.6
top_p: 0.9
n: 1
ref:
fsdp:
enable_full_shard: true
enable_cpu_offload: true
enable_rank0_init: true
torch_dtype: bf16
offload:
offload_params: false
reward:
reward_function: /xxx/EasyR1/easyr1/reward/pentest_reward.py:create_reward_function
trainer:
total_epochs: 2
max_steps: null
project_name: pentest
experiment_name: grpo-h800x2-optimized
logger: ["console", "swanlab"]
nnodes: 1
n_gpus_per_node: 2
max_try_make_batch: 20
val_freq: 50
val_before_train: true
val_only: false
val_generations_to_log: 1
# 保存配置
save_freq: 50 # 每50步保存一次检查点
save_limit: 3 # 最多保留3个检查点,旧的自动删除
save_model_only: false # false=保存完整检查点(含优化器状态),true=仅保存模型
save_checkpoint_path: /xxx/EasyR1/checkpoints/pentest-optimized # 检查点保存路径
load_checkpoint_path: null # 如需从检查点恢复,填写路径
find_last_checkpoint: true # 自动查找最新的检查点
7 训练启动脚本
- 在EasyR1项目根目录下创建启动脚本文件
start_training.sh。
bash
#!/bin/bash
# Pentest-R1 GRPO 训练启动脚本
# 使用 nohup 后台运行,支持 SSH 断开后继续训练
# 设置环境变量
export CUDA_VISIBLE_DEVICES=0,1
export PYTHONPATH=/root/autodl-fs/EasyR1:$PYTHONPATH
# 创建日志目录
mkdir -p /root/autodl-fs/EasyR1/logs
mkdir -p /root/autodl-fs/EasyR1/checkpoints/pentest-r1-optimized
# 获取当前时间戳
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="/root/autodl-fs/EasyR1/logs/training_${TIMESTAMP}.log"
echo "========================================"
echo "启动 Pentest-R1 GRPO 训练"
echo "配置文件: examples/pentest_grpo_h800_optimized.yaml"
echo "日志文件: ${LOG_FILE}"
echo "开始时间: $(date)"
echo "========================================"
# 使用 nohup 启动训练
# 2>&1: 将 stderr 重定向到 stdout
# &: 后台运行
nohup python3 -m verl.trainer.main \
config=examples/pentest_grpo_h800_optimized.yaml \
trainer.n_gpus_per_node=2 \
> "${LOG_FILE}" 2>&1 &
# 获取进程ID
PID=$!
echo "训练进程 PID: ${PID}"
echo "${PID}" > /root/autodl-fs/EasyR1/logs/training.pid
echo ""
echo "训练已在后台启动!"
echo ""
echo "常用命令:"
echo " 查看实时日志: tail -f ${LOG_FILE}"
echo " 查看进程状态: ps aux | grep verl.trainer"
echo " 查看GPU状态: nvidia-smi"
echo " 停止训练: kill ${PID}"
echo ""
echo "查看日志: tail -f ${LOG_FILE}"
- 训练中断再启动:可以从保存点继续训练。