说明:
函数的功能是生成网格,可以用于生成坐标。
函数输入:
输入两个一维tensor数据,且两个tensor数据类型相同,也可以输入三个一维tensor数据
函数输出:
输出两个tensor数据(两个tensor的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)或者三个tensor数据(三个tensor第一维度大小为第一个输入张量的元素个数,第二维度大小为第二个输入张量的元素个数,第三维度为第三个输入张量元素个数)
报错:
当两个输入tensor数据类型不同或维度不是一维时会报错。
结果理解:
输入两个一维张量的元素个数分别为n1,n2,则输出两个张量是二维的,且行和列个数均为n1,n2,输出第一个张量行相同(对应第一个输入张量),输出第二个张量列相同(对应第二个输入张量),其中第一个输出张量填充第一个输入张量中的元素,各行元素相同 ;第二个输出张量填充第二个输入张量中的元素,各列元素相同。
若输入是三个一维张量,元素个数分别为n1,n2,n3,则输出的三个张量都是三维的,且输出的三个张量的三个维度均相等,分别为n1,n2,n3。
输入为两个张量:
python
import torch
import torch.nn as nn
a1 = torch.tensor([1,3])
b1 = torch.tensor([2,4,6])
x1,y1 = torch.meshgrid(a1,b1)
print(x1)
print(y1)
输出:
tensor([[1, 1, 1],
[3, 3, 3]])
tensor([[2, 4, 6],
[2, 4, 6]])
输入为三个张量:
python
import torch
import torch.nn as nn
a2 = torch.tensor([1,3])
b2 = torch.tensor([2,4,6])
c2 = torch.tensor([7,8,9,10])
x2,y2,z2 = torch.meshgrid(a2,b2,c2)
print(x2)
print(x2.shape)
print(y2)
print(y2.shape)
print(z2)
print(z2.shape)
输出:
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3]]])
torch.Size([2, 3, 4])
tensor([[[2, 2, 2, 2],
[4, 4, 4, 4],
[6, 6, 6, 6]],
[[2, 2, 2, 2],
[4, 4, 4, 4],
[6, 6, 6, 6]]])
torch.Size([2, 3, 4])
tensor([[[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10]],
[[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10]]])
torch.Size([2, 3, 4])