StarCoder-3B微调和RAG的技术原理

一、StarCoder-3B模型微调:让模型学会G代码语法

1.1 为什么选择StarCoder-3B?

StarCoder系列模型是在超过80种编程语言的代码数据上预训练的,本身就具备较强的代码生成能力。论文选择3B参数版本,是出于计算效率和性能的平衡------相比7B或15B模型,3B模型更容易在单卡或有限GPU资源上进行微调。

1.2 微调技术:PEFT + LoRA

论文提到使用参数高效微调(PEFT)混合精度训练,这是目前微调大语言模型的主流做法:

  • LoRA(低秩适配) :在预训练模型的某些层(通常是注意力层的权重矩阵)旁边添加小的可训练"适配器"模块,只微调这些适配器的参数,而冻结原始模型参数。这能大幅减少可训练参数量,论文中trainable%仅约0.23%(3555万可训练参数 vs 155亿总参数)。

  • 混合精度训练:使用fp16或bf16精度进行训练,减少显存占用,加速计算。

1.3 训练数据:从Stack数据集中提取G代码

论文使用Stack数据集(GitHub公共代码仓库集合)中的G代码数据进行微调。具体做法是:

  1. 数据过滤 :从Stack数据集中筛选出包含G代码的文件(通常扩展名为.nc.gcode或包含G/M代码的文件)。

  2. 指令微调格式 :将数据组织成**"指令-输出"**对的形式。例如:

bash 复制代码
指令:生成一个铣削边长为10mm的正方形的G代码,刀具起始点为(0,0,0)。
输出:G00 X0 Y0 Z0\nG01 X10 Y0 F100\n...

数据集规模:虽然没有给出具体数量,但通常需要数千到数万个样本才能使模型较好地掌握领域特定语法-

8

1.4 微调后的模型能力

经过微调后,FT-StarCoder-3B能够:

• 理解自然语言描述的加工任务(如"铣削一个带有两个圆形岛状区域的矩形槽")。

• 生成符合G代码语法规则的指令序列-

9

二、检索增强生成(RAG):为模型注入外部知识

2.1 RAG的核心思想

RAG是一种"开卷考试"模式-

3

:在模型回答问题前,先从外部知识库中检索出最相关的信息,然后将这些信息作为上下文提供给模型,让模型基于给定的材料生成答案。这能有效解决LLM的知识截止问题和"幻觉"问题-

3

9

2.2 GLLM中的RAG实现

论文中RAG的作用是增强模型对CNC特定知识的理解,包括:

• G代码命令的完整定义

• 目标CNC机床的详细文档(如行程范围、支持的M代码等)-

9

具体实现流程:

  1. 知识库构建:

◦ 收集CNC领域的PDF文档(如机床手册、G代码参考指南)-

2

9

◦ 加载文档内容,进行分片(Chunking)------将长文档切分成较小的、语义完整的文本块-

3

◦ 使用嵌入模型(如OpenAI的text-embedding-3-small或其他开源嵌入模型)将每个文本块转换为向量,存储在向量数据库中(论文使用FAISS)-

9

  1. 检索流程:

◦ 当用户输入自然语言指令时,系统将指令同样转换为向量。

◦ 在FAISS中进行相似性搜索,召回最相关的k个文本块(例如:相关的G代码命令说明、机床参数)-

2

3

◦ 将召回的文本块作为"上下文",与用户的原始指令一起构建成新的提示词,输入给LLM-

3

6

  1. RAG的集成位置:

◦ 发生在参数提取之后、G代码生成之前(见图2架构图)-

9

◦ 与提示工程结合:提取的参数和RAG检索到的知识一起,填充到预定义的提示模板中。

2.3 实验中的意外发现

论文实验显示:在非结构化提示下使用RAG反而降低了模型性能-

9

。可能的原因:

• 非结构化提示缺乏明确的指令框架,RAG引入的额外信息可能增加了输入的复杂性,干扰了模型的理解。

• 模型本身已经通过微调掌握了足够的G代码知识,外部知识在某些简单任务上成为"噪音"。

