pytorch基础-nn.linear

python 复制代码
import torch
import torch.nn as nn

# 定义线性层
linear_layer = nn.Linear(in_features=10, out_features=5, bias=True)

# 输入数据
input_data = torch.randn(32, 10)  # (batch_size=32, in_features=10)

# 前向传播
output = linear_layer(input_data)
print(output.shape)  # 输出形状: (32, 5)

维度变化

  • 输入(batch_size, in_features)

  • 输出(batch_size, out_features)

示例

  • 输入形状:(32, 10)

  • 线性层:nn.Linear(10, 5)

  • 输出形状:(32, 5)

实现细节:矩阵乘法

相关推荐
choke2335 分钟前
[特殊字符] Python异常处理
开发语言·python
hqyjzsb5 分钟前
盲目用AI提效?当心陷入“工具奴”陷阱,效率不增反降
人工智能·学习·职场和发展·创业创新·学习方法·业界资讯·远程工作
Eloudy12 分钟前
用 Python 直写 CUDA Kernel的技术,CuTile、TileLang、Triton 与 PyTorch 的深度融合实践
人工智能·pytorch
神的泪水13 分钟前
CANN 实战全景篇:从零构建 LLM 推理引擎(基于 CANN 原生栈)
人工智能
yuanyuan2o214 分钟前
【深度学习】全连接、卷积神经网络
人工智能·深度学习·cnn
八零后琐话19 分钟前
干货:Claude最新大招Cowork避坑!
人工智能
汗流浃背了吧,老弟!37 分钟前
BPE 词表构建与编解码(英雄联盟-托儿索语料)
人工智能·深度学习
软件聚导航1 小时前
从 AI 画马到马年红包封面,我还做了一个小程序
人工智能·chatgpt
啊森要自信1 小时前
CANN ops-cv:AI 硬件端视觉算法推理训练的算子性能调优与实战应用详解
人工智能·算法·cann