1 背景说明
前一段时间基于nVidia nemo-toolkit 2.7.0框架进行ASR模型的训练工作,期间碰到一些应用问题,这里做个汇总记录,供大家和自己日后参考。
2 问题汇总
2.1 加载.ckpt文件问题
1 )问题现象
通过tran.py 文件加载.ckpt检查点文件进行接力训练时,出现如下识别告警信息:
Traceback (most recent call last):
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1877, in gzopen
t = cls.taropen(name, mode, fileobj, **kwargs)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1854, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1714, in init
self.firstmember = self.next()
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 2629, in next
raise e
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 2602, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1292, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/gzip.py", line 301, in read
return self._buffer.read(size)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/_compression.py", line 68, in readinto
data = self.read(len(byte_view))
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/gzip.py", line 488, in read
if not self._read_gzip_header():
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/gzip.py", line 436, in _read_gzip_header
raise BadGzipFile('Not a gzipped file (%r)' % magic)
gzip.BadGzipFile: Not a gzipped file (b'PK')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/tigerp/morsecodetoolkit/examples/morse_to_text/train.py", line 40, in main
model.maybe_init_from_pretrained_checkpoint(cfg)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/classes/modelPT.py", line 1378, in maybe_init_from_pretrained_checkpoint
restored_model = self.restore_from(
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/classes/modelPT.py", line 501, in restore_from
instance = cls._save_restore_connector.restore_from(
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/connectors/save_restore_connector.py", line 270, in restore_from
loaded_params = self.load_config_and_state_dict(
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/connectors/save_restore_connector.py", line 157, in load_config_and_state_dict
members = self._filtered_tar_info(restore_path, filter_fn=filter_fn)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/connectors/save_restore_connector.py", line 659, in _filtered_tar_info
with SaveRestoreConnector._tar_open(tar_path) as tar:
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/contextlib.py", line 135, in enter
return next(self.gen)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/site-packages/nemo/core/connectors/save_restore_connector.py", line 698, in _tar_open
tar = tarfile.open(path2file, tar_header)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1824, in open
return func(name, filemode, fileobj, **kwargs)
File "/usr/local/anaconda3/envs/nemo270/lib/python3.10/tarfile.py", line 1881, in gzopen
raise ReadError("not a gzip file") from e
tarfile.ReadError: not a gzip file
2 )原因分析
在 NVIDIA NeMo 框架中,.nemo 和 .ckpt 是两种不同用途的模型文件格式。
.nemo 文件是 NeMo 框架推荐的、用于模型共享、评估、微调和推理的最终格式。它本质上是一个 tar 归档压缩包,里面装好了运行一个模型所需的所有东西:模型配置 (model_config.yaml)
,模型权重 (model_weights.ckpt)和 其他工具包(例如分词器的词汇表文件vocab.json)。.nemo 文件通过 model.restore_from('your_model.nemo') 即可一键加载,用于推理或继续微调
.ckpt 文件是 PyTorch Lightning在训练过程中自动生成的检查点文件。它是一个 PyTorch 序列化文件,记录了训练在某一时刻的完整状态,除了模型权重,它通常还包含大量用于恢复训练的信息。.ckpt 文件通常需要通过torch.load加载checkpoint文件,获取state_dict并通过model.load_state_dict来完成加载。
上面这个错误 gzip.BadGzipFile: Not a gzipped file (b'PK') 表明程序试图将一个文件作为 gzip 压缩文件(即期望是一个.nemo格式文件)来读取,而实际输入是一个.ckpt 格式文件,导致模型加载失败。实际相关代码如下:
trainer = Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model: MorseEncDecCTCModel = MorseEncDecCTCModel(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg) #期望是一个.nemo格式文件
trainer.fit(model)
3 )解决方案
如下增加支持.ckpt 文件加载的代码段:
Warm start if pre-trained model is specified.
pretrained_path = cfg.get("init_from_nemo_model", None)
if pretrained_path and str(pretrained_path).endswith(".ckpt"):
Load weights from a .ckpt checkpoint file for fine-tuning
logging.info(f"Loading pretrained weights from checkpoint: {pretrained_path}")
checkpoint = torch.load(pretrained_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
else:
Default: use NeMo's built-in method for .nemo files
model.maybe_init_from_pretrained_checkpoint(cfg)
2.2 nv_one_logger模块缺失问题
1)问题现象
安装完nemo-toolkit 2.7.0及相关模块,运行训练框架时提示nv_one_logger模块缺失:
Traceback (most recent call last): File "/home/tigerp/se/gmo/train.py", line 28, in <module> from pubs.common import ( File "/home/tigerp/se/gmo/pubs/common.py", line 23, in <module> from nemo.collections.asr.metrics.wer import word_error_rate File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/collections/asr/init.py", line 15, in <module> from nemo.collections.asr import data, losses, models, modules File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/collections/asr/losses/init.py", line 15, in <module> from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/collections/asr/losses/angularloss.py", line 18, in <module> from nemo.core.classes import Loss, Typing, typecheck File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/core/init.py", line 16, in <module> from nemo.core.classes import * File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/core/classes/init.py", line 33, in <module> from nemo.core.classes.modelPT import ModelPT File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/core/classes/modelPT.py", line 51, in <module> from nemo.lightning.callback_group import CallbackGroup File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/init.py", line 28, in <module> from nemo.lightning.fabric.strategies import FabricMegatronStrategy File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/fabric/strategies.py", line 70, in <module> from nemo.lightning.pytorch.strategies import MegatronStrategy File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/pytorch/strategies/init.py", line 15, in <module> from nemo.lightning.pytorch.strategies.fsdp2_strategy import FSDP2Strategy File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/pytorch/strategies/fsdp2_strategy.py", line 37, in <module> from nemo.lightning.pytorch.strategies.utils import ( File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/pytorch/strategies/utils.py", line 51, in <module> from nemo.lightning.pytorch.callbacks import MegatronProgressBar, ProgressPrinter File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/pytorch/callbacks/init.py", line 22, in <module> from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/pytorch/callbacks/model_checkpoint.py", line 31, in <module> from nemo.lightning.callback_group import CallbackGroup File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/callback_group.py", line 21, in <module> from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback File "/home/tigerp/anaconda3/envs/matcha/lib/python3.10/site-packages/nemo/lightning/one_logger_callback.py", line 25, in <module> from nv_one_logger.api.config import OneLoggerConfig ModuleNotFoundError: No module named 'nv_one_logger'
2)原因分析
nv_one_logger 是 NVIDIA 推出的一个轻量级日志工具库,主要作用是帮助开发者将训练过程中的日志数据(如损失值、准确率等)无缝对接到 NVIDIA 自家的 NGC 目录(NGC Catalog) 实验管理平台。
从错误告警信息来看,是环境中缺少了 NeMo 框架所依赖的一个名为 nv_one_logger 的库,但是直接通过pip在线安装失败(找不到这个库),通过搜索查询和问询AI,告知nv_one_logger 库是nVidia的一个非公开库。一时没了招,直接到NeMo库的github网址下载了nemo2.7.0的源码下来进一步分析原因;后来考虑到实际项目中并不需要用到nv_one_logger 日志工具,干脆直接修改和屏蔽了相关调用源码。
3)解决方案
直接修改和屏蔽本地nemo中nemo/lightning/callback_group.py的相关调用源码如下(参考之一):
#屏蔽相关引用
#from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback
... ...
def init(self) -> None:
#self._callbacks: List[BaseCallback] = [OneLoggerNeMoCallback()]
self._callbacks: List[BaseCallback] = []
Ensure application-end is emitted at most once per process
self._app_end_emitted: bool = False