Pytorch中张量的维度扩张与广播操作示例

广播操作 允许你对不同形状的张量执行逐元素操作,而无需显式循环

一个关于分子坐标离散格点化的实战例子

python 复制代码
def cdists(mols, grid):
    '''
    Calculates the pairwise Euclidean distances between a set of molecules and a list
    of positions on a grid (uses inplace operations to minimize memory demands).

    Args:
        mols (torch.Tensor): data set (of molecules) with shape
            (batch_size x n_atoms x n_dims)
        grid (torch.Tensor): array (of positions) with shape (n_positions x n_dims)

    Returns:
        torch.Tensor: batch of distance matrices (batch_size x n_atoms x n_positions)
    '''
    if len(mols.size()) == len(grid.size())+1:
        grid = grid.unsqueeze(0)  # add batch dimension
    return F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1),
                  inplace=True).sqrt_()

那么,上面的代码中为什么要进行: (mols[:, :, None, :] - grid[:, None, :, :]) 这样的操作呢?

这段代码用于计算一组分子(mols)与一个网格上的一组位置格点grid)之间的欧几里得距离。这里的(mols[:, :, None, :] - grid[:, None, :, :])操作涉及到张量的广播操作,它的目的是计算每个分子的每个原子与每个位置之间的距离。

这段代码的工作原理:

  1. mols 张量的形状为 (batch_size x n_atoms x n_dims),其中 batch_size 是批次大小,n_atoms 是原子数量,n_dims 是原子的坐标维度(通常是3维,表示xyz坐标)。
  2. grid 张量的形状为 (n_positions x n_dims),其中 n_positions 是位置格点的数量,n_dims 同样是坐标的维度。

首先,如果 mols 张量的维度比 grid 张量的维度多1,代码会通过 grid.unsqueeze(0) 添加一个额外的维度,以匹配 mols 张量的 batch 维度。这是为了使广播操作生效。

接下来,代码使用广播操作计算每个分子的每个原子与每个位置之间的距离。广播操作允许你对不具有相同形状的张量执行逐元素操作,而无需显式循环。

  • mols[:, :, None, :] 的形状变成 (batch_size x n_atoms x 1 x n_dims)。这个操作在 n_atoms 维度上添加了一个额外的维度,以便与 grid[:, None, :, :] 进行广播。

  • grid[:, None, :, :] 的形状变成 (1 x 1 x n_positions x n_dims)。这个操作在 batch 维度和 n_atoms 维度上添加了额外的维度,以便与 mols[:, :, None, :] 进行广播。

  • 然后,这两个张量进行逐元素的减法操作,计算了每个分子的每个原子与每个位置格点之间的差异。结果是一个张量,其形状为 (batch_size x n_atoms x n_positions x n_dims)

  • 最后,使用 .pow_(2).sqrt_() 操作,计算了每个分子的每个原子与每个位置格点之间的欧几里得距离。

总之,这段代码通过广播操作高效地计算了每个分子的每个原子与每个位置之间的距离,而无需显式的循环操作。这有助于提高计算效率,特别是在处理大规模数据时。

那么带入一个随机生成的数据进行举例:

grid_test.py :

python 复制代码
import torch
import torch.nn.functional as F

# 示例数据
batch_size = 2
n_atoms = 5
n_dims = 3
n_positions = 10

mols = torch.rand(batch_size, n_atoms, n_dims)  # 随机生成分子坐标数据
grid = torch.rand(n_positions, n_dims)  # 随机生成网格位置数据

print("batch_size = ", batch_size,"    n_atoms = ", n_atoms,"    n_dims = ",n_dims,"    n_positions = ",n_positions)

# 打印示例数据
print("示例数据 mols:")
print("mols = torch.rand(batch_size, n_atoms, n_dims)")
print(mols)
print(mols.shape)

print("\n示例数据 grid:")
print("grid = torch.rand(n_positions, n_dims)")
print(grid)
print(grid.shape)

# 如果 mols 张量的维度比 grid 张量的维度多1,添加一个额外的维度
if len(mols.size()) == len(grid.size()) + 1:
    grid = grid.unsqueeze(0)
    print("\n添加额外维度后的 grid:")
    print(grid)
    print(grid.shape)

print("\nmols[:, :, None, :]")
print(mols[:, :, None, :])
print(mols[:, :, None, :].shape)

print("\ngrid[:, None, :, :]")
print(grid[:, None, :, :])
print(grid[:, None, :, :].shape)

