import torch
import torch.nn as nn
import torch.nn.functional as F
batchsize=2
num_class=4
logits=torch.randn(batchsize,num_class)
target=torch.randint(num_class,size=(batchsize,))#delta目标分布
target_logits=torch.randn(batchsize,num_class)#非delta目标分布
交叉熵 CrossEntropyLoss
python复制代码
## 1. CE Loss 交叉熵
ce_loss_fn=torch.nn.CrossEntropyLoss()
ce_loss=ce_loss_fn(logits,target)
print("ce_loss1:",ce_loss)
ce_loss=ce_loss_fn(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss2:",ce_loss)
负对数似然 NLLLoss
python复制代码
## 2. NLL Loss 负对数似然
nll_fn=torch.nn.NLLLoss()
nll_loss=nll_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),target)
print("nll_loss:",nll_loss)
####CE LOSS value = NLL LOSS value
KL散度 KLDivLoss
python复制代码
## 3. KL loss KL散度
kl_loss_fn=torch.nn.KLDivLoss()
kl_loss=kl_loss_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1))
print("kl_loss:", kl_loss)