文章目录
报错信息
bash
RuntimeError: expected scalar type Long but found Float
原因
nn.Linear需要作用于浮点数,这里可能输入了整数类型的张量作为参数。
代码示例
错误版
py
import torch
import torch.nn as nn
a = torch.tensor([1,2,3,4])
lin = nn.Linear(4,2)
b = lin(a)
print(b)
报错:
改正
py
import torch
import torch.nn as nn
a = torch.tensor([1,2,3,4])
lin = nn.Linear(4,2)
b = lin(a.float())
print(b)
把a转为float,结果为:
bash
tensor([-1.1703, 0.0518], grad_fn=<AddBackward0>)