print("\nmols[:, :, None, :] - grid[:, None, :, :]")
print(mols[:, :, None, :] - grid[:, None, :, :])
print((mols[:, :, None, :] - grid[:, None, :, :]).shape)

print("\n(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)")
print((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2))

print("\ntorch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)")
print(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1))
print((torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)).shape)

print("\nF.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True))
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)).shape)

print("\nF.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_())
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()).shape)

# 计算每个分子的每个原子与每个位置之间的距离
result = F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()

# 打印计算结果
print("\n计算结果:")
print(result)

输出结果:

bash 复制代码
$ python grid_test.py 
bash 复制代码
batch_size =  2     n_atoms =  5     n_dims =  3     n_positions =  10
示例数据 mols:
mols = torch.rand(batch_size, n_atoms, n_dims)
tensor([[[0.3787, 0.1093, 0.5062],
         [0.3149, 0.4295, 0.1202],
         [0.6499, 0.6533, 0.6489],
         [0.9395, 0.5027, 0.7664],
         [0.5991, 0.5733, 0.6474]],

        [[0.1370, 0.3499, 0.7365],
         [0.3564, 0.4096, 0.1820],
         [0.2576, 0.4737, 0.2487],
         [0.3169, 0.5875, 0.0414],
         [0.9958, 0.2101, 0.3953]]])
torch.Size([2, 5, 3])

示例数据 grid:
grid = torch.rand(n_positions, n_dims)
tensor([[0.0985, 0.5795, 0.3998],
        [0.3772, 0.8160, 0.7968],
        [0.9571, 0.0205, 0.5068],
        [0.7847, 0.9675, 0.3421],
        [0.9007, 0.0692, 0.3701],
        [0.8763, 0.4045, 0.2783],
        [0.5665, 0.1797, 0.8626],
        [0.4253, 0.6738, 0.3789],
        [0.3690, 0.3504, 0.3530],
        [0.1773, 0.4790, 0.9227]])
torch.Size([10, 3])

添加额外维度后的 grid:
tensor([[[0.0985, 0.5795, 0.3998],
         [0.3772, 0.8160, 0.7968],
         [0.9571, 0.0205, 0.5068],
         [0.7847, 0.9675, 0.3421],
         [0.9007, 0.0692, 0.3701],
         [0.8763, 0.4045, 0.2783],
         [0.5665, 0.1797, 0.8626],
         [0.4253, 0.6738, 0.3789],
         [0.3690, 0.3504, 0.3530],
         [0.1773, 0.4790, 0.9227]]])
torch.Size([1, 10, 3])

mols[:, :, None, :]
tensor([[[[0.3787, 0.1093, 0.5062]],

         [[0.3149, 0.4295, 0.1202]],

         [[0.6499, 0.6533, 0.6489]],

         [[0.9395, 0.5027, 0.7664]],

         [[0.5991, 0.5733, 0.6474]]],


        [[[0.1370, 0.3499, 0.7365]],

         [[0.3564, 0.4096, 0.1820]],

         [[0.2576, 0.4737, 0.2487]],

         [[0.3169, 0.5875, 0.0414]],

         [[0.9958, 0.2101, 0.3953]]]])
torch.Size([2, 5, 1, 3])

grid[:, None, :, :]
tensor([[[[0.0985, 0.5795, 0.3998],
          [0.3772, 0.8160, 0.7968],
          [0.9571, 0.0205, 0.5068],
          [0.7847, 0.9675, 0.3421],
          [0.9007, 0.0692, 0.3701],
          [0.8763, 0.4045, 0.2783],
          [0.5665, 0.1797, 0.8626],
          [0.4253, 0.6738, 0.3789],
          [0.3690, 0.3504, 0.3530],
          [0.1773, 0.4790, 0.9227]]]])
torch.Size([1, 1, 10, 3])

