torch的pos[:, group] -= offset.unsqueeze(1) 飞桨api怎样实现?

以前碰到过飞桨需要索引赋值的情况,issue:【论文复现】张量的bool索引和切片直接赋值问题 · Issue #37733 · PaddlePaddle/Paddle

也就是需要的功能是类似torch这样的语句:pos[:, group] -= offset.unsqueeze(1)

先上结论,飞桨现在支持索引切片直接赋值了。

最一开始飞桨没有相应的功能,所以

这4句话,实现的功能是torch的pos[:, group] -= offset.unsqueeze(1) ,有飞桨的api直接实现吗?

python 复制代码
                pos = pos.numpy()
                offset = offset.unsqueeze(1).numpy()
                pos[:, group] -= offset  # 尝试使用numpy来处理 前后相关语句共4句
                pos = paddle.to_tensor(pos)

飞桨的实现

python 复制代码
import paddle

def tensor_index_subtract(pos, group, offset):
    """
    实现类似PyTorch的pos[:, group] -= offset.unsqueeze(1)操作
    支持动态图和静态图模式
    """
    # 1. 扩展offset以匹配pos的批次维度
    offset_expanded = offset.unsqueeze(1).expand([pos.shape[0], -1, -1])
    
    # 2. 使用index_select提取group对应的列
    selected_columns = paddle.index_select(pos, group, axis=2)
    
    # 3. 执行减法操作
    result = pos - selected_columns * offset_expanded
    
    return result

第三种方法,直接执行减法最简单

python 复制代码
import paddle

def tensor_index_subtract(pos, group, offset):
    # 直接使用飞桨API实现索引减法
    pos = pos - paddle.index_select(offset, group, axis=1).unsqueeze(1)
    return pos

现在的问题是,要测试这样的减法是否能对齐

还有这句

python 复制代码
 result = pos - selected_columns * offset_expanded

codebuddy的解决方案

实现方案总结

对于 PyTorch 中的 pos[:, group] -= offset.unsqueeze(1) 操作,我提供了几种 PaddlePaddle 的实现方法:

🎯 方法1: 直接索引赋值(推荐)

python

复制代码
result = pos.clone()
result[:, group] -= offset.unsqueeze(1)

优点: 最接近原始 PyTorch 语法,性能最好,完全等价

🎯 方法2: 矩阵乘法(你提到的)

python

复制代码
# 创建选择矩阵
selected_columns = paddle.zeros([pos.shape[0], len(group), pos.shape[1]])
for i, atom_idx in enumerate(group):
    selected_columns[:, i, :] = 1.0

# 扩展偏移并应用
offset_expanded = offset.unsqueeze(1)
result = pos - selected_columns * offset_expanded

🎯 方法3: 掩码操作

python

复制代码
result = pos.clone()
mask = paddle.zeros([pos.shape[1]], dtype='bool')
mask[group] = True
result[:, mask] -= offset.unsqueeze(1)

🔍 关键点分析

  1. 维度理解:

    • pos[:, group] 选择指定列的原子
    • offset.unsqueeze(1) 将偏移从 [batch, 3] 扩展为 [batch, 1, 3]
  2. 操作本质:

    • 只有组内原子的位置被修改
    • 每个组内原子减去相同的偏移向量
    • 组外原子位置不变
  3. 矩阵乘法方法逻辑:

    • 创建选择矩阵:组内原子对应位置为1,其他为0
    • selected_columns * offset_expanded 只在组内原子位置产生偏移
    • pos - ... 实现减法操作

推荐使用方法1

最直接且高效的实现是方法1(直接索引赋值),它:

  • 完全等价于 PyTorch 原始操作
  • 性能最优
  • 代码最简洁
  • 避免了额外的矩阵创建
python 复制代码
>>> # 创建测试数据
>>> pos_torch = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)  
>>> offset_torch = torch.tensor([1, 2], dtype=torch.float32)
>>> pos_torch[:, group_torch] -= offset_torch.unsqueeze(1)
>>> print("PyTorch 结果:", pos_torch.numpy())
PyTorch 结果: [[0. 2. 2.]
 [2. 5. 4.]]
