6、PyTorch中搭建分类网络实例

1. 重要类

  • nn.Module
  • nn.flatten
  • nn.linear
  • nn.relu
  • to.device
  • torch.cuda.is_available
  • nn.softmax
  • nn.argmax
  • nn.sequential
  • nn.conv2d
  • add_module
  • buffer
  • load_state_dict
  • named_parameters
  • requires_grad
  • save_check_points

2. 代码测试

python 复制代码
import torch
from torch import nn
from torch.nn import Module

torch.set_printoptions(precision=3)


class MyModelTest(Module):
    def __init__(self):
        super(MyModelTest, self).__init__()
        self.linear_1 = nn.Linear(3, 4)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(4, 5)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        y = self.linear_2(x)
        return y


if __name__ == "__main__":
    matrix = torch.arange(3,dtype=torch.float)
    my_softmax = nn.Softmax(dim=0)
    output = my_softmax(matrix)
    print(f"matrix=\n{matrix}")
    print(f"output=\n{output}")
    my_model = MyModelTest()
    for name, param in my_model.named_parameters():
        print(f"layer:{name}\n|size:{param.size()}\n|values:{param[:2]}\n")
  • 结果:
python 复制代码
matrix=
tensor([0., 1., 2.])
output=
tensor([0.090, 0.245, 0.665])
layer:linear_1.weight
|size:torch.Size([4, 3])
|values:tensor([[-0.544, -0.492,  0.190],
        [-0.424, -0.068,  0.134]], grad_fn=<SliceBackward0>)

layer:linear_1.bias
|size:torch.Size([4])
|values:tensor([0.295, 0.306], grad_fn=<SliceBackward0>)

layer:linear_2.weight
|size:torch.Size([5, 4])
|values:tensor([[ 0.489,  0.018,  0.314,  0.497],
        [ 0.364, -0.455,  0.047, -0.215]], grad_fn=<SliceBackward0>)

layer:linear_2.bias
|size:torch.Size([5])
|values:tensor([-0.027,  0.190], grad_fn=<SliceBackward0>)
相关推荐
大侠区块链10 小时前
我面试了上百个想进 AI 公司的人,发现他们都搞错了一件事--深度精读 | 对话 Anthropic Claude Code 产品负责人 Cat Wu
人工智能·面试·职场和发展
绿虫光伏运维10 小时前
光伏运维精细化管理,解锁电站收益最大化
大数据·运维·人工智能·光伏业务
小仙女的小稀罕10 小时前
适合销售从业者会议整理使用的销售录音转任务工具
大数据·人工智能·学习·自然语言处理·语音识别
GitCode官方10 小时前
头号 Builder 集结|出海 Agent 开造!大疆 Pocket4 等你赢!
人工智能·agent·atomgit
CIO_Alliance10 小时前
2026年生成式引擎优化(GEO)解决方案选型指南|幂链科技的实战可验证与全链路合规
人工智能·geo·deepseek·ai搜索优化·幂链geo·豆包ai
DreamWear10 小时前
Claude Context:让 AI 编程助手真正"看见"整个代码库
人工智能·agent
HalukiSan10 小时前
VLLM部署Qwen3-30B-A3B-Instruct-2507-FP8
人工智能
DreamWear10 小时前
Agent Skills:给 AI 编码代理装上高级工程师的工作纪律
人工智能·agent
水月天涯10 小时前
ClaudeCode入门00-初识神器(小白入门-到底什么是ClaudeCode,为什么大家都说好用?)
人工智能·ai编程
HalukiSan10 小时前
一些关于AI训练的基础概念
人工智能