Easy R1 训练环境搭建与配置实战指南(GRPO算法)

文章目录

  • [0 参考资料](#0 参考资料)
  • [1 虚拟环境依赖版本详情](#1 虚拟环境依赖版本详情)
  • [2 Easy R1训练环境部署](#2 Easy R1训练环境部署)
  • [3 vllm_utils.py API修正](#3 vllm_utils.py API修正)
  • [4 训练数据集配置](#4 训练数据集配置)
  • [5 训练奖励函数配置(适配H800双卡)](#5 训练奖励函数配置(适配H800双卡))
  • [6 训练配置文件](#6 训练配置文件)
  • [7 训练启动脚本](#7 训练启动脚本)

0 参考资料

1 虚拟环境依赖版本详情

  • 表中的依赖和版本经过作者验证可以正常进行训练,vllm 0.15.1由于api变动导致训练的问题已在本文后续部分给出兼容代码。
  • 如果安装过程中出现部分依赖兼容问题,可以查阅此表,希望提供帮助快速解决环境兼容问题。
Package Version 说明
accelerate 1.12.0 Hugging Face的模型训练和推理加速库
aiohappyeyeballs 2.6.1 异步DNS解析库,用于Happy Eyeballs算法
aiohttp 3.13.3 异步HTTP客户端/服务器框架
aiohttp-cors 0.8.1 aiohttp的跨域资源共享(CORS)处理
aiosignal 1.4.0 异步信号支持库
annotated-doc 0.0.4 文档注释处理工具
annotated-types 0.7.0 Python类型注解扩展
anthropic 0.79.0 Anthropic API客户端(Claude模型)
antlr4-python3-runtime 4.9.3 ANTLR4解析器的Python运行时
anyio 4.12.1 异步I/O抽象层,兼容asyncio/trio
apache-tvm-ffi 0.1.8.post2 Apache TVM深度学习编译器FFI接口
astor 0.8.1 Python AST(抽象语法树)操作库
attrs 25.4.0 Python类装饰器,简化类定义
av 16.1.0 PyAV多媒体处理库(FFmpeg绑定)
blake3 1.0.8 BLAKE3哈希算法实现
cachetools 7.0.1 可扩展缓存集合(内存缓存等)
cbor2 5.8.0 CBOR(简明二进制对象表示)编码/解码
certifi 2026.1.4 Mozilla SSL证书集合
cffi 2.0.0 C Foreign Function Interface
charset-normalizer 3.4.4 字符编码自动检测
click 8.3.1 命令行接口创建工具包
cloudpickle 3.1.2 扩展的pickle模块,支持更多类型序列化
codetiming 1.4.0 代码执行时间测量工具
colorful 0.5.8 终端彩色输出库
compressed-tensors 0.13.0 压缩张量数据处理
conda-pack 0.8.1 Conda环境打包工具
cryptography 46.0.5 加密算法和协议实现
cuda-bindings 13.1.1 CUDA Python绑定
cuda-pathfinder 1.3.4 CUDA路径查找工具
cuda-python 13.1.1 NVIDIA CUDA Python接口
cupy-cuda12x 13.6.0 NumPy兼容的CUDA数组计算库
datasets 4.5.0 Hugging Face数据集加载和处理库
depyf 0.20.0 Python反编译器
dill 0.4.0 扩展的pickle,支持序列化更多Python对象
diskcache 5.6.3 磁盘和文件缓存库
distlib 0.4.0 分发包的低级工具
distro 1.9.0 Linux发行版信息获取
dnspython 2.8.0 DNS工具包
docstring_parser 0.17.0 文档字符串解析器
einops 0.8.2 张量操作库,支持爱因斯坦求和约定
email-validator 2.3.0 电子邮件地址验证
fastapi 0.129.0 现代高性能Web框架
fastapi-cli 0.0.21 FastAPI命令行工具
fastapi-cloud-cli 0.12.0 FastAPI云部署CLI工具
fastar 0.8.0 快速数组操作库
fastrlock 0.8.3 快速可重入锁实现
filelock 3.24.0 文件锁实现,用于并发控制
flash_attn 2.8.3 FlashAttention快速注意力机制实现
flashinfer-python 0.6.1 LLM推理加速库(采样、注意力等)
frozenlist 1.8.0 不可变列表实现
fsspec 2025.10.0 文件系统规范,统一文件系统接口
gguf 0.17.1 GGUF格式(GGML通用文件格式)处理
gitdb 4.0.12 Git对象数据库
GitPython 3.1.46 Python Git操作库
google-api-core 2.29.0 Google API客户端核心库
google-auth 2.48.0 Google认证库
googleapis-common-protos 1.72.0 Google API通用协议缓冲区
grpcio 1.78.0 gRPC Python实现
grpcio-reflection 1.78.0 gRPC服务反射
h11 0.16.0 HTTP/1.1协议纯Python实现
hf-xet 1.2.0 Hugging Face XET存储后端
httpcore 1.0.9 HTTP核心库
httptools 0.7.1 HTTP解析工具
httpx 0.28.1 现代异步HTTP客户端
httpx-sse 0.4.3 HTTPX的服务器发送事件(SSE)支持
huggingface_hub 0.36.2 Hugging Face模型和数据集管理
idna 3.11 国际化域名编码
ijson 3.4.0.post0 迭代式JSON解析器
importlib_metadata 8.7.1 导入元数据读取(兼容旧版Python)
interegular 0.3.3 正则表达式交集计算
Jinja2 3.1.6 模板引擎
jiter 0.13.0 快速JSON解析器
jmespath 1.1.0 JSON查询语言
jsonschema 4.26.0 JSON Schema验证
jsonschema-specifications 2025.9.1 JSON Schema规范数据
lark 1.2.2 解析库,支持多种语法
liger_kernel 0.7.0 LLM训练优化内核(融合算子)
llguidance 1.3.0 LLM结构化生成约束库
llvmlite 0.44.0 LLVM的轻量级Python绑定
lm-format-enforcer 0.11.3 语言模型输出格式强制
loguru 0.7.3 现代日志记录库
markdown-it-py 4.0.0 Markdown解析器
MarkupSafe 2.1.5 XML/HTML/XHTML标记安全字符串
mathruler 0.1.0 数学规则处理
mcp 1.26.0 Model Context Protocol实现
mdurl 0.1.2 Markdown URL处理
mistral_common 1.9.1 Mistral AI通用工具
model-hosting-container-standards 0.1.13 模型托管容器标准
mpmath 1.3.0 多精度数学库
msgpack 1.1.2 MessagePack序列化格式
msgspec 0.20.0 高性能序列化/验证库
multidict 6.7.1 多值字典实现
multiprocess 0.70.18 多进程处理库(dill支持)
networkx 3.6.1 复杂网络分析库
ninja 1.13.0 Ninja构建系统
numba 0.61.2 JIT编译器,加速Python数值计算
numpy 2.2.6 数值计算基础库
nvidia-cublas-cu12 12.8.4.1 NVIDIA cuBLAS线性代数库(CUDA 12)
nvidia-cuda-cupti-cu12 12.8.90 NVIDIA CUDA性能工具接口
nvidia-cuda-nvrtc-cu12 12.8.93 NVIDIA CUDA运行时编译
nvidia-cuda-runtime-cu12 12.8.90 NVIDIA CUDA运行时
nvidia-cudnn-cu12 9.10.2.21 NVIDIA cuDNN深度学习库
nvidia-cudnn-frontend 1.18.0 cuDNN前端API
nvidia-cufft-cu12 11.3.3.83 NVIDIA cuFFT快速傅里叶变换
nvidia-cufile-cu12 1.13.1.3 NVIDIA cuFile GPUDirect存储
nvidia-curand-cu12 10.3.9.90 NVIDIA cuRAND随机数生成
nvidia-cusolver-cu12 11.7.3.90 NVIDIA cuSolver密集和稀疏求解器
nvidia-cusparse-cu12 12.5.8.93 NVIDIA cuSPARSE稀疏矩阵库
nvidia-cusparselt-cu12 0.7.1 NVIDIA cuSPARSELt稀疏矩阵库
nvidia-cutlass-dsl 4.4.0 NVIDIA CUTLASS DSL
nvidia-cutlass-dsl-libs-base 4.4.0 CUTLASS基础库
nvidia-ml-py 13.590.48 NVIDIA管理库Python绑定
nvidia-nccl-cu12 2.27.5 NVIDIA NCCL多GPU通信库
nvidia-nvjitlink-cu12 12.8.93 NVIDIA JIT链接库
nvidia-nvshmem-cu12 3.3.20 NVIDIA NVSHMEM共享内存
nvidia-nvtx-cu12 12.8.90 NVIDIA NVTX标记工具
omegaconf 2.3.0 分层配置管理
openai 2.21.0 OpenAI API客户端
openai-harmony 0.0.8 OpenAI Harmony工具
opencensus 0.11.4 分布式追踪和指标收集
opencensus-context 0.1.3 OpenCensus上下文传播
opencv-python-headless 4.13.0.92 OpenCV计算机视觉库(无GUI)
opentelemetry-api 1.39.1 OpenTelemetry API
opentelemetry-exporter-prometheus 0.60b1 OpenTelemetry Prometheus导出器
opentelemetry-proto 1.39.1 OpenTelemetry协议缓冲区
opentelemetry-sdk 1.39.1 OpenTelemetry SDK
opentelemetry-semantic-conventions 0.60b1 OpenTelemetry语义约定
orjson 3.11.7 高性能JSON库
outlines_core 0.2.11 结构化生成核心库
packaging 25.0 包版本和依赖处理
pandas 3.0.0 数据分析和处理库
partial-json-parser 0.2.1.1.post7 部分JSON解析器
peft 0.18.1 参数高效微调库
pillow 12.0.0 图像处理库
pip 26.0.1 Python包安装工具
platformdirs 4.9.1 跨平台目录路径确定
prettytable 3.17.0 美观的ASCII表格生成
prometheus_client 0.24.1 Prometheus监控客户端
prometheus-fastapi-instrumentator 7.1.0 FastAPI Prometheus指标收集
propcache 0.4.1 属性缓存工具
proto-plus 1.27.1 Protocol Buffers增强
protobuf 6.33.5 协议缓冲区序列化
psutil 7.2.2 系统和进程监控
py-cpuinfo 9.0.0 CPU信息获取
py-spy 0.4.1 Python性能分析器(采样)
pyarrow 23.0.0 Apache Arrow列式数据格式
pyasn1 0.6.2 ASN.1编码/解码
pyasn1_modules 0.4.2 ASN.1模块集合
pybase64 1.4.3 Base64编码/解码加速
pycountry 24.6.1 国家/地区数据
pycparser 3.0 C解析器
pydantic 2.12.5 数据验证和设置管理
pydantic_core 2.41.5 Pydantic核心功能
pydantic-extra-types 2.11.0 Pydantic额外类型
pydantic-settings 2.12.0 Pydantic设置管理
pyecharts 2.1.0 ECharts图表Python绑定
Pygments 2.19.2 语法高亮库
PyJWT 2.11.0 JSON Web Token实现
pylatexenc 2.10 LaTeX编码处理
python-dateutil 2.9.0.post0 日期时间处理扩展
python-dotenv 1.2.1 环境变量从.env文件加载
python-json-logger 4.0.0 JSON格式日志记录
python-multipart 0.0.22 多部分表单数据解析
pyvers 0.2.2 Python版本管理
PyYAML 6.0.3 YAML解析和生成
pyzmq 27.1.0 ZeroMQ Python绑定
qwen-vl-utils 0.0.14 通义千问视觉语言模型工具
ray 2.53.0 分布式计算和机器学习框架
referencing 0.37.0 JSON引用解析
regex 2026.1.15 正则表达式增强
requests 2.32.5 HTTP请求库
rich 13.9.4 富文本和美观的终端输出
rich-toolkit 0.19.4 Rich工具集
rignore 0.7.6 Gitignore处理
rpds-py 0.30.0 Rust持久数据结构Python绑定
rsa 4.9.1 RSA加密实现
safetensors 0.7.0 安全的张量序列化格式
sentencepiece 0.2.1 文本分词库
sentry-sdk 2.52.0 Sentry错误监控SDK
setproctitle 1.3.7 进程标题设置
setuptools 80.10.2 Python包构建和安装
shellingham 1.5.4 检测当前shell
simplejson 3.20.2 JSON编码/解码器
six 1.17.0 Python 2/3兼容性工具
smart_open 7.5.0 智能文件打开(支持云存储)
smmap 5.0.2 滑动内存映射管理器
sniffio 1.3.1 异步库检测
sse-starlette 3.2.0 Starlette的服务器发送事件
starlette 0.52.1 ASGI工具包(FastAPI基础)
supervisor 4.3.0 进程控制系统
swanlab 0.7.8 实验跟踪和可视化
sympy 1.14.0 符号数学库
tabulate 0.9.0 表格数据格式化
tensordict 0.11.0 张量字典数据结构
tiktoken 0.12.0 OpenAI的BPE分词器
tokenizers 0.22.2 Hugging Face快速分词器
torch 2.9.1+cu128 PyTorch深度学习框架(CUDA 12.8)
torchaudio 2.9.1+cu128 PyTorch音频处理
torchdata 0.11.0 PyTorch数据加载工具
torchvision 0.24.1+cu128 PyTorch视觉库
tqdm 4.67.3 进度条库
transformers 4.56.2 Hugging Face预训练模型库
triton 3.5.1 OpenAI Triton GPU编译器
typer 0.23.1 基于类型提示的CLI框架
typing_extensions 4.15.0 类型提示向后兼容
typing-inspection 0.4.2 类型检查工具
urllib3 2.6.3 HTTP客户端
uvicorn 0.40.0 高性能ASGI服务器
uvloop 0.22.1 快速asyncio事件循环
verl 0.3.3.dev0 RL框架(可编辑安装:/autodl-fs/data/EasyR1)
virtualenv 20.36.1 虚拟环境创建工具
vllm 0.15.1 大语言模型推理和服务引擎
wandb 0.25.0 Weights & Biases实验跟踪
watchfiles 1.1.1 文件监控工具
wcwidth 0.6.0 宽字符宽度计算
websockets 16.0 WebSocket实现
wheel 0.46.3 Python wheel包格式
wrapt 2.1.1 对象包装和装饰器
xgrammar 0.1.29 结构化生成语法引擎
xxhash 3.6.0 xxHash快速哈希算法
yarl 1.22.0 URL解析库
zipp 3.23.0 压缩文件路径遍历

环境特征总结:

  • 深度学习框架: PyTorch 2.9.1 (CUDA 12.8), Transformers, vLLM
  • 推理优化: FlashAttention, FlashInfer, Triton, xGrammar
  • 分布式训练: Ray, DeepSpeed相关组件
  • 大模型工具: 完整的Hugging Face生态, OpenAI/Anthropic API客户端
  • Web服务: FastAPI + Uvicorn,支持模型部署
  • CUDA生态: 完整的NVIDIA CUDA 12.8工具链
  • 实验跟踪: wandb, swanlab, Ray Tune
  • RL框架: verl (EasyR1项目,可编辑安装)

2 Easy R1训练环境部署

  • 虚拟环境创建和pytorch安装
bash 复制代码
# 创建虚拟环境
conda create -n easy-r1 python=3.12 -y
# 安装cuda 12.8+torch 2.9.1
conda activate easy-r1
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
# 验证安装
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name(0)}')"
bash 复制代码
apt update
# 下载EasyR1项目
apt intall git
git clone https://github.com/hiyouga/EasyR1.git
cd EasyR1

# Flash Attention 编译依赖 packaging 和 ninja
pip install packaging ninja psutil
# 安装Flash Attention2
pip install flash-attn==2.8.3 --no-build-isolation
# 验证flash_attn
python -c "import flash_attn; print(flash_attn.__version__)"
# 安装项目依赖
pip install -e .
  • 日志监控工具安装
bash 复制代码
pip install swanlab
# https://swanlab.cn/space/~
swanlab login
  • 主要依赖验证脚本
bash 复制代码
vim  test_flash_attn.py
python 复制代码
import torch
import flash_attn

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"Flash Attention: {flash_attn.__version__}")

# 测试是否能正常导入 flash_attn 函数
from flash_attn import flash_attn_func
print("Flash Attention 导入成功!")
bash 复制代码
python test_flash_attn.py

3 vllm_utils.py API修正

  • /xxx/EasyR1/verl/utils/vllm_utils.py使用的vllm api过时,这里提供一个兼容vllm 0.15.1版本的vllm_utils.py源码内容。
bash 复制代码
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib.metadata import version
from typing import List

from msgspec import field
from packaging import version as vs
from vllm.lora.lora_model import LoRAModel
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager


class TensorLoRARequest(LoRARequest):
    peft_config: dict = field(default=None)
    lora_tensors: dict = field(default=None)


class VLLMHijack:
    @staticmethod
    def hijack():
        def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
            """
            based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors
            Reason:
            VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.
            To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to load memory-based LoRA tensors.
            """
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
            expected_lora_modules: List[str] = []
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)

            expected_lora_modules = list(set(expected_lora_modules))

            lora_tensors = None
            from vllm.lora.peft_helper import PEFTHelper

            if isinstance(lora_request, TensorLoRARequest):
                peft_config = lora_request.peft_config
                lora_tensors = lora_request.lora_tensors
                peft_helper = PEFTHelper.from_dict(peft_config)
            else:
                lora_path = get_adapter_absolute_path(lora_request.lora_path)

                peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
            model = self._adapter_manager.model
            hf_to_vllm_mapper = None
            if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None:
                hf_to_vllm_mapper = model.hf_to_vllm_mapper

            # vLLM 0.15.1+ compatibility: use getattr for embedding_modules and embedding_padding_modules
            embedding_modules = getattr(self, 'embedding_modules', None)
            embedding_padding_modules = getattr(self, 'embedding_padding_modules', None)

            # Build kwargs dynamically based on vLLM version
            kwargs = {
                'lora_model_id': lora_request.lora_int_id,
                'device': "cpu",
                'dtype': self.lora_config.lora_dtype,
            }

            # Add optional parameters only if they exist in the method signature
            import inspect
            sig = inspect.signature(self._lora_model_cls.from_lora_tensors)
            sig_params = list(sig.parameters.keys())

            if 'tensors' in sig_params:
                kwargs['tensors'] = lora_tensors
            if 'peft_helper' in sig_params:
                kwargs['peft_helper'] = peft_helper
            if 'target_embedding_padding' in sig_params:
                kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
            if 'embedding_modules' in sig_params:
                kwargs['embedding_modules'] = embedding_modules
            if 'embedding_padding_modules' in sig_params:
                kwargs['embedding_padding_modules'] = embedding_padding_modules
            if 'weights_mapper' in sig_params:
                kwargs['weights_mapper'] = hf_to_vllm_mapper

            if isinstance(lora_request, TensorLoRARequest):
                lora = self._lora_model_cls.from_lora_tensors(**kwargs)
            else:
                # For from_local_checkpoint, build different kwargs
                local_kwargs = {
                    'lora_path': lora_path,
                    'expected_lora_modules': expected_lora_modules,
                    'peft_helper': peft_helper,
                    'lora_model_id': lora_request.lora_int_id,
                    'device': "cpu",
                    'dtype': self.lora_config.lora_dtype,
                }
                # Check which parameters are supported
                local_sig = inspect.signature(self._lora_model_cls.from_local_checkpoint)
                local_sig_params = list(local_sig.parameters.keys())

                if 'target_embedding_padding' in local_sig_params:
                    local_kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
                if 'embedding_modules' in local_sig_params:
                    local_kwargs['embedding_modules'] = embedding_modules
                if 'embedding_padding_modules' in local_sig_params:
                    local_kwargs['embedding_padding_modules'] = embedding_padding_modules
                if 'weights_mapper' in local_sig_params:
                    local_kwargs['weights_mapper'] = hf_to_vllm_mapper

                lora = self._lora_model_cls.from_local_checkpoint(**local_kwargs)

            # vLLM 0.15.1+ compatibility: extra_vocab_size may not exist
            lora_extra_vocab_size = getattr(lora, 'extra_vocab_size', 0)
            config_extra_vocab_size = getattr(self.lora_config, 'lora_extra_vocab_size', 0)
            if lora_extra_vocab_size > config_extra_vocab_size:
                raise ValueError(
                    f"LoRA added vocab size {lora_extra_vocab_size} "
                    f"is greater than lora_extra_vocab_size "
                    f"{config_extra_vocab_size}."
                )
            return lora

        setattr(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter)

        if vs.parse(version("vllm")).base_version == "0.11.0":
            from vllm.model_executor.models.module_mapping import MultiModelKeys
            from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration

            def hijack__get_mm_mapping(self) -> MultiModelKeys:
                """
                Patch vllm.model_executor.models.qwen3_vl.Qwen3VLForConditionalGeneration.get_mm_mapping in vLLM 0.11.0
                Reason:
                vLLM 0.11.0 uses "model.visual.*" prefixes for Qwen3-VL, but the real module names are "visual.*".
                This breaks LoRA filtering for multimodal parts, so we align the prefixes to the real module names.
                Fixed upstream: https://github.com/vllm-project/vllm/commit/9f4e309
                """
                return MultiModelKeys.from_string_field(
                    language_model="language_model",
                    connector="visual.merger.",
                    tower_model="visual.",
                )

            setattr(Qwen3VLForConditionalGeneration, "get_mm_mapping", hijack__get_mm_mapping)

4 训练数据集配置

  • 这里只提供EasyR1要求的数据集格式,具体数据集的收集和格式转化请自行处理。
json 复制代码
{"problem": "xxx", "answer": "xxx"}

5 训练奖励函数配置(适配H800双卡)

  • 创建 EasyR1/easyr1/reward/pentest_reward.py
python 复制代码
import re
from typing import List, Dict, Any

# 正则表达式模式(与原项目完全一致)
THINK_PATTERN = re.compile(r"</think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
STEP_PATTERN = re.compile(r"===\s*Step\s*\d+\s*===\s*Thought:.*?Command:", re.DOTALL | re.IGNORECASE)
POST_THINK_PATTERN = re.compile(r"</think>(.*)", re.DOTALL)
STEP_NUM_PATTERN = re.compile(r"===\s*Step\s*(\d+)\s*===", re.IGNORECASE)
COMMAND_PATTERN = re.compile(r"Command:\s*(.*?)(?:\n|$)", re.DOTALL | re.IGNORECASE)


class PentestRewardFunction:
    """
    渗透测试 GRPO 奖励函数
    
    包含两个奖励组件:
    1. format_reward: 评估输出格式是否符合要求
    2. accuracy_reward: 评估内容与参考答案的匹配程度
    """
    
    def __init__(
        self,
        format_weight: float = 0.3,
        accuracy_weight: float = 0.7,
        **kwargs
    ):
        """
        初始化奖励函数
        
        Args:
            format_weight: 格式奖励权重
            accuracy_weight: 准确性奖励权重
        """
        self.format_weight = format_weight
        self.accuracy_weight = accuracy_weight
        
    def __call__(
        self,
        prompts: List[str],
        completions: List[str],
        **kwargs
    ) -> Dict[str, List[float]]:
        """
        计算奖励值(EasyR1 接口)
        
        Args:
            prompts: 提示列表
            completions: 模型生成的完成列表
            kwargs: 包含 answer 等额外信息
        
        Returns:
            Dict 包含各个奖励分量
        """
        # 获取参考答案
        answers = kwargs.get("answer", [""] * len(completions))
        
        rewards = []
        format_rewards = []
        accuracy_rewards = []
        
        for prompt, completion, answer in zip(prompts, completions, answers):
            # 计算格式奖励
            format_score = self._compute_format_reward(completion)
            format_rewards.append(format_score)
            
            # 计算准确性奖励
            accuracy_score = self._compute_accuracy_reward(completion, answer)
            accuracy_rewards.append(accuracy_score)
            
            # 加权求和
            total_score = (
                self.format_weight * format_score +
                self.accuracy_weight * accuracy_score
            )
            rewards.append(total_score)
        
        return {
            "rewards": rewards,
            "format_reward": format_rewards,
            "accuracy_reward": accuracy_rewards
        }
    
    def _compute_format_reward(self, completion: str) -> float:
        """
        格式奖励:评估输出是否符合预期格式
        
        评分标准:
        - 包含思考部分: +0.3
        - 包含正确步骤格式: +0.3
        - 满分: 0.6(原项目设计)
        """
        score = 0.0
        
        # 检查思考部分
        think_match = THINK_PATTERN.search(completion)
        if think_match and think_match.group(1).strip():
            score += 0.3
        
        # 检查步骤格式
        if STEP_PATTERN.search(completion):
            score += 0.3
        
        return score
    
    def _compute_accuracy_reward(self, completion: str, answer: str) -> float:
        """
        准确性奖励:评估内容与参考答案的匹配度
        
        评分标准:
        - 步骤编号正确: +0.2
        - 命令完全匹配: +1.0
        - 命令部分匹配 (Jaccard > 0.5): +0.7 * similarity
        """
        score = 0.0
        
        # 提取思考后的内容
        match_format = POST_THINK_PATTERN.search(completion)
        if not match_format:
            return score  # 没有思考部分,准确性为0
        
        analysis_content = match_format.group(1).strip()
        
        # 检查步骤编号匹配
        gen_step_match = STEP_NUM_PATTERN.search(analysis_content)
        true_step_match = STEP_NUM_PATTERN.search(answer)
        
        if (gen_step_match and true_step_match and 
            gen_step_match.group(1) == true_step_match.group(1)):
            score += 0.2
        
        # 提取命令
        gen_command_match = COMMAND_PATTERN.search(analysis_content)
        true_command_match = COMMAND_PATTERN.search(answer)
        
        gen_command = gen_command_match.group(1).strip() if gen_command_match else ""
        true_command = true_command_match.group(1).strip() if true_command_match else ""
        
        if not true_command:
            return score
        
        # 命令完全匹配
        if gen_command == true_command:
            score += 1.0
        elif gen_command:
            # 计算 Jaccard 相似度
            gen_words = set(gen_command.split())
            true_words = set(true_command.split())
            intersection = gen_words & true_words
            union_size = len(gen_words) + len(true_words) - len(intersection)
            
            if union_size > 0:
                similarity = len(intersection) / union_size
                if similarity > 0.5:
                    score += similarity * 0.7
        
        return min(score, 1.5)  # 最高分 1.5 (0.2 + 1.0 + 0.3)


# 创建全局奖励函数实例(用于 EasyR1)
_reward_fn_instance = PentestRewardFunction()


def create_reward_function(reward_inputs):
    """
    EasyR1 奖励函数入口点
    
    Args:
        reward_inputs: List[Dict],每个 dict 包含:
            - response: 模型生成的文本
            - ground_truth: 参考答案
    
    Returns:
        List[Dict[str, float]]: 每个输入的奖励分数
    """
    prompts = []
    completions = []
    answers = []
    
    for item in reward_inputs:
        prompts.append("")  
        completions.append(item.get("response", ""))
        answers.append(item.get("ground_truth", ""))
    

    result = _reward_fn_instance(prompts, completions, answer=answers)
    
    scores = []
    for i in range(len(reward_inputs)):
        scores.append({
            "overall": result["rewards"][i],
            "format": result["format_reward"][i],
            "accuracy": result["accuracy_reward"][i],
        })
    
    return scores


REWARD_NAME = "pentest_reward"
REWARD_TYPE = "batch"
  • 注册奖励函数,编辑 EasyR1/easyr1/reward/__init__.py,添加:
bash 复制代码
from .pentest_reward import PentestRewardFunction, create_reward_function

__all__ = [
    # ... 其他导入
    "PentestRewardFunction",
    "create_reward_function",
]
  • 同时在上一级目录创建vim EasyR1/easyr1/__init__.py。(后续训练配置文件使用绝对路径,可能这里不需要设置)

6 训练配置文件

  • 创建 EasyR1/examples/pentest_grpo_h800_optimized.yaml,使用 verl 的配置格式。
yaml 复制代码
data:
  train_files: /xxx/EasyR1/data/pentest/pentest_grpo_train.jsonl
  val_files: /xxx/EasyR1/data/pentest/pentest_grpo_eval.jsonl
  prompt_key: problem
  answer_key: answer
  image_key: images
  video_key: videos
  image_dir: null
  video_fps: 2.0
  # 上下文长度配置(基于实际数据分析优化)
  # 统计数据:12,679样本,95%分位数=3,555 tokens,99%分位数=4,817 tokens
  # 设置为5120 prompt + 3072 response = 8192 total
  # - 5120 > 99%分位数(4817),只过滤0.3%的极端长样本
  # - 3072 response空间更充足,支持更详细的thought和command
  max_prompt_length: 5120
  max_response_length: 3072
  # 增大batch size以提高GPU利用率
  rollout_batch_size: 16
  mini_rollout_batch_size: 8
  val_batch_size: 8
  format_prompt: null
  override_chat_template: null
  shuffle: true
  seed: 1
  min_pixels: 262144
  max_pixels: 4194304
  filter_overlong_prompts: true

algorithm:
  adv_estimator: grpo
  disable_kl: false
  use_kl_loss: true
  kl_penalty: low_var_kl
  kl_coef: 1.0e-2
  online_filtering: false
  filter_key: overall
  filter_low: 0.01
  filter_high: 0.99

worker:
  actor:
    # 增大global batch size
    global_batch_size: 16
    micro_batch_size_per_device_for_update: 4
    micro_batch_size_per_device_for_experience: 4
    max_grad_norm: 1.0
    padding_free: true
    dynamic_batching: true
    ulysses_size: 1
    model:
      model_path: unsloth/DeepSeek-R1-0528-Qwen3-8B
      enable_gradient_checkpointing: true
      trust_remote_code: true
      freeze_vision_tower: false
      lora:
        rank: 32
        alpha: 64
        target_modules: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj
        exclude_modules: null
    optim:
      lr: 5.0e-6
      weight_decay: 0.01
      strategy: adamw
      lr_warmup_ratio: 0.0
    fsdp:
      enable_full_shard: true
      enable_cpu_offload: false
      enable_rank0_init: true
      torch_dtype: bf16
    offload:
      # 关闭offload以提高速度(H800显存充足)
      offload_params: false
      offload_optimizer: false

  rollout:
    n: 4
    temperature: 0.6
    top_p: 0.9
    limit_images: 0
    # 提高显存利用率
    gpu_memory_utilization: 0.75
    # 关闭eager模式,启用CUDA graph加速
    enforce_eager: false
    # 启用chunked prefill提高长序列效率
    enable_chunked_prefill: true
    tensor_parallel_size: 2
    # 增大最大batch tokens以支持更大batch
    max_num_batched_tokens: 8192
    disable_tqdm: false
    val_override_config:
      temperature: 0.6
      top_p: 0.9
      n: 1

  ref:
    fsdp:
      enable_full_shard: true
      enable_cpu_offload: true
      enable_rank0_init: true
      torch_dtype: bf16
    offload:
      offload_params: false

  reward:
    reward_function: /xxx/EasyR1/easyr1/reward/pentest_reward.py:create_reward_function

trainer:
  total_epochs: 2
  max_steps: null
  project_name: pentest
  experiment_name: grpo-h800x2-optimized
  logger: ["console", "swanlab"]
  nnodes: 1
  n_gpus_per_node: 2
  max_try_make_batch: 20
  # 验证配置
  val_freq: 100                    # 每100步验证一次(平衡训练速度和监控频率)
  val_before_train: true           # 训练前验证,评估初始模型性能
  val_only: false                  # 仅验证模式(用于单独评估)
  val_generations_to_log: 3        # 验证时记录3个生成样本,便于观察模型输出
  # 保存配置
  save_freq: 50                    # 每50步保存一次检查点
  save_limit: 3                    # 最多保留3个检查点,旧的自动删除
  save_model_only: false           # false=保存完整检查点(含优化器状态),true=仅保存模型
  save_checkpoint_path: /xxx/EasyR1/checkpoints/pentest-optimized  # 检查点保存路径
  load_checkpoint_path: null       # 如需从检查点恢复,填写路径
  find_last_checkpoint: true       # 自动查找最新的检查点
syr1.jsonl
  val_files: /xxx/EasyR1/data/pentest/pentest_grpo_eval.jsonl
  prompt_key: problem
  answer_key: answer
  image_key: images
  video_key: videos
  image_dir: null
  video_fps: 2.0
  max_prompt_length: 4096
  max_response_length: 4096
  # 增大batch size以提高GPU利用率
  rollout_batch_size: 16
  mini_rollout_batch_size: 8
  val_batch_size: 8
  format_prompt: null
  override_chat_template: null
  shuffle: true
  seed: 1
  min_pixels: 262144
  max_pixels: 4194304
  filter_overlong_prompts: true

algorithm:
  adv_estimator: grpo
  disable_kl: false
  use_kl_loss: true
  kl_penalty: low_var_kl
  kl_coef: 1.0e-2
  online_filtering: false
  filter_key: overall
  filter_low: 0.01
  filter_high: 0.99

worker:
  actor:
    # 增大global batch size
    global_batch_size: 16
    micro_batch_size_per_device_for_update: 4
    micro_batch_size_per_device_for_experience: 4
    max_grad_norm: 1.0
    padding_free: true
    dynamic_batching: true
    ulysses_size: 1
    model:
      model_path: unsloth/DeepSeek-R1-0528-Qwen3-8B
      enable_gradient_checkpointing: true
      trust_remote_code: true
      freeze_vision_tower: false
      lora:
        rank: 32
        alpha: 64
        target_modules: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj
        exclude_modules: null
    optim:
      lr: 5.0e-6
      weight_decay: 0.01
      strategy: adamw
      lr_warmup_ratio: 0.0
    fsdp:
      enable_full_shard: true
      enable_cpu_offload: false
      enable_rank0_init: true
      torch_dtype: bf16
    offload:
      # 关闭offload以提高速度(H800显存充足)
      offload_params: false
      offload_optimizer: false

  rollout:
    n: 4
    temperature: 0.6
    top_p: 0.9
    limit_images: 0
    # 提高显存利用率
    gpu_memory_utilization: 0.75
    # 关闭eager模式,启用CUDA graph加速
    enforce_eager: false
    # 启用chunked prefill提高长序列效率
    enable_chunked_prefill: true
    tensor_parallel_size: 2
    # 增大最大batch tokens以支持更大batch
    max_num_batched_tokens: 8192
    disable_tqdm: false
    val_override_config:
      temperature: 0.6
      top_p: 0.9
      n: 1

  ref:
    fsdp:
      enable_full_shard: true
      enable_cpu_offload: true
      enable_rank0_init: true
      torch_dtype: bf16
    offload:
      offload_params: false

  reward:
    reward_function: /xxx/EasyR1/easyr1/reward/pentest_reward.py:create_reward_function

trainer:
  total_epochs: 2
  max_steps: null
  project_name: pentest
  experiment_name: grpo-h800x2-optimized
  logger: ["console", "swanlab"]
  nnodes: 1
  n_gpus_per_node: 2
  max_try_make_batch: 20
  val_freq: 50
  val_before_train: true
  val_only: false
  val_generations_to_log: 1
  # 保存配置
  save_freq: 50                    # 每50步保存一次检查点
  save_limit: 3                    # 最多保留3个检查点,旧的自动删除
  save_model_only: false           # false=保存完整检查点(含优化器状态),true=仅保存模型
  save_checkpoint_path: /xxx/EasyR1/checkpoints/pentest-optimized  # 检查点保存路径
  load_checkpoint_path: null       # 如需从检查点恢复,填写路径
  find_last_checkpoint: true       # 自动查找最新的检查点

7 训练启动脚本

  • 在EasyR1项目根目录下创建启动脚本文件start_training.sh
bash 复制代码
#!/bin/bash

# Pentest-R1 GRPO 训练启动脚本
# 使用 nohup 后台运行,支持 SSH 断开后继续训练

# 设置环境变量
export CUDA_VISIBLE_DEVICES=0,1
export PYTHONPATH=/root/autodl-fs/EasyR1:$PYTHONPATH

# 创建日志目录
mkdir -p /root/autodl-fs/EasyR1/logs
mkdir -p /root/autodl-fs/EasyR1/checkpoints/pentest-r1-optimized

# 获取当前时间戳
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="/root/autodl-fs/EasyR1/logs/training_${TIMESTAMP}.log"

echo "========================================"
echo "启动 Pentest-R1 GRPO 训练"
echo "配置文件: examples/pentest_grpo_h800_optimized.yaml"
echo "日志文件: ${LOG_FILE}"
echo "开始时间: $(date)"
echo "========================================"

# 使用 nohup 启动训练
# 2>&1: 将 stderr 重定向到 stdout
# &: 后台运行
nohup python3 -m verl.trainer.main \
    config=examples/pentest_grpo_h800_optimized.yaml \
    trainer.n_gpus_per_node=2 \
    > "${LOG_FILE}" 2>&1 &

# 获取进程ID
PID=$!
echo "训练进程 PID: ${PID}"
echo "${PID}" > /root/autodl-fs/EasyR1/logs/training.pid

echo ""
echo "训练已在后台启动!"
echo ""
echo "常用命令:"
echo "  查看实时日志:  tail -f ${LOG_FILE}"
echo "  查看进程状态:  ps aux | grep verl.trainer"
echo "  查看GPU状态:   nvidia-smi"
echo "  停止训练:      kill ${PID}"
echo ""
echo "查看日志: tail -f ${LOG_FILE}"
  • 训练中断再启动:可以从保存点继续训练
相关推荐
大傻^11 天前
基于群组相对策略优化(GRPO)的大模型强化学习微调技术方案
强化学习·grpo
爱听歌的周童鞋2 个月前
斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Lecture 17: Alignment - RL 2
llm·policy gradient·grpo·cs336·baselines·advantage funcs
亚里随笔2 个月前
激活被遗忘的训练信号:ERPO框架如何让大模型在数学推理中更进一步
深度学习·llm·rl·agentic·grpo
五月底_2 个月前
GRPO参数详解
人工智能·深度学习·nlp·rl·grpo
core5122 个月前
深度解析DeepSeek-R1中GRPO强化学习算法
人工智能·算法·机器学习·deepseek·grpo
core5122 个月前
【实战】使用 Unsloth 与 GRPO 微调 Qwen2.5 模型
微调·qwen·unsloth·grpo
余俊晖4 个月前
RLVR训练多模态文档解析模型-olmOCR 2技术方案(模型、数据和代码均开源)
人工智能·算法·ocr·grpo
marsggbo4 个月前
LLM 场景下的强化学习技术扫盲
llm·强化学习·ppo·dpo·grpo
songyuc5 个月前
DeepSeek-Math 学习笔记
grpo