mols[:, :, None, :] - grid[:, None, :, :]
tensor([[[[ 2.8014e-01, -4.7024e-01,  1.0644e-01],
          [ 1.4571e-03, -7.0671e-01, -2.9065e-01],
          [-5.7839e-01,  8.8782e-02, -5.7650e-04],
          [-4.0609e-01, -8.5820e-01,  1.6411e-01],
          [-5.2208e-01,  4.0115e-02,  1.3611e-01],
          [-4.9762e-01, -2.9525e-01,  2.2792e-01],
          [-1.8779e-01, -7.0417e-02, -3.5645e-01],
          [-4.6644e-02, -5.6448e-01,  1.2727e-01],
          [ 9.6272e-03, -2.4109e-01,  1.5320e-01],
          [ 2.0133e-01, -3.6968e-01, -4.1653e-01]],

         [[ 2.1635e-01, -1.4999e-01, -2.7958e-01],
          [-6.2337e-02, -3.8646e-01, -6.7667e-01],
          [-6.4219e-01,  4.0903e-01, -3.8660e-01],
          [-4.6988e-01, -5.3796e-01, -2.2191e-01],
          [-5.8587e-01,  3.6036e-01, -2.4991e-01],
          [-5.6142e-01,  2.4999e-02, -1.5810e-01],
          [-2.5159e-01,  2.4983e-01, -7.4247e-01],
          [-1.1044e-01, -2.4423e-01, -2.5875e-01],
          [-5.4167e-02,  7.9156e-02, -2.3282e-01],
          [ 1.3754e-01, -4.9439e-02, -8.0255e-01]],

         [[ 5.5133e-01,  7.3810e-02,  2.4912e-01],
          [ 2.7265e-01, -1.6266e-01, -1.4796e-01],
          [-3.0720e-01,  6.3283e-01,  1.4211e-01],
          [-1.3490e-01, -3.1416e-01,  3.0679e-01],
          [-2.5089e-01,  5.8416e-01,  2.7880e-01],
          [-2.2644e-01,  2.4880e-01,  3.7061e-01],
          [ 8.3396e-02,  4.7363e-01, -2.1377e-01],
          [ 2.2454e-01, -2.0431e-02,  2.6995e-01],
          [ 2.8082e-01,  3.0296e-01,  2.9588e-01],
          [ 4.7252e-01,  1.7436e-01, -2.7384e-01]],

         [[ 8.4102e-01, -7.6851e-02,  3.6667e-01],
          [ 5.6233e-01, -3.1332e-01, -3.0420e-02],
          [-1.7516e-02,  4.8217e-01,  2.5965e-01],
          [ 1.5479e-01, -4.6482e-01,  4.2434e-01],
          [ 3.8800e-02,  4.3350e-01,  3.9634e-01],
          [ 6.3253e-02,  9.8140e-02,  4.8815e-01],
          [ 3.7309e-01,  3.2297e-01, -9.6225e-02],
          [ 5.1423e-01, -1.7109e-01,  3.8750e-01],
          [ 5.7050e-01,  1.5230e-01,  4.1342e-01],
          [ 7.6221e-01,  2.3702e-02, -1.5630e-01]],

         [[ 5.0062e-01, -6.2699e-03,  2.4768e-01],
          [ 2.2193e-01, -2.4274e-01, -1.4941e-01],
          [-3.5792e-01,  5.5275e-01,  1.4067e-01],
          [-1.8561e-01, -3.9424e-01,  3.0535e-01],
          [-3.0160e-01,  5.0408e-01,  2.7735e-01],
          [-2.7715e-01,  1.6872e-01,  3.6916e-01],
          [ 3.2683e-02,  3.9355e-01, -2.1521e-01],
          [ 1.7383e-01, -1.0051e-01,  2.6851e-01],
          [ 2.3010e-01,  2.2288e-01,  2.9444e-01],
          [ 4.2181e-01,  9.4283e-02, -2.7528e-01]]],


        [[[ 3.8479e-02, -2.2963e-01,  3.3670e-01],
          [-2.4021e-01, -4.6610e-01, -6.0388e-02],
          [-8.2006e-01,  3.2939e-01,  2.2968e-01],
          [-6.4775e-01, -6.1759e-01,  3.9437e-01],
          [-7.6374e-01,  2.8072e-01,  3.6637e-01],
          [-7.3929e-01, -5.4636e-02,  4.5818e-01],
          [-4.2946e-01,  1.7019e-01, -1.2619e-01],
          [-2.8831e-01, -3.2387e-01,  3.5753e-01],
          [-2.3204e-01, -4.7952e-04,  3.8346e-01],
          [-4.0334e-02, -1.2907e-01, -1.8627e-01]],

         [[ 2.5793e-01, -1.6992e-01, -2.1771e-01],
          [-2.0759e-02, -4.0639e-01, -6.1480e-01],
          [-6.0061e-01,  3.8910e-01, -3.2472e-01],
          [-4.2830e-01, -5.5788e-01, -1.6004e-01],
          [-5.4429e-01,  3.4043e-01, -1.8804e-01],
          [-5.1984e-01,  5.0723e-03, -9.6227e-02],
          [-2.1001e-01,  2.2990e-01, -6.8060e-01],
          [-6.8861e-02, -2.6416e-01, -1.9688e-01],
          [-1.2589e-02,  5.9229e-02, -1.7095e-01],
          [ 1.7911e-01, -6.9366e-02, -7.4067e-01]],

         [[ 1.5904e-01, -1.0583e-01, -1.5102e-01],
          [-1.1965e-01, -3.4230e-01, -5.4810e-01],
          [-6.9950e-01,  4.5319e-01, -2.5803e-01],
          [-5.2719e-01, -4.9380e-01, -9.3347e-02],
          [-6.4318e-01,  4.0452e-01, -1.2134e-01],
          [-6.1873e-01,  6.9158e-02, -2.9533e-02],
          [-3.0890e-01,  2.9399e-01, -6.1391e-01],
          [-1.6775e-01, -2.0007e-01, -1.3018e-01],
          [-1.1148e-01,  1.2331e-01, -1.0426e-01],
          [ 8.0225e-02, -5.2802e-03, -6.7398e-01]],

         [[ 2.1837e-01,  8.0081e-03, -3.5833e-01],
          [-6.0314e-02, -2.2846e-01, -7.5541e-01],
          [-6.4016e-01,  5.6703e-01, -4.6534e-01],
          [-4.6786e-01, -3.7996e-01, -3.0066e-01],
          [-5.8385e-01,  5.1836e-01, -3.2865e-01],
          [-5.5940e-01,  1.8300e-01, -2.3684e-01],
          [-2.4956e-01,  4.0783e-01, -8.2122e-01],
          [-1.0842e-01, -8.6233e-02, -3.3749e-01],
          [-5.2144e-02,  2.3716e-01, -3.1157e-01],
          [ 1.3956e-01,  1.0856e-01, -8.8129e-01]],

         [[ 8.9723e-01, -3.6938e-01, -4.4795e-03],
          [ 6.1855e-01, -6.0585e-01, -4.0157e-01],
          [ 3.8695e-02,  1.8964e-01, -1.1149e-01],
          [ 2.1100e-01, -7.5734e-01,  5.3189e-02],
          [ 9.5011e-02,  1.4097e-01,  2.5192e-02],
          [ 1.1946e-01, -1.9439e-01,  1.1700e-01],
          [ 4.2930e-01,  3.0443e-02, -4.6737e-01],
          [ 5.7044e-01, -4.6362e-01,  1.6351e-02],
          [ 6.2672e-01, -1.4023e-01,  4.2277e-02],
          [ 8.1842e-01, -2.6882e-01, -5.2744e-01]]]])
