联邦大模型微调

微调(Fine-tuning)是一种迁移学习的技术,用于在一个已经预训练好的模型基础上,通过进一步训练来适应特定的任务或数据集。微调可以在具有相似特征的任务之间共享知识,从而加快训练速度并提高模型性能。

微调步骤:

  1. 选择预训练模型:选择一个在大规模数据集上预训练好的模型,如ImageNet上的预训练的卷积神经网络(如ResNet、VGG等)。这些模型通常具有良好的特征提取能力。
  2. 冻结底层权重:将预训练模型的底层权重(通常是卷积层)固定住,不进行训练。这是因为底层权重通常学习到了通用的特征,可以被用于许多不同的任务。
  3. 替换顶层分类器:将预训练模型的顶层分类器(通常是全连接层)替换为适合特定任务的新的分类器。新的分类器的输出节点数量应该与任务的类别数相匹配。
  4. 解冻部分权重(可选):根据任务的复杂性和可用的训练数据量,可以选择解冻一些底层权重,以便更好地适应新的任务。这样可以允许底层权重进行微小的调整,以更好地适应新任务的特征。
  5. 进行训练:使用特定任务的训练数据集对新的分类器进行训练。可以使用较小的学习率进行训练,以避免对预训练模型的权重进行过大的更新。
  6. 评估和调整:在训练完成后,使用验证集或测试集评估模型的性能。根据评估结果,可以进行调整,如调整学习率、调整模型结构等。

PEFT:

PEFT(Performance Estimation and Modeling for Fine-Tuning)是一种用于微调任务的性能估计和建模方法。它的主要目的是帮助研究人员和从业者在微调过程中更好地理解和预测模型的性能,并进行更有效的模型选择和调优。

使用场景:

  1. 模型选择:在微调之前,通常需要选择一个合适的预训练模型。PEFT可以帮助评估和比较不同预训练模型在特定任务上的性能,从而选择最适合的模型。
  2. 超参数调优:微调过程中可能涉及到一些超参数的选择,如学习率、批量大小等。PEFT可以帮助预估不同超参数设置下模型的性能,并指导超参数的调优。
  3. 计算资源规划:微调通常需要大量的计算资源,如显存、GPU时间等。PEFT可以帮助估计不同模型和数据集规模下的计算资源需求,以便更好地规划和分配资源。
  4. 模型压缩和加速:在一些场景下,需要将模型压缩或加速,以便在资源受限的设备上进行推理。PEFT可以帮助评估不同压缩和加速技术对模型性能的影响,并指导模型优化的方向。

PEFT的关键步骤:

  1. 数据采样:从原始数据集中采样一小部分数据用于性能估计。这样可以减少计算开销,同时保持采样数据与原始数据集的分布一致性。
  2. 特征提取:使用预训练模型提取采样数据的特征表示。这些特征通常具有很好的表达能力,可以用于性能估计。
  3. 性能估计模型:基于采样数据的特征表示,建立一个性能估计模型。这个模型可以是简单的线性回归模型,也可以是更复杂的神经网络模型。
  4. 性能预测:使用性能估计模型对未知数据的性能进行预测。通过输入微调任务的特征表示,模型可以输出预测的性能指标,如准确率、F1分数等。

代码:

python 复制代码
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import flwr as fl
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from datasets import load_dataset
from flwr.client.mod import fixedclipping_mod
from flwr.server.strategy import (
    DifferentialPrivacyClientSideFixedClipping
)
from utils.utils import * 
from utils.LLM import LLM_fl
from utils.LLM import get_fireworks_api_key,load_env
python 复制代码
cfg = get_config("federated")

print_config(cfg)
python 复制代码
partitioner = IidPartitioner(num_partitions=cfg.flower.num_clients)
fds = FederatedDataset(
    dataset=cfg.dataset.name,
    partitioners={"train": partitioner}
)

partition_zero = fds.load_partition(0) 

format_dataset(partition_zero)

(20个客户端均分数据集)

python 复制代码
visualize_partitions(fds)

获取分词器(tokenizer)、数据整理器(data collator)和格式化提示函数(formatting prompts function):

python 复制代码
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(
    cfg.model.name, 
    cfg.model.use_fast_tokenizer,
    cfg.train.padding_side,
)

预训练的分词器:EleutherAI/pythia-70m

python 复制代码
save_path = "./my_fl_model"
client = fl.client.ClientApp(
    client_fn=gen_client_fn(
        fds,
        tokenizer,
        formatting_prompts_func,
        data_collator,
        cfg.model, 
        cfg.train, 
        save_path,
    ),
    mods=[fixedclipping_mod] 
)

fds:

复制代码
<flwr_datasets.federated_dataset.FederatedDataset object at 0x7fa271c2f790>

