pytorch -- torch.nn.Module

  1. 基础

    torch.nn 是 PyTorch 中用于构建神经网络的模块。nn.Module包含网络各层的定义及forward方法。

    在用户自定义神经网络时,需要继承自nn.Module类。通过继承 nn.Module 类,您可以创建自己的神经网络模型,并定义模型的结构和操作。
    torch.nn 模块中常用的一些类和函数

    nn.Linear: 线性层,用于定义全连接层。

    nn.Conv2d: 二维卷积层,用于处理图像数据。

    nn.ReLU: ReLU 激活函数。

    nn.Sigmoid: Sigmoid 激活函数。

    nn.Dropout: Dropout 层,用于正则化和防止过拟合。

    nn.CrossEntropyLoss: 交叉熵损失函数,通常用于多类别分类问题。

    nn.MSELoss: 均方误差损失函数,通常用于回归问题。

    nn.Sequential: 顺序容器,用于按顺序组合多个层。

    还能使用 PyTorch 提供的优化器(如 torch.optim)和损失函数来训练和优化模型。

  2. 使用

python 复制代码
import torch
from torch.nn import Module
class yaya(Module):

    def __init__(self):
        super().__init__()
    def forward(self,input):
        output = input+1
        return output

tu = yaya()
x = torch.tensor(1.0)
output = tu(x)
print(output)
相关推荐
easy_coder2 分钟前
ReAct 进入死循环?用 Harness 把它拉回来
人工智能·架构·云计算
tangweiguo030519872 分钟前
LangChain + RAG + Agent + 多模态 完整实战教程
python·langchain
我是无敌小恐龙11 分钟前
Java SE 零基础入门Day06 方法重载+Debug调试+String字符串全套API详解(超全干货)
java·开发语言·人工智能·python·transformer·无人机·量子计算
aidesignplus13 分钟前
从平方到线性:Mamba如何挑战Transformer的长序列效率瓶颈?
人工智能·python·深度学习·vim·transformer
2301_7735536214 分钟前
Redis怎样优化复制缓冲池大小_调大repl-backlog-size减少频繁的全量同步触发
jvm·数据库·python
三维频道14 分钟前
工业级三维扫描实测:汽车灯具复杂结构件的全尺寸 3D 测量方案分析
java·人工智能·python·数码相机·3d·汽车·汽车轻量化制造
人工智能AI技术15 分钟前
过拟合与欠拟合:机器学习最基础核心问题
人工智能
weixin_3812881818 分钟前
HTML lang 属性的正确取值规范:BCP 47 格式详解与最佳实践
jvm·数据库·python
码农飞哥21 分钟前
从Java后端到AI应用开发,我这两年做了什么
java·开发语言·人工智能
阿荻在肝了22 分钟前
Agent学习七:LangGraph学习-持久化与记忆二
python·学习·agent