torch.Size([2, 5, 10, 3])

(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)
tensor([[[[7.8481e-02, 2.2112e-01, 1.1329e-02],
          [2.1231e-06, 4.9944e-01, 8.4477e-02],
          [3.3454e-01, 7.8823e-03, 3.3235e-07],
          [1.6491e-01, 7.3651e-01, 2.6931e-02],
          [2.7257e-01, 1.6092e-03, 1.8526e-02],
          [2.4763e-01, 8.7170e-02, 5.1948e-02],
          [3.5266e-02, 4.9586e-03, 1.2706e-01],
          [2.1757e-03, 3.1863e-01, 1.6198e-02],
          [9.2683e-05, 5.8124e-02, 2.3469e-02],
          [4.0534e-02, 1.3667e-01, 1.7349e-01]],

         [[4.6807e-02, 2.2498e-02, 7.8165e-02],
          [3.8859e-03, 1.4935e-01, 4.5788e-01],
          [4.1240e-01, 1.6730e-01, 1.4946e-01],
          [2.2079e-01, 2.8940e-01, 4.9245e-02],
          [3.4325e-01, 1.2986e-01, 6.2455e-02],
          [3.1519e-01, 6.2496e-04, 2.4995e-02],
          [6.3296e-02, 6.2414e-02, 5.5127e-01],
          [1.2197e-02, 5.9650e-02, 6.6951e-02],
          [2.9340e-03, 6.2657e-03, 5.4207e-02],
          [1.8916e-02, 2.4442e-03, 6.4408e-01]],

         [[3.0397e-01, 5.4479e-03, 6.2063e-02],
          [7.4335e-02, 2.6459e-02, 2.1893e-02],
          [9.4375e-02, 4.0047e-01, 2.0195e-02],
          [1.8198e-02, 9.8694e-02, 9.4122e-02],
          [6.2946e-02, 3.4124e-01, 7.7727e-02],
          [5.1273e-02, 6.1902e-02, 1.3735e-01],
          [6.9548e-03, 2.2433e-01, 4.5697e-02],
          [5.0420e-02, 4.1742e-04, 7.2876e-02],
          [7.8857e-02, 9.1784e-02, 8.7545e-02],
          [2.2327e-01, 3.0403e-02, 7.4989e-02]],

         [[7.0732e-01, 5.9061e-03, 1.3445e-01],
          [3.1622e-01, 9.8171e-02, 9.2536e-04],
          [3.0679e-04, 2.3249e-01, 6.7419e-02],
          [2.3960e-02, 2.1606e-01, 1.8006e-01],
          [1.5054e-03, 1.8792e-01, 1.5708e-01],
          [4.0009e-03, 9.6314e-03, 2.3829e-01],
          [1.3919e-01, 1.0431e-01, 9.2593e-03],
          [2.6444e-01, 2.9273e-02, 1.5016e-01],
          [3.2548e-01, 2.3194e-02, 1.7092e-01],
          [5.8096e-01, 5.6178e-04, 2.4429e-02]],

         [[2.5062e-01, 3.9311e-05, 6.1346e-02],
          [4.9254e-02, 5.8923e-02, 2.2322e-02],
          [1.2810e-01, 3.0553e-01, 1.9787e-02],
          [3.4452e-02, 1.5542e-01, 9.3239e-02],
          [9.0964e-02, 2.5410e-01, 7.6924e-02],
          [7.6812e-02, 2.8467e-02, 1.3628e-01],
          [1.0682e-03, 1.5488e-01, 4.6316e-02],
          [3.0217e-02, 1.0102e-02, 7.2099e-02],
          [5.2947e-02, 4.9675e-02, 8.6694e-02],
          [1.7792e-01, 8.8893e-03, 7.5782e-02]]],


        [[[1.4806e-03, 5.2729e-02, 1.1337e-01],
          [5.7700e-02, 2.1725e-01, 3.6467e-03],
          [6.7250e-01, 1.0850e-01, 5.2755e-02],
          [4.1958e-01, 3.8142e-01, 1.5553e-01],
          [5.8330e-01, 7.8806e-02, 1.3423e-01],
          [5.4655e-01, 2.9851e-03, 2.0993e-01],
          [1.8443e-01, 2.8965e-02, 1.5925e-02],
          [8.3122e-02, 1.0489e-01, 1.2783e-01],
          [5.3842e-02, 2.2994e-07, 1.4704e-01],
          [1.6269e-03, 1.6660e-02, 3.4695e-02]],

         [[6.6527e-02, 2.8872e-02, 4.7397e-02],
          [4.3095e-04, 1.6515e-01, 3.7797e-01],
          [3.6073e-01, 1.5140e-01, 1.0545e-01],
          [1.8344e-01, 3.1124e-01, 2.5613e-02],
          [2.9626e-01, 1.1589e-01, 3.5358e-02],
          [2.7023e-01, 2.5728e-05, 9.2596e-03],
          [4.4104e-02, 5.2854e-02, 4.6322e-01],
          [4.7418e-03, 6.9780e-02, 3.8761e-02],
          [1.5849e-04, 3.5081e-03, 2.9225e-02],
          [3.2082e-02, 4.8116e-03, 5.4860e-01]],

         [[2.5293e-02, 1.1201e-02, 2.2806e-02],
          [1.4316e-02, 1.1717e-01, 3.0042e-01],
          [4.8930e-01, 2.0538e-01, 6.6580e-02],
          [2.7793e-01, 2.4384e-01, 8.7136e-03],
          [4.1369e-01, 1.6363e-01, 1.4724e-02],
          [3.8283e-01, 4.7828e-03, 8.7222e-04],
          [9.5418e-02, 8.6428e-02, 3.7688e-01],
          [2.8140e-02, 4.0030e-02, 1.6948e-02],
          [1.2428e-02, 1.5206e-02, 1.0870e-02],
          [6.4360e-03, 2.7880e-05, 4.5425e-01]],

         [[4.7686e-02, 6.4129e-05, 1.2840e-01],
          [3.6378e-03, 5.2195e-02, 5.7065e-01],
          [4.0981e-01, 3.2152e-01, 2.1654e-01],
          [2.1889e-01, 1.4437e-01, 9.0394e-02],
          [3.4088e-01, 2.6870e-01, 1.0801e-01],
          [3.1292e-01, 3.3489e-02, 5.6095e-02],
          [6.2282e-02, 1.6632e-01, 6.7440e-01],
          [1.1754e-02, 7.4361e-03, 1.1390e-01],
          [2.7190e-03, 5.6243e-02, 9.7075e-02],
          [1.9477e-02, 1.1786e-02, 7.7667e-01]],

         [[8.0503e-01, 1.3644e-01, 2.0066e-05],
          [3.8260e-01, 3.6705e-01, 1.6126e-01],
          [1.4973e-03, 3.5964e-02, 1.2431e-02],
          [4.4522e-02, 5.7357e-01, 2.8291e-03],
          [9.0270e-03, 1.9874e-02, 6.3462e-04],
          [1.4272e-02, 3.7786e-02, 1.3690e-02],
          [1.8430e-01, 9.2680e-04, 2.1844e-01],
          [3.2541e-01, 2.1494e-01, 2.6736e-04],
          [3.9277e-01, 1.9664e-02, 1.7874e-03],
          [6.6981e-01, 7.2266e-02, 2.7820e-01]]]])

torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
          0.0817, 0.3507],
         [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
          0.0634, 0.6654],
         [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
          0.2582, 0.3287],
         [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
          0.5196, 0.6060],
         [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
          0.1893, 0.2626]],

        [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
          0.2009, 0.0530],
         [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
          0.0329, 0.5855],
         [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
          0.0385, 0.4607],
         [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
          0.1560, 0.8079],
         [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
          0.4142, 1.0203]]])
torch.Size([2, 5, 10])

F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
          0.0817, 0.3507],
         [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
          0.0634, 0.6654],
         [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
          0.2582, 0.3287],
         [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
          0.5196, 0.6060],
         [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
          0.1893, 0.2626]],

        [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
          0.2009, 0.0530],
         [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
          0.0329, 0.5855],
         [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
          0.0385, 0.4607],
         [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
          0.1560, 0.8079],
         [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
          0.4142, 1.0203]]])
torch.Size([2, 5, 10])

F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
          0.2858, 0.5922],
         [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
          0.2518, 0.8157],
         [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
          0.5081, 0.5733],
         [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
          0.7208, 0.7784],
         [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
          0.4351, 0.5124]],

        [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
          0.4482, 0.2302],
         [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
          0.1814, 0.7652],
         [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
          0.1962, 0.6788],
         [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
          0.3950, 0.8989],
         [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
          0.6436, 1.0101]]])
