深入浅出Pytorch函数——torch.nn.Module

分类目录:《深入浅出Pytorch函数》总目录


torch.nn.Module是所有Pytorch中所有神经网络模型的基类,我们的神经网络模型也应该继承这个类。Modules可以包含其它Modules,也允许使用树结构嵌入他们,还可以将子模块赋值给模型属性。

实例

复制代码
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)       # submodule: Conv2d
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

通过上面方式赋值的submodule会被注册。当调用.cuda()的时候,submodule的参数也会转换为cuda Tensor

函数

eval()

将模块设置为evaluation模式,相当于self.train(False)。这个函数仅当模型中有DropoutBatchNorm时才会有影响。

复制代码
def eval(self: T) -> T:
        r"""Sets the module in evaluation mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        See :ref:`locally-disable-grad-doc` for a comparison between
        `.eval()` and several similar mechanisms that may be confused with it.

        Returns:
            Module: self
        """
        return self.train(False)
相关推荐
人工智能AI技术4 小时前
YOLOv9目标检测实战:用Python搭建你的第一个实时交通监控系统
人工智能
小雨中_4 小时前
2.7 强化学习分类
人工智能·python·深度学习·机器学习·分类·数据挖掘
拯救HMI的工程师4 小时前
【拯救HMI】工业HMI字体选择:拒绝“通用字体”,适配工业场景3大要求
人工智能
lczdyx5 小时前
【胶囊网络】01-2 胶囊网络发展历史与研究现状
人工智能·深度学习·机器学习·ai·大模型·反向传播
AomanHao5 小时前
【ISP】基于暗通道先验改进的红外图像透雾
图像处理·人工智能·算法·计算机视觉·图像增强·红外图像
AI智能观察5 小时前
从数据中心到服务大厅:数字人智能体如何革新电力行业服务模式
人工智能·数字人·智慧展厅·智能体·数字展厅
AI智能观察5 小时前
生成式AI驱动信息分发变革:GEO跃迁方向、价值锚点与企业生存指南
人工智能·流量运营·geo·ai搜索·智能营销·geo工具·geo平台
苏渡苇5 小时前
轻量化AI落地:Java + Spring Boot 实现设备异常预判
java·人工智能·spring boot·后端·网络协议·tcp/ip·spring
大熊背5 小时前
APEX系统中为什么 不用与EV0的差值计算曝光参数调整量
人工智能·算法·apex·自动曝光
小雨中_5 小时前
2.4 贝尔曼方程与蒙特卡洛方法
人工智能·python·深度学习·机器学习·自然语言处理