PyTorch核心基础知识点(一)

一、基础工具与安装

  1. Python两大法宝函数
    • dir():快速查看模块结构,例如dir(torch)可查看PyTorch所有子模块
    • help():获取函数/类的详细说明,如help(torch.cuda.is_available)查看GPU检测方法
python 复制代码
   import torch
   print(dir(torch.optim))  # 查看优化器模块
  1. 安装PyTorch
    • CPU版本pip install torch torchvision
    • GPU版本:需额外安装CUDA驱动,官网生成对应命令
    • 验证安装:print(torch.__version__)torch.cuda.is_available()

二、核心数据结构:Tensor

  1. Tensor特性
    • 类似NumPy数组但支持GPU加速
    • 创建方式:torch.tensor(), torch.randn(), torch.zeros()
python 复制代码
   x = torch.rand(2,3)  # 创建2x3随机张量
  1. Tensor操作
    • 索引/切片:x[0,:]
    • 形状变换:x.view(3,2)x.reshape(3,2)
    • 设备切换:x.to ('cuda')实现GPU计算

三、数据处理与可视化

  1. Dataset与DataLoader
    • 自定义数据集需实现__len____getitem__
python 复制代码
   from torch.utils.data import Dataset
   class MyDataset(Dataset):
       def __getitem__(self, index): 
           # 返回单条数据
  1. TensorBoard可视化
    • 安装:pip install tensorboard
    • 基础使用:
python 复制代码
     writer = SummaryWriter('logs')
     writer.add_scalar('Loss', loss, epoch)  # 记录标量
     writer.add_image('Input', img_tensor)   # 记录图像[[3,7,10]]
  • 启动服务:tensorboard --logdir=logs --port=6006
  1. Transforms数据增强
    • 将PIL图像转为Tensor并进行归一化:
python 复制代码
     transform = transforms.Compose([
         transforms.Resize(256),
         transforms.ToTensor()  # 范围[0,1][[4,5]]
     ])

四、神经网络搭建

  1. 模型定义
    • 继承nn.Module并实现forward方法
python 复制代码
   class Net(nn.Module):
       def __init__(self):
           super().__init__()
           self.fc = nn.Linear(784, 10)
       def forward(self, x):
           return self.fc(x)
  1. 训练流程
python 复制代码
   optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
   criterion = nn.CrossEntropyLoss()
   
   for epoch in range(10):
       output = model(input)
       loss = criterion(output, target)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()  # 参数更新[[4,12,16]]

五、实战建议

  1. 学习路径

  2. 常见问题

    • GPU不可用:检查CUDA版本与PyTorch是否匹配
    • 维度错误 :使用tensor.shape检查维度,unsqueeze()/squeeze()调整

通过结合官方文档与实战项目(如FashionMNIST分类),可快速提升PyTorch应用能力。建议在学习过程中多用print()和TensorBoard观察中间结果,加深对计算图的理解。

相关推荐
golang学习记1 分钟前
Zed 编辑器的 6 个隐藏技巧:提升开发效率的「冷知识」整理
人工智能
清水白石0086 分钟前
深入 Python 的底层世界:从 C 扩展到 ctypes 与 Cython 的本质差异全解析
c语言·python·neo4j
武汉大学-王浩宇9 分钟前
LLaMa-Factory的继续训练(Resume Training)
人工智能·机器学习
weisian15112 分钟前
入门篇--知名企业-28-字节跳动-2--字节跳动的AI宇宙:从技术赋能到生态共建的深度布局
人工智能·字节跳动·扣子·豆包
Amelia11111113 分钟前
day49
python
NGBQ1213822 分钟前
原创餐饮店铺图片数据集:344张高质量店铺图像助力商业空间识别与智能分析的专业数据集
人工智能
FIT2CLOUD飞致云23 分钟前
应用升级为智能体,模板中心上线,MaxKB开源企业级智能体平台v2.5.0版本发布
人工智能·ai·开源·1panel·maxkb
haiyu_y30 分钟前
Day 58 经典时序模型 2(ARIMA / 季节性 / 残差诊断)
人工智能·深度学习·ar
IT=>小脑虎31 分钟前
2026版 Python零基础小白学习知识点【基础版详解】
开发语言·python·学习
我想吃烤肉肉34 分钟前
Playwright中page.locator和Selenium中find_element区别
爬虫·python·测试工具·自动化