Sum-rate计算

1.ZF

python 复制代码
import torch

def calc_sum_rate_corrected(H: torch.Tensor, soft_mask: torch.Tensor, p: float = 80.0, sigma2: float = 1.0) -> torch.Tensor:
    """
    基于 Top-k 软掩码选择的可导 sum-rate 计算 (修正版)。
    此函数修正了原实现中的矩阵维度和SINR计算错误。

    Args:
        H: 信道矩阵, 形状为 (N, K),N=天线数, K=用户数。
        soft_mask: 每个用户的选择概率掩码, 形状为 (K,),值在 [0, 1] 之间。
        p: 总发射功率。
        sigma2: 噪声功率。

    Returns:
        sum_rate: 可导的和速率标量。
    """
    N, K = H.shape
    device = H.device

    # 确保掩码和信道矩阵的数据类型一致
    if soft_mask.dtype != H.dtype:
        soft_mask = soft_mask.to(H.dtype)

    # 1. 根据soft_mask选出Top-k个最可能的用户
    topk = min(N, K)
    # 使用.real以防soft_mask是复数类型
    _, indices = torch.topk(soft_mask.real, topk)
    H_sel = H[:, indices]  # 形状: (N, topk)
    soft_mask_sel = soft_mask[indices]  # 形状: (topk,)

    # 2. 为了可导性,使用soft_mask对选出的信道进行加权
    # 这是您原始设计中有意为之的部分,予以保留
    H_weighted = H_sel * soft_mask_sel.unsqueeze(0)  # 形状: (N, topk)

    try:
        # 3. [修正] 计算ZF权重矩阵 W
        H_H = H_weighted.conj().T  # 形状: (topk, N)
        
        # [修正点 #2] Gram矩阵应为 (topk, topk)
        gram = H_H @ H_weighted
        
        # 为提高数值稳定性,加入微小的正则化项
        reg_eps = 1e-6
        gram_reg = gram + reg_eps * torch.eye(topk, device=device, dtype=gram.dtype)
        
        # [修正点 #3] 正确求解 W,其形状应为 (topk, N)
        # W 的每一行 w_k 是分配给用户k的波束成形向量
        W = torch.linalg.solve(gram_reg, H_H)

        # 4. [修正] 使用向量化方式计算所有用户的SINR
        
        # 分子 (Signal Power)
        # p_u 是分配给每个用户的功率
        p_u = p / topk
        # G是均衡后的等效信道,对角线元素是每个用户的信号增益
        G = W @ H_weighted  # 形状: (topk, topk)
        signal_gains = torch.diag(G)
        num = p_u * torch.abs(signal_gains) ** 2
        
        # 分母 (Noise Power)
        # [修正点 #5] 按行(dim=1)求和,计算每个用户权重向量w_k的模长平方
        noise_power = sigma2 * torch.sum(torch.abs(W) ** 2, dim=1)
        den = noise_power
        
        # 为防止除以零,在分母上增加一个极小值
        sinr = num / (den + 1e-12)
        
        # 5. 计算和速率
        rate = torch.log2(1 + sinr)
        return torch.sum(rate)

    except torch.linalg.LinAlgError:
        # 如果矩阵奇异无法求解,返回0
        return torch.tensor(0.0, device=device)
相关推荐
贾宝玉的玉宝贾几秒前
FreeSWITCH 简单图形化界面52 - 拨号应用 Answer 介绍
python·django·voip·freeswitch·sip·ippbx·jssip
Larry_Yanan几秒前
Qt多进程(十一)Linux下socket通信
linux·开发语言·c++·qt
Hello.Reader几秒前
PyFlink JAR、Python 包、requirements、虚拟环境、模型文件,远程集群怎么一次搞定?
java·python·jar
PeterClerk1 分钟前
深度学习-NLP 常见语料库
人工智能·深度学习·自然语言处理
代码游侠9 分钟前
学习笔记——ESP8266 WiFi模块
服务器·c语言·开发语言·数据结构·算法
0和1的舞者10 分钟前
Python 中四种核心数据结构的用途和嵌套逻辑
数据结构·python·学习·知识
weixin_4624462311 分钟前
Python 使用 PyQt5 + Pandas 实现 Excel(xlsx)批量合并工具(带图形界面)
python·qt·pandas
Hello.Reader13 分钟前
PyFlink Configuration 一次讲透怎么配、配哪些、怎么“调得快且稳”
运维·服务器·python·flink
行者9614 分钟前
Flutter跨平台开发适配OpenHarmony:进度条组件的深度实践
开发语言·前端·flutter·harmonyos·鸿蒙
白日做梦Q14 分钟前
实时语义分割:BiSeNet与Fast-SCNN深度对比与实践启示
人工智能·深度学习·计算机视觉