以前碰到过飞桨需要索引赋值的情况,issue:【论文复现】张量的bool索引和切片直接赋值问题 · Issue #37733 · PaddlePaddle/Paddle
也就是需要的功能是类似torch这样的语句:pos[:, group] -= offset.unsqueeze(1)
先上结论,飞桨现在支持索引切片直接赋值了。
最一开始飞桨没有相应的功能,所以
这4句话,实现的功能是torch的pos[:, group] -= offset.unsqueeze(1) ,有飞桨的api直接实现吗?
python
pos = pos.numpy()
offset = offset.unsqueeze(1).numpy()
pos[:, group] -= offset # 尝试使用numpy来处理 前后相关语句共4句
pos = paddle.to_tensor(pos)
飞桨的实现
python
import paddle
def tensor_index_subtract(pos, group, offset):
"""
实现类似PyTorch的pos[:, group] -= offset.unsqueeze(1)操作
支持动态图和静态图模式
"""
# 1. 扩展offset以匹配pos的批次维度
offset_expanded = offset.unsqueeze(1).expand([pos.shape[0], -1, -1])
# 2. 使用index_select提取group对应的列
selected_columns = paddle.index_select(pos, group, axis=2)
# 3. 执行减法操作
result = pos - selected_columns * offset_expanded
return result
第三种方法,直接执行减法最简单
python
import paddle
def tensor_index_subtract(pos, group, offset):
# 直接使用飞桨API实现索引减法
pos = pos - paddle.index_select(offset, group, axis=1).unsqueeze(1)
return pos
现在的问题是,要测试这样的减法是否能对齐
还有这句
python
result = pos - selected_columns * offset_expanded
codebuddy的解决方案
实现方案总结
对于 PyTorch 中的 pos[:, group] -= offset.unsqueeze(1) 操作,我提供了几种 PaddlePaddle 的实现方法:
🎯 方法1: 直接索引赋值(推荐)
python
result = pos.clone()
result[:, group] -= offset.unsqueeze(1)
优点: 最接近原始 PyTorch 语法,性能最好,完全等价
🎯 方法2: 矩阵乘法(你提到的)
python
# 创建选择矩阵
selected_columns = paddle.zeros([pos.shape[0], len(group), pos.shape[1]])
for i, atom_idx in enumerate(group):
selected_columns[:, i, :] = 1.0
# 扩展偏移并应用
offset_expanded = offset.unsqueeze(1)
result = pos - selected_columns * offset_expanded
🎯 方法3: 掩码操作
python
result = pos.clone()
mask = paddle.zeros([pos.shape[1]], dtype='bool')
mask[group] = True
result[:, mask] -= offset.unsqueeze(1)
🔍 关键点分析
-
维度理解:
pos[:, group]选择指定列的原子offset.unsqueeze(1)将偏移从[batch, 3]扩展为[batch, 1, 3]
-
操作本质:
- 只有组内原子的位置被修改
- 每个组内原子减去相同的偏移向量
- 组外原子位置不变
-
矩阵乘法方法逻辑:
- 创建选择矩阵:组内原子对应位置为1,其他为0
selected_columns * offset_expanded只在组内原子位置产生偏移pos - ...实现减法操作
✅ 推荐使用方法1
最直接且高效的实现是方法1(直接索引赋值),它:
- 完全等价于 PyTorch 原始操作
- 性能最优
- 代码最简洁
- 避免了额外的矩阵创建
python
>>> # 创建测试数据
>>> pos_torch = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
>>> offset_torch = torch.tensor([1, 2], dtype=torch.float32)
>>> pos_torch[:, group_torch] -= offset_torch.unsqueeze(1)
>>> print("PyTorch 结果:", pos_torch.numpy())
PyTorch 结果: [[0. 2. 2.]
[2. 5. 4.]]
>>>
>>> # 飞桨实现
>>> pos_paddle = pos.clone()
>>> pos_paddle[:, group] -= offset.unsqueeze(1)
>>> print("飞桨结果:", pos_paddle.numpy())
飞桨结果: [[0. 2. 2.]
[2. 5. 4.]]
我不明白为什么要加一句pos.clone
专门测试一下clone是否可以去掉
python
cd "g:/codebuddy/md" && python simple_clone_test.py
开始测试pos.clone()的影响...
原始pos:
Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[ 0.30147934, 0.28241217, -0.40545371],
[-0.46399701, -0.23090385, 0.27370819],
[ 1.19101310, 0.18119779, 1.70754528],
[ 0.54156667, -1.72120285, -1.54957402],
[-0.26763844, 1.14153934, -0.02156020]])
group: Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
[0, 2])
offset: Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.50000000])
=== 方法1:使用clone ===
复制后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[ 0.30147934, 0.28241217, -0.40545371],
[-0.46399701, -0.23090385, 0.27370819],
[ 1.19101310, 0.18119779, 1.70754528],
[ 0.54156667, -1.72120285, -1.54957402],
[-0.26763844, 1.14153934, -0.02156020]])
修改后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[-0.19852066, 0.28241217, -0.90545368],
[-0.96399701, -0.23090385, -0.22629181],
[ 0.69101310, 0.18119779, 1.20754528],
[ 0.04156667, -1.72120285, -2.04957390],
[-0.76763844, 1.14153934, -0.52156019]])
=== 方法2:不使用clone ===
直接使用: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[ 0.30147934, 0.28241217, -0.40545371],
[-0.46399701, -0.23090385, 0.27370819],
[ 1.19101310, 0.18119779, 1.70754528],
[ 0.54156667, -1.72120285, -1.54957402],
[-0.26763844, 1.14153934, -0.02156020]])
修改后: Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[-0.19852066, 0.28241217, -0.90545368],
[-0.96399701, -0.23090385, -0.22629181],
[ 0.69101310, 0.18119779, 1.20754528],
[ 0.04156667, -1.72120285, -2.04957390],
[-0.76763844, 1.14153934, -0.52156019]])
=== 结果比较 ===
两种方法结果是否一致: True
差异的最大绝对值: 0.0
所以这个clone是可以去掉的。
paddlemd里最后使用的代码
最后forces.py的代码如下:
python
import paddle
class Wrapper:
def __init__(self, natoms, bonds):
# 不传入device参数,保持与torchmd的wrapper兼容
self.groups, self.nongrouped = calculate_molecule_groups(natoms, bonds)
def wrap(self, pos, box, wrapidx=None):
nmol = len(self.groups)
original_box_shape = box.shape
batch_size = pos.shape[0]
# 正确处理box,确保batch维度匹配
if len(box.shape) == 3: # [batch, 3, 3]
# 提取对角线元素
box_diag = paddle.diagonal(box, offset=0, axis1=1, axis2=2) # [batch, 3]
# 确保box的batch大小与pos一致
if box_diag.shape[0] != batch_size:
if box_diag.shape[0] == 1: # 如果box只有1个batch,扩展到pos的batch大小
box = box_diag.expand([batch_size, 3])
else: # 否则,取第一个batch的box
box = box_diag[0:1, :].expand([batch_size, 3])
else:
box = box_diag
elif len(box.shape) == 2: # [3, 3]
# 提取对角线元素,扩展为batch维度
box_diag = paddle.diagonal(box, offset=0, axis1=0, axis2=1) # [3]
box = box_diag.unsqueeze(0).expand([batch_size, 3]) # [batch, 3]
elif len(box.shape) == 1: # [3]
# 扩展为batch维度
box = box.unsqueeze(0).expand([batch_size, 3]) # [batch, 3]
elif len(box.shape) == 2 and box.shape[0] == batch_size: # [batch, 3]
# 已经是[batch, 3]形状,直接使用
box = box
else:
raise ValueError(f"不支持的box形状: {original_box_shape}")
if paddle.all(box == 0):
return
if wrapidx is not None:
# Get COM of wrapping center group
com = paddle.sum(pos[:, wrapidx], axis=1) / len(wrapidx) # [batch, 3]
# Subtract COM from all atoms so that the center mol is at [box/2, box/2, box/2]
com_expanded = com.unsqueeze(1) # [batch, 1, 3]
box_half_expanded = (box / 2).unsqueeze(1) # [batch, 1, 3]
pos = (pos - com_expanded) + box_half_expanded
if nmol != 0:
# Work out the COMs and offsets of every group and move group to [0, box] range
for i, group in enumerate(self.groups):
selected_columns = pos[:, group] # [batch, group_size, 3]
tmp_com = paddle.sum(selected_columns, axis=1) / len(group) # [batch, 3]
# Calculate offset with proper broadcasting (与torchmd一致)
tmp_com_expanded = tmp_com.unsqueeze(1) # [batch, 1, 3]
box_expanded = box.unsqueeze(1) # [batch, 1, 3] (适配batch情况)
offset = paddle.floor(tmp_com_expanded / box_expanded) * box_expanded # [batch, 1, 3]
pos[:, group] -= offset
# Move non-grouped atoms
if len(self.nongrouped):
selected_columns = pos[:, self.nongrouped] # [batch, num_atoms, 3]
# For non-grouped atoms, calculate offset for each atom individually
box_expanded = box.unsqueeze(1) # [batch, 1, 3] (适配batch情况)
offset = paddle.floor(selected_columns / box_expanded) * box_expanded # [batch, num_atoms, 3]
pos[:, self.nongrouped] -= offset
def calculate_molecule_groups(natoms, bonds):
import networkx as nx
# Calculate molecule groups and non-bonded / non-grouped atoms
if bonds is not None and len(bonds):
bondGraph = nx.Graph()
bondGraph.add_nodes_from(range(natoms))
# 移除numpy依赖,直接处理bonds
# 处理不同类型的bonds输入(numpy数组或paddle张量)
if hasattr(bonds, 'tolist'):
# 直接转换为列表
bond_edges = bonds.tolist()
else:
# 如果是普通列表,直接使用
bond_edges = bonds
# 确保每个边的元素都是整数
bond_edges = [[int(edge[0]), int(edge[1])] for edge in bond_edges]
bondGraph.add_edges_from(bond_edges)
molgroups = list(nx.connected_components(bondGraph))
nongrouped = paddle.to_tensor(
[list(group)[0] for group in molgroups if len(group) == 1]
)
molgroups = [
paddle.to_tensor(list(group))
for group in molgroups
if len(group) > 1
]
else:
molgroups = []
nongrouped = paddle.arange(0, natoms)
return molgroups, nongrouped