Python实例题:基于联邦学习的隐私保护 AI 系统(分布式学习、隐私计算)

目录

Python实例题

题目

问题描述

解题思路

关键代码框架

难点分析

扩展方向

Python实例题

题目

基于联邦学习的隐私保护 AI 系统(分布式学习、隐私计算)

问题描述

开发一个基于联邦学习的隐私保护 AI 系统,包含以下功能:

  • 联邦学习框架:支持多种机器学习模型的联邦训练
  • 隐私保护机制:差分隐私、同态加密等技术保护数据隐私
  • 模型聚合:安全聚合各参与方的模型参数
  • 客户端管理:管理和协调多个参与训练的客户端
  • 评估与部署:评估联邦模型性能并部署到生产环境

解题思路

  • 采用横向或纵向联邦学习架构
  • 实现安全聚合协议(如 FedAvg、FedProx)
  • 应用差分隐私或同态加密保护数据隐私
  • 设计客户端 - 服务器通信协议
  • 开发模型评估和部署工具

关键代码框架

python 复制代码
# 联邦学习服务器端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import json
import logging
from typing import List, Dict, Any, Tuple
from cryptography.fernet import Fernet

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class FedAvgServer:
    def __init__(self, model: nn.Module, clients: List[str], config: Dict[str, Any]):
        self.model = model
        self.clients = clients
        self.config = config
        self.global_round = 0
        self.client_models = {client: None for client in clients}
        self.client_weights = {client: 1.0 for client in clients}  # 客户端权重
        
        # 初始化加密密钥
        self.encryption_key = Fernet.generate_key()
        self.cipher_suite = Fernet(self.encryption_key)
        
        # 初始化优化器
        self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])
        
    def aggregate_models(self) -> None:
        """聚合客户端模型"""
        logger.info(f"开始第 {self.global_round} 轮模型聚合")
        
        # 检查是否所有客户端都提交了模型
        for client, model_params in self.client_models.items():
            if model_params is None:
                logger.warning(f"客户端 {client} 未提交模型,跳过此轮")
                return
        
        # 计算总权重
        total_weight = sum(self.client_weights.values())
        
        # 初始化全局模型参数
        global_params = {}
        for name, param in self.model.named_parameters():
            global_params[name] = torch.zeros_like(param.data)
        
        # 加权聚合
        for client, model_params in self.client_models.items():
            weight = self.client_weights[client] / total_weight
            
            for name, param in model_params.items():
                global_params[name] += param * weight
        
        # 更新全局模型
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.data.copy_(global_params[name])
        
        # 增加全局轮次
        self.global_round += 1
        
        # 重置客户端模型
        self.client_models = {client: None for client in self.clients}
        
        logger.info(f"第 {self.global_round-1} 轮模型聚合完成")
    
    def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:
        """加密模型参数"""
        # 将模型参数转换为numpy数组并序列化为JSON
        model_dict = {name: param.numpy().tolist() for name, param in model_params.items()}
        model_json = json.dumps(model_dict).encode('utf-8')
        
        # 加密
        encrypted_data = self.cipher_suite.encrypt(model_json)
        
        return encrypted_data
    
    def decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:
        """解密模型参数"""
        # 解密
        decrypted_data = self.cipher_suite.decrypt(encrypted_data)
        model_dict = json.loads(decrypted_data.decode('utf-8'))
        
        # 转换回PyTorch张量
        model_params = {name: torch.tensor(param) for name, param in model_dict.items()}
        
        return model_params
    
    def receive_client_model(self, client_id: str, encrypted_model: bytes, client_weight: float) -> None:
        """接收客户端模型"""
        if client_id not in self.clients:
            logger.warning(f"未知客户端: {client_id}")
            return
        
        try:
            # 解密模型
            model_params = self.decrypt_model(encrypted_model)
            
            # 存储客户端模型
            self.client_models[client_id] = model_params
            self.client_weights[client_id] = client_weight
            
            logger.info(f"收到客户端 {client_id} 的模型,权重: {client_weight}")
        except Exception as e:
            logger.error(f"接收客户端模型失败: {e}")
    
    def send_global_model(self, client_id: str) -> bytes:
        """向客户端发送全局模型"""
        if client_id not in self.clients:
            logger.warning(f"未知客户端: {client_id}")
            return None
        
        # 获取当前全局模型参数
        model_params = {name: param.data for name, param in self.model.named_parameters()}
        
        # 加密并发送
        return self.encrypt_model(model_params)
    
    def evaluate_model(self, test_loader: DataLoader) -> Tuple[float, float]:
        """评估模型性能"""
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                outputs = self.model(inputs)
                loss = nn.CrossEntropyLoss()(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        accuracy = 100.0 * correct / total
        avg_loss = test_loss / len(test_loader)
        
        logger.info(f"模型评估结果: 准确率 = {accuracy:.2f}%, 平均损失 = {avg_loss:.4f}")
        
        return accuracy, avg_loss
    
    def save_model(self, path: str) -> None:
        """保存模型"""
        torch.save(self.model.state_dict(), path)
        logger.info(f"模型已保存到: {path}")
python 复制代码
# 联邦学习客户端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import json
import logging
from typing import Dict, Any, List, Tuple
from cryptography.fernet import Fernet

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class FedAvgClient:
    def __init__(self, client_id: str, model: nn.Module, train_data: Dataset, config: Dict[str, Any]):
        self.client_id = client_id
        self.model = model
        self.train_data = train_data
        self.config = config
        
        # 创建数据加载器
        self.train_loader = DataLoader(
            train_data, 
            batch_size=config['batch_size'], 
            shuffle=True
        )
        
        # 初始化优化器
        self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])
        
        # 加密工具
        self.encryption_key = None  # 将从服务器接收
        self.cipher_suite = None
    
    def set_encryption_key(self, key: bytes) -> None:
        """设置加密密钥"""
        self.encryption_key = key
        self.cipher_suite = Fernet(key)
    
    def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:
        """加密模型参数"""
        if self.cipher_suite is None:
            raise ValueError("未设置加密密钥")
        
        # 将模型参数转换为numpy数组并序列化为JSON
        model_dict = {name: param.numpy().tolist() for name, param in model_params.items()}
        model_json = json.dumps(model_dict).encode('utf-8')
        
        # 加密
        encrypted_data = self.cipher_suite.encrypt(model_json)
        
        return encrypted_data
    
    def decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:
        """解密模型参数"""
        if self.cipher_suite is None:
            raise ValueError("未设置加密密钥")
        
        # 解密
        decrypted_data = self.cipher_suite.decrypt(encrypted_data)
        model_dict = json.loads(decrypted_data.decode('utf-8'))
        
        # 转换回PyTorch张量
        model_params = {name: torch.tensor(param) for name, param in model_dict.items()}
        
        return model_params
    
    def update_model(self, encrypted_global_model: bytes) -> None:
        """更新本地模型为全局模型"""
        try:
            # 解密全局模型
            global_params = self.decrypt_model(encrypted_global_model)
            
            # 更新本地模型
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    param.data.copy_(global_params[name])
            
            logger.info(f"客户端 {self.client_id} 模型已更新")
        except Exception as e:
            logger.error(f"更新模型失败: {e}")
    
    def train(self, epochs: int) -> Tuple[Dict[str, torch.Tensor], float]:
        """本地训练模型"""
        self.model.train()
        
        for epoch in range(epochs):
            epoch_loss = 0
            batches = 0
            
            for inputs, targets in self.train_loader:
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = nn.CrossEntropyLoss()(outputs, targets)
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item()
                batches += 1
            
            avg_loss = epoch_loss / batches
            logger.info(f"客户端 {self.client_id}, 轮次 {epoch+1}/{epochs}, 平均损失: {avg_loss:.4f}")
        
        # 获取训练后的模型参数
        model_params = {name: param.data for name, param in self.model.named_parameters()}
        
        # 返回模型参数和样本数量(作为权重)
        return model_params, len(self.train_data)
    
    def get_encrypted_model(self, epochs: int = 1) -> bytes:
        """训练并返回加密的模型参数"""
        model_params, weight = self.train(epochs)
        encrypted_model = self.encrypt_model(model_params)
        
        return encrypted_model, weight