>>>
>>> # 飞桨实现
>>> pos_paddle = pos.clone()
>>> pos_paddle[:, group] -= offset.unsqueeze(1)
>>> print("飞桨结果:", pos_paddle.numpy())
飞桨结果: [[0. 2. 2.]
 [2. 5. 4.]]

我不明白为什么要加一句pos.clone

专门测试一下clone是否可以去掉

python 复制代码
cd "g:/codebuddy/md" && python simple_clone_test.py

开始测试pos.clone()的影响...
原始pos:
Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[ 0.30147934,  0.28241217, -0.40545371],
        [-0.46399701, -0.23090385,  0.27370819],
        [ 1.19101310,  0.18119779,  1.70754528],
        [ 0.54156667, -1.72120285, -1.54957402],
        [-0.26763844,  1.14153934, -0.02156020]])
group: Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
       [0, 2])
offset: Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
       [0.50000000])

=== 方法1:使用clone ===
复制后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[ 0.30147934,  0.28241217, -0.40545371],
        [-0.46399701, -0.23090385,  0.27370819],
        [ 1.19101310,  0.18119779,  1.70754528],
        [ 0.54156667, -1.72120285, -1.54957402],
        [-0.26763844,  1.14153934, -0.02156020]])
修改后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[-0.19852066,  0.28241217, -0.90545368],
        [-0.96399701, -0.23090385, -0.22629181],
        [ 0.69101310,  0.18119779,  1.20754528],
        [ 0.04156667, -1.72120285, -2.04957390],
        [-0.76763844,  1.14153934, -0.52156019]])

=== 方法2:不使用clone ===
直接使用: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[ 0.30147934,  0.28241217, -0.40545371],
        [-0.46399701, -0.23090385,  0.27370819],
        [ 1.19101310,  0.18119779,  1.70754528],
        [ 0.54156667, -1.72120285, -1.54957402],
        [-0.26763844,  1.14153934, -0.02156020]])
修改后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[-0.19852066,  0.28241217, -0.90545368],
        [-0.96399701, -0.23090385, -0.22629181],
        [ 0.69101310,  0.18119779,  1.20754528],
        [ 0.04156667, -1.72120285, -2.04957390],
        [-0.76763844,  1.14153934, -0.52156019]])

=== 结果比较 ===
两种方法结果是否一致: True
差异的最大绝对值: 0.0

所以这个clone是可以去掉的。

paddlemd里最后使用的代码

最后forces.py的代码如下:

python 复制代码
import paddle


class Wrapper:
    def __init__(self, natoms, bonds):
        # 不传入device参数,保持与torchmd的wrapper兼容
        self.groups, self.nongrouped = calculate_molecule_groups(natoms, bonds)

    def wrap(self, pos, box, wrapidx=None):
        nmol = len(self.groups)
        original_box_shape = box.shape
        batch_size = pos.shape[0]
        
        # 正确处理box,确保batch维度匹配
        if len(box.shape) == 3:  # [batch, 3, 3]
            # 提取对角线元素
            box_diag = paddle.diagonal(box, offset=0, axis1=1, axis2=2)  # [batch, 3]
            # 确保box的batch大小与pos一致
            if box_diag.shape[0] != batch_size:
                if box_diag.shape[0] == 1:  # 如果box只有1个batch,扩展到pos的batch大小
                    box = box_diag.expand([batch_size, 3])
                else:  # 否则,取第一个batch的box
                    box = box_diag[0:1, :].expand([batch_size, 3])
            else:
                box = box_diag
        elif len(box.shape) == 2:  # [3, 3]
            # 提取对角线元素,扩展为batch维度
            box_diag = paddle.diagonal(box, offset=0, axis1=0, axis2=1)  # [3]
            box = box_diag.unsqueeze(0).expand([batch_size, 3])  # [batch, 3]
        elif len(box.shape) == 1:  # [3]
            # 扩展为batch维度
            box = box.unsqueeze(0).expand([batch_size, 3])  # [batch, 3]
        elif len(box.shape) == 2 and box.shape[0] == batch_size:  # [batch, 3]
            # 已经是[batch, 3]形状,直接使用
            box = box
        else:
            raise ValueError(f"不支持的box形状: {original_box_shape}")
        
        if paddle.all(box == 0):
            return

        if wrapidx is not None:
            # Get COM of wrapping center group
            com = paddle.sum(pos[:, wrapidx], axis=1) / len(wrapidx)  # [batch, 3]
            # Subtract COM from all atoms so that the center mol is at [box/2, box/2, box/2]
            com_expanded = com.unsqueeze(1)  # [batch, 1, 3]
            box_half_expanded = (box / 2).unsqueeze(1)  # [batch, 1, 3]
            pos = (pos - com_expanded) + box_half_expanded

        if nmol != 0:
            # Work out the COMs and offsets of every group and move group to [0, box] range
            for i, group in enumerate(self.groups):
                selected_columns = pos[:, group]  # [batch, group_size, 3]
                tmp_com = paddle.sum(selected_columns, axis=1) / len(group)  # [batch, 3]
                # Calculate offset with proper broadcasting (与torchmd一致)
                tmp_com_expanded = tmp_com.unsqueeze(1)  # [batch, 1, 3]
                box_expanded = box.unsqueeze(1)  # [batch, 1, 3] (适配batch情况)
                offset = paddle.floor(tmp_com_expanded / box_expanded) * box_expanded  # [batch, 1, 3]
                pos[:, group] -= offset

        # Move non-grouped atoms
        if len(self.nongrouped):
            selected_columns = pos[:, self.nongrouped]  # [batch, num_atoms, 3]
            # For non-grouped atoms, calculate offset for each atom individually
            box_expanded = box.unsqueeze(1)  # [batch, 1, 3] (适配batch情况)
            offset = paddle.floor(selected_columns / box_expanded) * box_expanded  # [batch, num_atoms, 3]
            pos[:, self.nongrouped] -= offset


