【Finetune】(一)、transformers之BitFit微调

文章目录

0、参数微调简介

参数微调方法是仅对模型的一小部分的参数(这一小部分可能是模型自身的,也可能是外部引入的)进行训练,便可以为模型带来显著的性能变化,在一些场景下甚至不输于全量微调。

由于训练一小部分参数,极大程度降低了训练大模型的算力需求,不需要多机多卡,单卡就可以完成对一些大模型的训练。不仅如此,少量的训练参数,对存储的要求同样降低很多,大多数的参数微调方法只需要保存训练部分的参数,与动辄几十GB的原始大模型相比,几乎可以忽略。

1、常见的微调方法

常见的微调方法如图所示:

Lialin, Vladislav, Vijeta Deshpande, and Anna Rumshisky. "Scaling down to scale up: A guide to parameter-efficient fine-tuning." arXiv preprint arXiv:2303.15647 (2023).

2、代码实战

  • 模型------bloom-389m-zh
  • 数据集------alpaca_data_zh

2.1、导包

python 复制代码
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

2.2、加载数据集

python 复制代码
ds = Dataset.load_from_disk("./alpaca_data_zh/")

2.3、数据集处理

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
python 复制代码
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
python 复制代码
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

2.4、创建模型

python 复制代码
model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)

2.5、BitFit微调*

python 复制代码
#选择模型参数里面的所有bias部分
#非bias部分冻结
num_param = 0
for name,param in model.named_parameters():
    if 'bias' not in name:
        param.requires_grad = False
    else:
        num_param+=param.numel()
num_param

2.6、配置模型参数

python 复制代码
args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1
)

2.7、创建训练器

python 复制代码
trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)

2.8、模型训练

python 复制代码
trainer.train()

2.9、模型推理

python 复制代码
from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
python 复制代码
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)
相关推荐
测试员周周4 小时前
【Appium 系列】第04节-Page Object 模式 — BasePage 基类设计
开发语言·数据库·人工智能·python·语言模型·appium·web app
Marry Andy5 小时前
Atlas 800T A2部署qwen3-32b
linux·人工智能·语言模型·自然语言处理
玖日大大5 小时前
2026十大LLM研究突破:扩散语言模型挑战自回归、Unicode隐形注入、AI操纵性评估 — 大模型从狂飙走向可控
人工智能·语言模型·回归·llm·论文解读·ai agent·ai安全
龙侠九重天5 小时前
大型语言模型结构化输出:用 JSON Schema 约束大模型输出
人工智能·语言模型·自然语言处理·大模型·json
kishu_iOS&AI5 小时前
NLP —— 迁移学习 FastText
人工智能·自然语言处理·迁移学习
司南OpenCompass5 小时前
GPT领跑,头部模型“错位竞争”,强Agent能力成下一战场丨大语言模型4月最新榜单揭晓
人工智能·gpt·语言模型·大模型·大模型评测·司南评测
财经资讯数据_灵砚智能5 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月14日
大数据·人工智能·python·信息可视化·自然语言处理
财经资讯数据_灵砚智能19 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年5月12日
人工智能·python·信息可视化·自然语言处理·ai编程
莽撞的大地瓜21 小时前
多模态内容校对智能体新突破:蜜度校对通以全流程自动化重塑校对标准
自然语言处理·全文检索·中文分词
纤纡.21 小时前
从零搭建 AI 智能 PDF 问答工具:Streamlit+LangChain + 千问大模型实战
人工智能·阿里云·语言模型·langchain