【pytorch】深度学习准备:基本配置

深度学习中常用包

python 复制代码
import os 
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optimizer

超参数设置

2种设置方式:将超参数直接设置在训练的代码中;用yaml、json,dict等文件来存储超参数

python 复制代码
# 批次的大小
batch_size = 16
# 优化器的学习率
lr = 1e-4
# 训练次数
max_epochs = 100

GPU设置

python 复制代码
# 方案一:使用os.environ,这种情况如果使用GPU不需要设置
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号

# 方案二:使用"device",后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号

使用argparse和yaml文件

  1. argparse的使用:
python 复制代码
import argparse
"""
	argparse.ArgumentParser()创建了一个对象
	add_argument()添加参数
	parse_args()将参数封装在opt内,各个参数通过.运算符调用
"""

def main(opt):
    print(opt.num_batches)

if __name__ == '__main__':

    parse = argparse.ArgumentParser()
    parse.add_argument('--num_batches', type=int, default=50, help='the num of batch')
    parse.add_argument('--num_window', type=int, default=5, help='the num of window')
    parse.add_argument('--weight', type=str, default= '../pretrain.pth', help='the path of pretrained model')

    opt = parse.parse_args()
    main(opt)
  1. yaml文件的使用
    下面是一个yaml文件的例子,参数呈现层级结构
yaml 复制代码
device: 'cpu'

data:
    train_path: 'data/train'
    test_path: 'test/train'
    num: 1000

读取yaml文件

python 复制代码
def read_yaml(path):
"""
	read()读入yaml文件中的内容
	safe_load()加载yaml格式的内容并转换为字典
"""
    file = open(path, 'r', encoding='utf-8')
    string = file.read()
    file.close()
    dict = yaml.safe_load(string)

    return dict

path = 'config.yaml'
Dict = read_yaml(path)
device = Dict['device']
print(device)
train_path = Dict['data']['train_path']
print(train_path)
  1. 使用方法
    在yaml文件中给全部参数设置默认值,使用argparse库设置待调参数的值

参考资料

  1. 深度学习代码中的argparse以及yaml文件的使用
  2. datawhale的thorough-pytorch repo
相关推荐
时空无限16 分钟前
说说transformer 中的掩码矩阵以及为什么能掩盖住词语
人工智能·矩阵·transformer
HAH-HAH20 分钟前
【Python 入门】(2)Python 语言基础(变量)
开发语言·python·学习·青少年编程·个人开发·变量·python 语法
查里王22 分钟前
AI 3D 生成工具知识库:当前产品格局与测评总结
人工智能·3d
武子康1 小时前
AI-调查研究-76-具身智能 当机器人走进生活:具身智能对就业与社会结构的深远影响
人工智能·程序人生·ai·职场和发展·机器人·生活·具身智能
递归不收敛1 小时前
PyCharm项目上传GitHub仓库(笔记)
笔记·pycharm·github
小鹿清扫日记1 小时前
从蛮力清扫到 “会看路”:室外清洁机器人的文明进阶
人工智能·ai·机器人·扫地机器人·具身智能·连合直租·有鹿巡扫机器人
技术小黑1 小时前
Transformer系列 | Pytorch复现Transformer
pytorch·深度学习·transformer
递归不收敛1 小时前
一、Java 基础入门:从 0 到 1 认识 Java(详细笔记)
java·开发语言·笔记
fanstuck1 小时前
Prompt提示工程上手指南(六):AI避免“幻觉”(Hallucination)策略下的Prompt
人工智能·语言模型·自然语言处理·nlp·prompt
zhangfeng11332 小时前
win7 R 4.4.0和RStudio1.25的版本兼容性以及系统区域设置有关 导致Plots绘图面板被禁用,但是单独页面显示
开发语言·人工智能·r语言·生物信息