Google开源Tunix:JAX生态的LLM微调方案来了

JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练------微调、强化学习、知识蒸馏这些都能搞。

Tunix是什么

这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:

  • 监督微调(Supervised Fine-Tuning)
  • 强化学习(Reinforcement Learning)
  • 知识蒸馏(Knowledge Distillation)

现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。

核心功能

监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。

强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。

知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。

库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。

安装

三种装法:

从PyPI装(推荐):

复制代码
 pip install "tunix[prod]"

或者直接从GitHub主分支:

复制代码
 pip install git+https://github.com/google/tunix

开发模式从源码装:

复制代码
 git clone https://github.com/google/tunix.git  
 cd tunix  
 pip install -e".[dev]"

TPU上用QLoRA微调Gemma

拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。

环境准备

复制代码
 pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets  
 pip install -q git+https://github.com/google/tunix  
 pip install -q git+https://github.com/google/qwix  
   
 # Flax需要升级到最新版
 pip uninstall -q -y flax  
 pip install -q git+https://github.com/google/flax.git

完整流程

第一步,从Kaggle拉预训练checkpoint:

复制代码
 import kagglehub  
   
 model_path = "google/gemma/flax/2b"  
 kaggle_ckpt_path = kagglehub.model_download(model_path)

初始化模型和tokenizer:

复制代码
 from flax import nnx  
from tunix.models.gemma import model as gemma_lib, params as params_lib  
from tunix.generate import tokenizer_adapter as tokenizer_lib  

base_model = gemma_lib.Transformer.from_params(  
    params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),  
    version="2b"  
)  
 tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

挂上QLoRA adapter:

复制代码
 import qwix  

lora_provider = qwix.LoraProvider(  
    module_path=".*(q_einsum|kv_einsum|proj)",  
    rank=16,  
    alpha=2.0,  
    weight_qtype="nf4"  # enable QLoRA quantization
)  
 lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。

加载训练数据:

复制代码
 from tunix.examples.data import translation_dataset  

train_ds, validation_ds = translation_dataset.create_datasets(  
    dataset_name="mtnt/en-fr",  
    global_batch_size=16,  
    max_target_length=256,  
    num_train_epochs=3,  
    tokenizer=tokenizer,  
 )

用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。

开始训练:

复制代码
 from tunix.sft import peft_trainer, utils  
import optax  

trainer=peft_trainer.PeftTrainer(  
    lora_model,  
    optimizer=optax.adamw(1e-3),  
    config=peft_trainer.TrainingConfig(max_steps=100)  
)  
 trainer.train(train_ds, validation_ds)

优化器用AdamW,学习率1e-3,跑100步看看效果。

推理测试:

训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:

复制代码
 from tunix.generate import sampler as sampler_lib  

# initialize sampler
sampler = sampler_lib.Sampler(  
    transformer=lora_model,  
    tokenizer=tokenizer,  
    cache_config=sampler_lib.CacheConfig(  
        cache_size=256,  
        num_layers=base_model.num_layers,  
        num_kv_heads=base_model.num_kv_heads,  
        head_dim=base_model.head_dim,  
    ),  
)  

# test prompts
input_batch = [  
    "Translate this into French:\nHello, my name is Morgane.\n",  
    "Translate this into French:\nThis dish is delicious!\n",  
    "Translate this into French:\nI am a student.\n",  
    "Translate this into French:\nHow's the weather today?\n",  
]  

# generate predictions
out_data = sampler(  
    input_strings=input_batch,  
    max_generation_steps=20,  
)  

# print results
for input_string, out_string in zip(input_batch, out_data.text):  
    print(f"----------------------")  
    print(f"Prompt:\n{input_string}")  
     print(f"Output:\n{out_string}")

如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。

总结

100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。

Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。

后续应该会继续扩展支持的模型类型和训练算法,值得关注。

地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97

作者:Abish Pius

相关推荐
Mintopia21 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮1 天前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬1 天前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia1 天前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区1 天前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两1 天前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪1 天前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232551 天前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
程序员打怪兽1 天前
详解Visual Transformer (ViT)网络模型
深度学习