python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 2
seq_length = 3
vocab_size = 4
logits = torch.randn(batch_size,seq_length,vocab_size)
print(f"logits=\n{logits}")
logits_t = logits.transpose(-1,-2)
print(f"logits_t=\n{logits_t}")
label = torch.randint(0,vocab_size,(batch_size,seq_length))
print(f"label=\n{label}")
result_none = F.cross_entropy(logits_t,label,reduction="none")
print(f"result_none=\n{result_none}")
result_none_mean = torch.mean(result_none)
result_mean = F.cross_entropy(logits_t,label)
print(f"result_mean=\n{result_mean}")
print(f"result_none_mean={result_none_mean}")
python
复制代码
logits=
tensor([[[ 0.477, 2.017, 1.016, -0.299],
[-0.189, 0.321, -0.885, 1.418],
[ 0.027, -0.606, 0.079, -0.491]],
[[ 1.911, 1.643, -0.327, 0.185],
[-0.031, -1.463, -0.073, 1.391],
[-0.710, 0.811, 1.521, 0.033]]])
logits_t=
tensor([[[ 0.477, -0.189, 0.027],
[ 2.017, 0.321, -0.606],
[ 1.016, -0.885, 0.079],
[-0.299, 1.418, -0.491]],
[[ 1.911, -0.031, -0.710],
[ 1.643, -1.463, 0.811],
[-0.327, -0.073, 1.521],
[ 0.185, 1.391, 0.033]]])
label=
tensor([[0, 0, 0],
[3, 0, 0]])
result_none=
tensor([[2.059, 2.098, 1.157],
[2.444, 1.848, 2.832]])
result_mean=
2.0730881690979004
result_none_mean=2.0730881690979004