def calculate_molecule_groups(natoms, bonds):
    import networkx as nx

    # Calculate molecule groups and non-bonded / non-grouped atoms
    if bonds is not None and len(bonds):
        bondGraph = nx.Graph()
        bondGraph.add_nodes_from(range(natoms))
        # 移除numpy依赖,直接处理bonds
        # 处理不同类型的bonds输入(numpy数组或paddle张量)
        if hasattr(bonds, 'tolist'):
            # 直接转换为列表
            bond_edges = bonds.tolist()
        else:
            # 如果是普通列表,直接使用
            bond_edges = bonds
        # 确保每个边的元素都是整数
        bond_edges = [[int(edge[0]), int(edge[1])] for edge in bond_edges]
        bondGraph.add_edges_from(bond_edges)
        molgroups = list(nx.connected_components(bondGraph))

        nongrouped = paddle.to_tensor(
            [list(group)[0] for group in molgroups if len(group) == 1]
        )
        molgroups = [
            paddle.to_tensor(list(group))
            for group in molgroups
            if len(group) > 1
        ]
    else:
        molgroups = []
        nongrouped = paddle.arange(0, natoms)
    return molgroups, nongrouped
相关推荐
小程故事多_801 小时前
AI Agent进阶架构:用渐进式披露驯服复杂性
人工智能·架构
人工智能AI技术2 小时前
【Agent从入门到实践】10 决策模块:Agent如何“思考问题”
人工智能
qq_527887872 小时前
联邦经典算法Fedavg实现
人工智能·深度学习
天天讯通2 小时前
数据公司与AI五大主流合作模式
人工智能
Clarence Liu3 小时前
AI Agent开发(2) - 深入解析 A2A 协议与 Go 实战指南
开发语言·人工智能·golang
综合热讯3 小时前
AUS GLOBAL 荣耀赞助 2026 LIL TOUR 高尔夫嘉年华
人工智能
小饼干超人3 小时前
详解向量数据库中的PQ算法(Product Quantization)
人工智能·算法·机器学习
砚边数影4 小时前
AI数学基础(一):线性代数核心,向量/矩阵运算的Java实现
java·数据库·人工智能·线性代数·矩阵·ai编程·金仓数据库
互联网科技看点4 小时前
诸葛io获认可:金融分析智能体赛道领航者
大数据·人工智能·金融
engchina4 小时前
自然语言转 SQL 并不是“魔法”
数据库·人工智能·sql·text2sql·nl2sql·自然语言转sql