python 复制代码
# 联邦学习主程序
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Subset
import numpy as np
from typing import List, Dict, Any

# 定义简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

def split_dataset(dataset, num_clients: int, iid: bool = True) -> List[Subset]:
    """分割数据集给多个客户端"""
    num_samples = len(dataset) // num_clients
    client_datasets = []
    
    if iid:
        # IID方式分割(随机分配)
        indices = list(range(len(dataset)))
        np.random.shuffle(indices)
        
        for i in range(num_clients):
            client_indices = indices[i * num_samples : (i + 1) * num_samples]
            client_datasets.append(Subset(dataset, client_indices))
    else:
        # 非IID方式分割(按标签排序后分配)
        # 这里简化处理,实际应用中可能需要更复杂的分割策略
        labels = np.array([dataset[i][1] for i in range(len(dataset))])
        indices = np.argsort(labels)
        
        for i in range(num_clients):
            client_indices = indices[i * num_samples : (i + 1) * num_samples]
            client_datasets.append(Subset(dataset, client_indices))
    
    return client_datasets

def run_federated_learning(config: Dict[str, Any]):
    """运行联邦学习过程"""
    # 设置随机种子
    torch.manual_seed(config['seed'])
    np.random.seed(config['seed'])
    
    # 加载数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('data', train=False, transform=transform)
    
    # 分割训练数据给客户端
    client_datasets = split_dataset(train_dataset, config['num_clients'], config['iid'])
    
    # 创建测试数据加载器
    test_loader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=config['batch_size'], 
        shuffle=False
    )
    
    # 初始化服务器和客户端
    global_model = SimpleCNN()
    server = FedAvgServer(global_model, [f"client{i}" for i in range(config['num_clients'])], config)
    
    clients = []
    for i in range(config['num_clients']):
        client_model = SimpleCNN()
        # 初始时客户端模型与全局模型相同
        client_model.load_state_dict(global_model.state_dict())
        client = FedAvgClient(f"client{i}", client_model, client_datasets[i], config)
        clients.append(client)
    
    # 分发加密密钥给客户端
    for client in clients:
        client.set_encryption_key(server.encryption_key)
    
    # 联邦学习训练循环
    for round in range(config['global_rounds']):
        logger.info(f"===== 开始第 {round+1}/{config['global_rounds']} 轮联邦学习 =====")
        
        # 选择参与本轮的客户端
        selected_clients = np.random.choice(
            clients, 
            size=min(config['clients_per_round'], len(clients)), 
            replace=False
        )
        
        # 向客户端发送全局模型
        for client in selected_clients:
            encrypted_global_model = server.send_global_model(client.client_id)
            client.update_model(encrypted_global_model)
        
        # 客户端本地训练
        for client in selected_clients:
            encrypted_model, client_weight = client.get_encrypted_model(config['local_epochs'])
            server.receive_client_model(client.client_id, encrypted_model, client_weight)
        
        # 服务器聚合模型
        server.aggregate_models()
        
        # 评估全局模型
        if (round + 1) % config['eval_every'] == 0:
            accuracy, loss = server.evaluate_model(test_loader)
            logger.info(f"第 {round+1} 轮评估结果: 准确率 = {accuracy:.2f}%, 损失 = {loss:.4f}")
    
    # 保存最终模型
    server.save_model(config['model_save_path'])
    logger.info("联邦学习训练完成")