torch.Size([2, 5, 10])

计算结果:
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
          0.2858, 0.5922],
         [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
          0.2518, 0.8157],
         [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
          0.5081, 0.5733],
         [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
          0.7208, 0.7784],
         [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
          0.4351, 0.5124]],

        [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
          0.4482, 0.2302],
         [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
          0.1814, 0.7652],
         [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
          0.1962, 0.6788],
         [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
          0.3950, 0.8989],
         [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
          0.6436, 1.0101]]])
相关推荐
笃励11 分钟前
Java面试题二
java·开发语言·python
infominer20 分钟前
RAGFlow 0.12 版本功能导读
人工智能·开源·aigc·ai-native
涩即是Null22 分钟前
如何构建LSTM神经网络模型
人工智能·rnn·深度学习·神经网络·lstm
本本的小橙子25 分钟前
第十四周:机器学习
人工智能·机器学习
励志成为美貌才华为一体的女子40 分钟前
《大规模语言模型从理论到实践》第一轮学习--第四章分布式训练
人工智能·分布式·语言模型
学步_技术44 分钟前
自动驾驶系列—自动驾驶背后的数据通道:通信总线技术详解与应用场景分析
人工智能·机器学习·自动驾驶·通信总线
winds~1 小时前
自动驾驶-问题笔记-待解决
人工智能·笔记·自动驾驶
学步_技术1 小时前
自动驾驶系列—LDW(车道偏离预警):智能驾驶的安全守护者
人工智能·安全·自动驾驶·ldw
青云交1 小时前
大数据新视界 --大数据大厂之 Kafka 性能优化的进阶之道:应对海量数据的高效传输
大数据·数据库·人工智能·性能优化·kafka·数据压缩·分区策略·磁盘 i/o
一颗星星辰1 小时前
Python | 第九章 | 排序和查找
服务器·网络·python