tokenizer:

为联邦学习服务器配置了一个 FedAvg 策略,并添加了差分隐私机制,同时指定了服务器运行的轮数和其他相关参数:

python 复制代码
def server_fn(context: Context):

    # Define the Strategy
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=cfg.flower.num_clients, # total clients
        fraction_fit=cfg.flower.fraction_fit, # ratio of clients to sample
        fraction_evaluate=0.0, # No federated evaluation
        # A (optional) function used to configure a "fit()" round
        on_fit_config_fn=get_on_fit_config(),
        # A (optional) function to aggregate metrics sent by clients
        fit_metrics_aggregation_fn=fit_weighted_average,
        # A (optional) function to execute on the server after each round. 
        # In this example the function only saves the global model.
        evaluate_fn=get_evaluate_fn( 
            cfg.model,
            cfg.train.save_every_round,
            cfg.flower.num_rounds,
            save_path
        ),
    )

    # Add Differential Privacy
    sampled_clients = cfg.flower.num_clients*strategy.fraction_fit
    strategy = DifferentialPrivacyClientSideFixedClipping(
        strategy, 
        noise_multiplier=cfg.flower.dp.noise_mult,
        clipping_norm=cfg.flower.dp.clip_norm, 
        num_sampled_clients=sampled_clients
    )

    # Number of rounds to run the simulation
    num_rounds = cfg.flower.num_rounds
    config = fl.server.ServerConfig(num_rounds=num_rounds)
    
    return fl.server.ServerAppComponents(strategy=strategy, config=config) 

配置和启动一个联邦学习模拟,其中包含了服务器和客户端的应用,以及它们运行所需的资源。通过调用 run_simulation 函数,模拟开始执行,客户端和服务器之间将进行多轮通信,以协作训练一个共享的模型,同时保护数据的隐私:

python 复制代码
client_resources = dict(cfg.flower.client_resources)
fl.simulation.run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=cfg.flower.num_clients,
    backend_config={"client_resources": client_resources,
                    "init_args": backend_setup}
)

运行微调后的模型:

python 复制代码
# Load the checkpoint
llm_eval = LLM_fl()

# Load dataset
train_dataset = load_dataset(cfg.dataset.name, split='train')
train_dataset = format_dataset(train_dataset)

# Select training example
example_index = 6

data_point = train_dataset[example_index]

# Print the prompt
llm_eval.eval(data_point['instruction'], verbose=True)
python 复制代码
# Print the fine-tuned LLM response
llm_eval.print_response()
python 复制代码
# Print the expected output from the medAlpaca dataset
ex_response = format_string(data_point['response'])
print(f"Expected output:\n\t{ex_response}")

结果可视化:

python 复制代码
visualize_results(
    results=['7b/pretrained', '7b/cen_10', '7b/fl'])

向集中式精细调整模型提供相同数量的数据:

python 复制代码
visualize_results(
    results=['7b/pretrained', '7b/cen_10',
             '7b/cen_full', '7b/fl'],
    compact=True)

计算花销:

python 复制代码
cfg = get_config("federated")

compute_communication_costs(cfg, comm_bw_mbps=20)

参考:

【1】LLMs_interview_notes/大模型(LLMs)参数高效微调(PEFT)面/大模型(LLMs)参数高效微调(PEFT)面.md at main · naginoa/LLMs_interview_notes · GitHub

【2】

相关推荐
AIGC大时代2 小时前
如何判断一个学术论文是否具有真正的科研价值?ChatGPT如何提供帮助?
大数据·人工智能·物联网·chatgpt·aigc
岁月如歌,青春不败2 小时前
HMSC联合物种分布模型
开发语言·人工智能·python·深度学习·r语言
海域云赵从友3 小时前
香港 GPU 服务器托管引领 AI 创新,助力 AI 发展
运维·服务器·人工智能
四口鲸鱼爱吃盐3 小时前
Pytorch | 利用GRA针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉
蓝天星空3 小时前
制作一个类似ChatGPT的AI对话网站,模型能力使用ChatGPT
人工智能
汤姆和佩琦3 小时前
24-12-28-pytorch深度学习中音频I/O 中遇到的问题汇总
人工智能·pytorch·python·深度学习·音视频·i/o
静静AI学堂3 小时前
SCSA:探索空间与通道注意力之间的协同效应
人工智能·深度学习·yolo·目标跟踪
一只敲代码的猪3 小时前
Llama 3 后训练(三)
深度学习·神经网络·机器学习·llama
小成晓程4 小时前
PyQt6+OpenCV 项目练习
人工智能·opencv·计算机视觉
奔波儿灞爱霸波尔奔5 小时前
人工智能之基于阿里云进行人脸特征检测部署
人工智能·阿里云·云计算