# 配置参数
config = {
    'seed': 42,
    'num_clients': 10,
    'clients_per_round': 5,
    'global_rounds': 50,
    'local_epochs': 5,
    'batch_size': 64,
    'learning_rate': 0.01,
    'iid': True,  # 是否IID数据分布
    'eval_every': 5,  # 每多少轮评估一次
    'model_save_path': 'federated_model.pth'
}

# 运行联邦学习
if __name__ == "__main__":
    run_federated_learning(config)

难点分析

  • 隐私保护与模型性能平衡:在保护隐私的同时保持模型准确性
  • 通信效率:减少客户端与服务器之间的通信开销
  • 异构设备处理:处理不同性能客户端的参与
  • 安全聚合协议:实现安全的模型参数聚合
  • 恶意参与者检测:识别和处理恶意参与方

扩展方向

  • 实现更高级的隐私保护技术(如差分隐私、同态加密)
  • 添加自适应学习率调整机制
  • 支持增量训练和持续学习
  • 开发联邦学习可视化监控界面
  • 实现跨平台联邦学习(移动端、边缘设备)
相关推荐
tomorrow.hello几秒前
Java并发测试工具
java·开发语言·测试工具
Edward-tan3 分钟前
CCPD 车牌数据集提取标注,并转为标准 YOLO 格式
python
晓131317 分钟前
JavaScript加强篇——第四章 日期对象与DOM节点(基础)
开发语言·前端·javascript
老胖闲聊19 分钟前
Python I/O 库【输入输出】全面详解
开发语言·python
倔强青铜三38 分钟前
苦练Python第18天:Python异常处理锦囊
人工智能·python·面试
倔强青铜三1 小时前
苦练Python第17天:你必须掌握的Python内置函数
人工智能·python·面试
迷路爸爸1801 小时前
让 VSCode 调试器像 PyCharm 一样显示 Tensor Shape、变量形状、变量长度、维度信息
ide·vscode·python·pycharm·debug·调试
她说人狗殊途2 小时前
java.net.InetAddress
java·开发语言
天使day2 小时前
Cursor的使用
java·开发语言·ai
咸鱼鲸2 小时前
【PyTorch】PyTorch中的数据预处理操作
人工智能·pytorch·python