pytorch Loss Functions

1. pytorch中loss函数使用方法示例

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# 定义网络时需要继承nn.Module并实现它的forward方法,将网络中具有可学习参数的层放在构造函数__init__中
# 不具有可学习参数的层(如ReLu)既可以放在构造函数中也可以不放

# torch.nn.MaxPool2d和torch.nn.functional.max_pool2d,在pytorch构建模型中,都可以作为最大池化层的引入,但前者为类模块,后者为函数,在使用上存在不同。
# torch.nn.functional.max_pool2d是函数,可以直接调用;torch.nn.MaxPool2d是类模块,要先实例化,再调用其函数。
# torch.nn中其它模块跟torch.nn.functional中其它对应的函数也是类似的用法。
class myNet(torch.nn.Module):
    def __init__(self):
        super(myNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(1,6,5)
        self.conv2 = torch.nn.Conv2d(6,16,5)

        self.fc1 = torch.nn.Linear(16*5*5,120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

        self.pooling = torch.nn.MaxPool2d(2)
        self.activate = torch.nn.ReLU()

    def forward(self, x):
        x = self.pooling(self.activate(self.conv1(x)))
        x = self.pooling(self.activate(self.conv2(x)))
        x = x.view(x.size()[0], -1)
        x = self.activate(self.fc1(x))
        x = self.activate(self.fc2(x))
        x = self.fc3(x)

        return x

input = Variable(torch.randn(1,1,32,32))
net = myNet()          # 创建myNet()对象
output = net(input)    # 调用myNet()对象的forward()方法,有点类似C++中的operator()()
target = Variable(torch.arange(0, 10))
citerion = torch.nn.MSELoss()                    # 创建MSELoss()对象
loss = citerion(output.float(), target.float())  # 调用loss函数
print(loss)

print('*'*30)

net.zero_grad()   # 把net中所有可学习参数的梯度清零
print(net.conv1.bias.grad)
loss.backward()
print(net.conv1.bias.grad)

输出结果:

bash 复制代码
tensor(28.6363, grad_fn=<MseLossBackward0>)
******************************
None
tensor([ 0.1782, -0.0815, -0.0902, -0.0140,  0.0267,  0.0015])

2. pytorch官方支持的loss

https://pytorch.org/docs/stable/nn.html#loss-functions

相关推荐
waterHBO1 小时前
python 爬虫 selenium 笔记
爬虫·python·selenium
编程零零七2 小时前
Python数据分析工具(三):pymssql的用法
开发语言·前端·数据库·python·oracle·数据分析·pymssql
AIAdvocate4 小时前
Pandas_数据结构详解
数据结构·python·pandas
小言从不摸鱼4 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
FreakStudio6 小时前
全网最适合入门的面向对象编程教程:50 Python函数方法与接口-接口和抽象基类
python·嵌入式·面向对象·电子diy
redcocal7 小时前
地平线秋招
python·嵌入式硬件·算法·fpga开发·求职招聘
artificiali7 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
RaidenQ8 小时前
2024.9.13 Python与图像处理新国大EE5731课程大作业,索贝尔算子计算边缘,高斯核模糊边缘,Haar小波计算边缘
图像处理·python·算法·课程设计
花生了什么树~.8 小时前
python基础知识(六)--字典遍历、公共运算符、公共方法、函数、变量分类、参数分类、拆包、引用
开发语言·python
Trouvaille ~8 小时前
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
图像处理·python·机器学习·numpy·信号处理·时间序列分析·科学计算