这说明RAG的集成方式需要优化,可能更适合与结构化提示结合,或者只在对特定机床参数有疑问时才触发检索。

三、如何复现GLLM工作

3.1 环境准备

硬件要求:

• GPU:建议至少8GB显存(如NVIDIA RTX 3070/4080或AMD等效显卡),如果需要多卡并行则要求更高-

1

• 内存:16GB以上。

软件环境:

• Python 3.9+

• PyTorch 2.0+

• CUDA 11.8+(如果使用NVIDIA GPU)

bash 复制代码
# 安装必要库
pip install transformers accelerate peft datasets bitsandbytes
pip install langchain langchain-community faiss-cpu openai
pip install streamlit  # 用于构建Web界面

3.2 微调StarCoder-3B的步骤

步骤1:准备数据集

从Hugging Face下载Stack数据集或自行收集G代码文件:

bash 复制代码
from datasets import load_dataset

# 以Stack数据集为例,过滤出G代码相关文件
dataset = load_dataset("bigcode/the-stack", data_dir="data/gcode", split="train")

数据格式示例:

bash 复制代码
{
  "instruction": "铣削一个边长为50mm的正方形,深度2mm,进给速度100mm/min",
  "output": "G00 X0 Y0 Z2\nG01 Z-2 F50\nG01 X50 Y0 F100\nG01 X50 Y50\nG01 X0 Y50\nG01 X0 Y0\nG00 Z2\nM30"
}
步骤2:微调脚本

参考AMD博客的finetune.py脚本或StarCoder2训练指南:

python 复制代码
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

# 加载模型和tokenizer
model_name = "bigcode/starcoder2-3b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    load_in_4bit=True,  # 4bit量化减少显存占用
    device_map="auto"
)

# LoRA配置
lora_config = LoraConfig(
    r=16,  # LoRA秩
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # 要应用LoRA的模块
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

# 训练参数
training_args = TrainingArguments(
    output_dir="./starcoder-gcode",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    max_steps=1000,
    logging_steps=10,
    save_steps=200,
    fp16=True,
)

# 创建Trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    dataset_text_field="output",  # 使用output字段作为训练文本
    max_seq_length=1024,
)

trainer.train()

3.3 RAG系统的搭建

参考阿里云的RAG搭建指南和LangGraph教程:

python 复制代码
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

# 1. 加载PDF文档(如机床手册、G代码参考)
loader = PyPDFLoader("cnc_manual.pdf")
documents = loader.load()

# 2. 文档分片
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50
)
chunks = text_splitter.split_documents(documents)

# 3. 创建嵌入模型(使用开源模型)
embeddings = HuggingFaceEmbeddings(
    model_name="BAAI/bge-small-zh"  # 或其他嵌入模型
)

# 4. 存入FAISS向量库
vector_store = FAISS.from_documents(chunks, embeddings)

# 5. 检索函数
def retrieve_context(query, k=3):
    docs = vector_store.similarity_search(query, k=k)
    return "\n".join([doc.page_content for doc in docs])

# 6. 与LLM集成
def generate_with_rag(user_input, extracted_params):
    # 构建检索查询
    search_query = f"{user_input} 参数: {extracted_params}"
    context = retrieve_context(search_query)
    
    # 构建提示词
    prompt = f"""根据以下参考信息和用户需求生成G代码:
    
参考信息:
{context}

用户需求:{user_input}
提取参数:{extracted_params}

请生成符合CNC加工规范的G代码:"""
    
    # 调用微调后的模型
    # ...

3.4 自纠错机制实现

论文使用LangGraph构建有向无环图(DAG)来实现迭代优化

python 复制代码
from langgraph.graph import StateGraph

class GCodeState:
    def __init__(self):
        self.gcode = ""
        self.errors = []
        self.iteration = 0

# 定义节点
def generate_node(state):
    # 初始生成或根据错误重新生成
    state.gcode = model.generate(state.prompt)
    return state

