联邦经典算法Fedavg实现

一、简介

本文以多层感知机MLP在CIFAR-10上进行图像分类为例,详细阐述了Fedavg这一经典联邦学习算法的应用,为各位读者打开初步学习联邦学习的大门。

(Fedavg算法伪代码)

在联邦学习中,Server 不训练模型,只做调度与聚合;Client 不共享数据,只做本地训练。

(上图为联邦聚合算法Fedavg的整体训练流程)

因为联邦学习是分布式的机器学习算法,所以在本文在模拟时按照多文件的代码结构梳理,读者想要复现的话请按照下面的项目格式进行整理,简单易上手,将下面的代码按照以下结构进行梳理后直接按照项目启动的说明就可以运行了。

fedavg_cifar10/

├── README.md

├── run.py # 主入口

├── server.py # 联邦服务器

├── client.py # 客户端逻辑

├── fedavg.py # FedAvg 聚合

├── model.py # MLP 模型

├── data.py # CIFAR-10 数据 + 客户端划分

├── utils.py # 评估工具

二、项目启动

1️⃣ 创建 Conda 环境

conda create -n fedavg-demo python=3.9 -y

conda activate fedavg-demo

2️⃣ 安装 PyTorch

pip install torch torchvision torchaudio

3️⃣ 安装其余依赖

pip install numpy tqdm

4️⃣启动训练

python run.py

第一次运行的时候会首先下载数据集,大概要20-30mins左右,大家耐心等待就好啦~

最后应该会输出类似的内容:

===== Round 0 =====

Global Accuracy after round 0: 0.312

===== Round 1 =====

Global Accuracy after round 1: 0.418

...

三、项目说明

下面对每个模块做详细解读,想要直接运行的朋友可以直接复制粘贴:

1️⃣ model.py ------ 模型定义("被联邦的对象")

python 复制代码
# model.py
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(32 * 32 * 3, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        return self.net(x)
  • CIFAR-10 输入:32×32×3

  • 全是 LinearFedAvg 非常直观

2️⃣ data.py ------ CIFAR-10 + 客户端数据划分(IID)

python 复制代码
# data.py
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

def load_cifar10():
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    train_set = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    return train_set, test_set

def split_dataset(dataset, num_clients):
    """IID 划分"""
    num_items = len(dataset) // num_clients
    indices = torch.randperm(len(dataset))

    client_subsets = []
    for i in range(num_clients):
        subset_idx = indices[i * num_items:(i + 1) * num_items]
        client_subsets.append(Subset(dataset, subset_idx))

    return client_subsets

def get_dataloader(dataset, batch_size=64, shuffle=True):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

3️⃣ client.py ------ 客户端本地训练(关键)

python 复制代码
# client.py
import torch
from model import MLP

def local_train(
    client_id,
    global_state_dict,
    train_loader,
    device,
    local_epochs=1,
    lr=1e-2
):
    model = MLP().to(device)
    model.load_state_dict(global_state_dict)
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    total_samples = 0

    for _ in range(local_epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

            total_samples += x.size(0)

    return {
        "state_dict": model.state_dict(),
        "num_samples": total_samples
    }

这一步就是在模拟真实联邦客户端,每个客户端client 不知道别的 client,客户端client之间 只看到 服务器端聚合后分发下来的全局模型global model,客户端client 只上传参数,不上传数据。

4️⃣ fedavg.py ------ FedAvg 代码实现

python 复制代码
# fedavg.py
import torch

def fedavg(client_results):
    total_samples = sum(r["num_samples"] for r in client_results)
    global_state = {}

    for key in client_results[0]["state_dict"].keys():
        global_state[key] = sum(
            r["state_dict"][key] * (r["num_samples"] / total_samples)
            for r in client_results
        )

    return global_state

上述代码的作用如下:

  • 遍历每一个参数(weight / bias)

  • 按客户端样本数加权

  • 做逐元素平均

5️⃣ utils.py ------ 测试准确率

python 复制代码
# utils.py
import torch

def evaluate(model, dataloader, device):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total

6️⃣ server.py ------ 联邦服务器(调度核心)

python 复制代码
# server.py
import copy
import torch
from model import MLP
from client import local_train
from fedavg import fedavg

class FedAvgServer:
    def __init__(self, client_loaders, test_loader, device):
        self.client_loaders = client_loaders
        self.test_loader = test_loader
        self.device = device
        self.global_model = MLP().to(device)

    def run_round(self, round_id):
        print(f"\n===== Round {round_id} =====")

        global_state = copy.deepcopy(self.global_model.state_dict())
        client_results = []

        for cid, loader in enumerate(self.client_loaders):
            result = local_train(
                cid, global_state, loader, self.device
            )
            client_results.append(result)

        new_state = fedavg(client_results)
        self.global_model.load_state_dict(new_state)

    def train(self, rounds):
        from utils import evaluate

        for r in range(rounds):
            self.run_round(r)
            acc = evaluate(self.global_model, self.test_loader, self.device)
            print(f"Global Accuracy after round {r}: {acc:.4f}")

我们上面说过,服务器不训练模型,只做模型的调度和聚合。server端只干三件事:下发模型、收集客户端参数和调用Fedavg进行聚合。

7️⃣ run.py ------ 一键启动

python 复制代码
# run.py
import torch
from data import load_cifar10, split_dataset, get_dataloader
from server import FedAvgServer

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_clients = 3      # 改成 6 试试
    rounds = 10

    train_set, test_set = load_cifar10()
    client_datasets = split_dataset(train_set, num_clients)

    client_loaders = [
        get_dataloader(ds, batch_size=64)
        for ds in client_datasets
    ]
    test_loader = get_dataloader(test_set, batch_size=256, shuffle=False)

    server = FedAvgServer(client_loaders, test_loader, device)
    server.train(rounds)

if __name__ == "__main__":
    main()

所有的联邦变量都将在这里控制:客户端数量、通信轮数、batch size和CPU或者GPU的设备。

相关推荐
天天讯通2 小时前
数据公司与AI五大主流合作模式
人工智能
Clarence Liu2 小时前
AI Agent开发(2) - 深入解析 A2A 协议与 Go 实战指南
开发语言·人工智能·golang
综合热讯2 小时前
AUS GLOBAL 荣耀赞助 2026 LIL TOUR 高尔夫嘉年华
人工智能
小饼干超人2 小时前
详解向量数据库中的PQ算法(Product Quantization)
人工智能·算法·机器学习
哥布林学者3 小时前
吴恩达深度学习课程五:自然语言处理 第二周:词嵌入(三)Word2Vec
深度学习·ai
砚边数影3 小时前
AI数学基础(一):线性代数核心,向量/矩阵运算的Java实现
java·数据库·人工智能·线性代数·矩阵·ai编程·金仓数据库
互联网科技看点3 小时前
诸葛io获认可:金融分析智能体赛道领航者
大数据·人工智能·金融
engchina3 小时前
自然语言转 SQL 并不是“魔法”
数据库·人工智能·sql·text2sql·nl2sql·自然语言转sql
一行注释也不写3 小时前
【卷积层和池化层在CNN中的作用】
深度学习·计算机视觉·cnn