PyTorch Lightning教程四:超参数的使用

如果需要和命令行接口进行交互,可以使用Python中的argparse包,快捷方便,对于Lightning而言,可以利用它,在命令行窗口中,直接配置超参数等操作,但也可以使用LightningCLI的方法,更加轻便简单。

ArgumentParser

ArgumentParser是Python的内置特性,进而构建CLI程序,我们可以使用它在命令行中设置超参数和其他训练设置。

python 复制代码
from argparse import ArgumentParser

parser = ArgumentParser()
# 训练方式(GPU or CPU or 其他)
parser.add_argument("--devices", type=int, default=2)
# 超参数
parser.add_argument("--layer_1_dim", type=int, default=128)
# 解析用户输入和默认值 (returns argparse.Namespace)
args = parser.parse_args()

# 在程序中使用解析后的参数
trainer = Trainer(devices=args.devices)
model = MyModel(layer_1_dim=args.layer_1_dim)

然后在命令行中如此调用

shell 复制代码
python trainer.py --layer_1_dim 64 --devices 1

Python的参数解析器在简单的用例中工作得很好,但在大型项目中维护它可能会变得很麻烦。例如,每次在模型中添加、更改或删除参数时,都必须添加、编辑或删除相应的add_argument。Lightning CLI提供了与Trainer和LightningModule的无缝集成,为您自动生成CLI参数。

LightningCLI

shell 复制代码
pip install "jsonargparse[signatures]"

执行起来很简单,例如

python 复制代码
# main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

def cli_main():
    # 只需要写这一行即可,两个参数,对应模型和数据
    cli = LightningCLI(DemoModel, BoringDataModule)	
    # 注意: 别写.fit

if __name__ == "__main__":
    cli_main()  # 在函数中实现CLI并在主if块中调用它是一种很好的做法

然后在命令行中执行help,进行文档查询

shell 复制代码
python main.py --help

执行结果

shell 复制代码
usage: main.py [-h] [-c CONFIG] [--print_config[=flags]]
               {fit,validate,test,predict,tune} ...

pytorch-lightning trainer command line tool

optional arguments:
  -h, --help            Show this help message and exit.
  -c CONFIG, --config CONFIG
                        Path to a configuration file in json or yaml format.
  --print_config[=flags]
                        Print the configuration after applying all other
                        arguments and exit. The optional flags customizes the
                        output and are one or more keywords separated by
                        comma. The supported flags are: comments,
                        skip_default, skip_null.

subcommands:
  For more details of each subcommand, add it as an argument followed by
  --help.

  {fit,validate,test,predict,tune}
    fit                 Runs the full optimization routine.
    validate            Perform one evaluation epoch over the validation set.
    test                Perform one evaluation epoch over the test set.
    predict             Run inference on your data.
    tune                Runs routines to tune hyperparameters before training.

因此可以使用如下方法:

shell 复制代码
$ python main.py fit		# 训练
$ python main.py validate	# 验证
$ python main.py test		# 测试
$ python main.py predict	# 预测

例如训练过程,可以通过以下方法具体调参数

shell 复制代码
# learning_rate
python main.py fit --model.learning_rate 0.1

# output dimensions
python main.py fit --model.out_dim 10 --model.learning_rate 0.1

# trainer 和 data arguments
python main.py fit --model.out_dim 2 --model.learning_rate 0.1 --data.data_dir '~/' --trainer.logger False
相关推荐
明天好,会的13 分钟前
分形生成实验:在有限上下文中构建可组合的强类型单元
人工智能
小智RE0-走在路上13 分钟前
Python学习笔记(13) --Mysql,Python关联数据库
数据库·python·学习
All The Way North-13 分钟前
从0到1,构建自己的全连接神经网络
人工智能·pytorch·深度学习·全连接神经网络
week_泽15 分钟前
6、OpenCV SURF特征检测笔记
人工智能·笔记·opencv
YJlio15 分钟前
杨利杰YJlio|博客导航目录(专栏总览 + 推荐阅读路线)
开发语言·python·pdf
Swizard16 分钟前
数据不够代码凑?用 Albumentations 让你的 AI 模型“看”得更广,训练快 10 倍!
python·算法·ai·训练
AI即插即用17 分钟前
即插即用系列 | CVPR 2025 DICMP:基于深度信息辅助的图像去雾与深度估计双任务协同互促网络
图像处理·人工智能·深度学习·神经网络·计算机视觉·视觉检测
Coder_Boy_18 分钟前
基于SpringAI的智能平台基座开发-(五)
java·人工智能·spring boot·langchain·springai
AI即插即用19 分钟前
即插即用系列 | WACV 2024 CSAM:面向各向异性医学图像分割的 2.5D 跨切片注意力模块
图像处理·人工智能·深度学习·神经网络·目标检测·计算机视觉·视觉检测
今夕资源网19 分钟前
仙宫云自动抢算力工具可后台运行,仙宫云自动抢卡,仙宫云自动抢显卡,AI云平台抢算力
人工智能·后台·仙宫云·抢算力·抢显卡·抢gpu