1. 二元交叉熵损失函数 BCE
1.1 介绍和理解
计算公式:
L=−1N∑i=1N[yilog(y^i)+(1−yi)log(1−y^i)] \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i) \right] L=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
前面的 y 是预测值,后面的 y 是目标值。
可以把这个损失函数理解为一个严厉的老师,会对猜错的行为进行严厉处罚。比如:
- 如果模型预测的概率是1,而目标值其实是0,这就意味着模型猜错了,带入公式会发现只用计算前一半,得到的结果会非常大,结果非常大也就意味着惩罚很强。
- 同理,如果猜对了,结果会非常接近0。
1.2 代码实现
python
import torch
import torch.nn as nn
# 定义输入
input = torch.randn(3, 2)
# 二分类的情况需要将值控制在 0 - 1
pred = torch.sigmoid(input)
# 定义目标值
target = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.float32)
loss = nn.BCELoss()
print(loss(pred, target))
2. 多分类交叉熵损失函数
2.1 概念和理解
计算公式为:
L=−1N∑i=1N∑c=1Cyi,clog(y^i,c) \mathcal{L} = - \frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) L=−N1i=1∑Nc=1∑Cyi,clog(y^i,c) 原理和理解同二分类交叉熵损失函数。
2.2 代码实现
python
import torch
import torch.nn as nn
# 多分类交叉熵损失函数的使用
# 定义输入值,六分类
input = torch.randn(5, 6)
# 其实这一步可以省略,因为损失函数底层默认经过了softmax处理
pred = input.softmax(dim=1)
# 定义目标值,情况一:目标值用顺序标签标识,即一组数据直接对应分类的目标值
target = torch.tensor([1, 3, 5, 0, 4])
loss = nn.CrossEntropyLoss()
print(loss(pred, target))
# 情况二:目标值是一组概率值,即每个目标值都是一组概率,实际中可使用独热编码的形式,这里为了方便演示直接使用一组处理后的概率数据
target2 = torch.randn(5, 6).softmax(dim=1)
print(loss(pred, target2))
3. 使用损失函数的简单模型练习
python
import torch
import torch.nn as nn
# 定义一个简单模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 只定义一层
self.linear = nn.Linear(5, 3)
# 初始化权重
self.linear.weight.data.normal_(mean=0.0, std=0.01)
print(self.linear.weight)
# 初始化偏置
self.linear.bias.data.normal_(mean=0.0, std=0.01)
print(self.linear.bias)
def forward(self, x):
x = self.linear(x)
return x
# 创建模型
model = Net()
# 定义输入
input = torch.randn(2, 5)
# 定义目标值
target = torch.zeros(2, 3)
# 定义损失函数
loss = nn.MSELoss()
# 前向传播
output = model(input)
# 计算损失
loss = loss(output, target)
# 反向传播
loss.backward()
# 定义参数更新方式,即定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# 仅仅更新一步,演示
optimizer.step()
optimizer.zero_grad()
# 打印参数
for param in model.state_dict():
print(param)
print(model.state_dict()[param])