def syntax_check_node(state):
    errors = check_syntax(state.gcode)  # 实现语法检查
    if errors:
        state.errors.append(("syntax", errors))
    return state

def semantic_check_node(state):
    # 使用Hausdorff距离检查刀具路径
    distance = compute_hausdorff_distance(state.gcode, state.user_path)
    if distance > threshold:
        state.errors.append(("semantic", f"Hausdorff距离={distance}"))
    return state

# 构建图
graph = StateGraph(GCodeState)
graph.add_node("generate", generate_node)
graph.add_node("syntax_check", syntax_check_node)
graph.add_node("semantic_check", semantic_check_node)

# 添加条件边
graph.add_conditional_edges(
    "syntax_check",
    lambda state: "generate" if state.errors else "semantic_check"
)

四、推荐的Task List

如果你想复现或扩展GLLM的工作,可以参考以下任务清单:

阶段1:基础准备

  • 熟悉G代码语法和CNC加工基础(阅读LinuxCNC文档中的示例代码)

  • 收集至少1000个G代码样本(可从开源CNC项目、3D打印切片软件导出)

  • 准备CNC机床手册、G代码参考文档作为RAG知识库

阶段2:模型微调

  • 搭建微调环境(配置GPU驱动、PyTorch、CUDA)

  • 清洗并格式化训练数据(构建指令-输出对)

  • 使用LoRA微调StarCoder-3B或StarCoder2-3B

  • 评估微调效果:测试模型能否生成语法正确的简单G代码(如正方形铣削)

阶段3:RAG系统

  • 实现文档加载和分片(支持PDF/TXT)

  • 选择嵌入模型并构建向量库(FAISS或Chroma)

  • 实现检索功能,测试检索相关性

  • 将RAG集成到提示词构造流程

阶段4:自纠错机制

  • 实现G代码语法检查器(逐行解析,校验命令格式)

  • 实现安全检查(检测切削时的快速移动、钻孔安全高度)

  • 实现语义检查:从G代码解析刀具路径,计算Hausdorff距离

  • 构建迭代优化循环(使用LangGraph或自定义循环)

阶段5:Web界面与评估

  • 使用Streamlit构建用户界面

  • 集成CAMotics或自定义2D可视化

  • 设计测试任务集(至少包含论文中的6类任务)

  • 运行实验,记录成功率和平均迭代次数

  • 对比结构化提示 vs 非结构化提示、有无RAG的效果

可选扩展

  • 尝试不同的嵌入模型和检索策略,优化RAG性能

  • 微调更大的模型(如CodeLlama-7B)进行比较

  • 添加用户反馈回路,让用户可以手动修正生成的G代码

通过以上步骤,你应该能够复现GLLM的核心功能,并在此基础上进行改进和扩展。

相关推荐
就叫你天选之人啦2 小时前
GBDT系列八股(XGBoost、LightGBM)
人工智能·深度学习·学习·机器学习
hans汉斯2 小时前
基于区块链和语义增强的科研诚信智能管控平台
人工智能·算法·yolo·数据挖掘·区块链·汉斯出版社
冷小鱼2 小时前
机器学习极简入门:从外卖预测到AI核心算法
人工智能·算法·机器学习
yinyan13142 小时前
一起学springAI系列一:使用多种聊天模型
java·人工智能·spring boot·后端·spring·springai
冷小鱼2 小时前
Word2Vec 揭秘:如何让计算机“理解“词语?
人工智能·自然语言处理·word2vec
技术小甜甜2 小时前
[Python实战] 用 pathlib 彻底统一文件路径处理,比字符串拼接稳得多
开发语言·人工智能·python·ai·效率化
未来之窗软件服务2 小时前
二次训练中文 NLU小体积[AI人工智能(五十九)]—东方仙盟
人工智能·仙盟创梦ide·东方仙盟
landuochong2002 小时前
用 Telegram 远程控制你本地的 Claude Code
人工智能·架构·claudecode
Westward-sun.2 小时前
OpenCV图像透视变换:自动矫正倾斜的发票
人工智能·opencv·计算机视觉