average_weights.py

import argparse

from pathlib import Path

import torch

def _extract_state(ckpt, use_ema=False):

if isinstance(ckpt, dict):

if use_ema and 'ema' in ckpt:

model = ckpt['ema']

return model.float().state_dict(), ckpt, 'ema'

if 'model' in ckpt:

model = ckpt['model']

return model.float().state_dict(), ckpt, 'model'

if 'state_dict' in ckpt:

return ckpt['state_dict'], ckpt, 'state_dict'

if isinstance(ckpt, dict):

return ckpt, ckpt, 'state_dict'

raise ValueError('Unsupported checkpoint format')

def average_checkpoints(model_dir, output, pattern='*.pt', use_ema=False, strict_keys=True):

model_dir = Path(model_dir)

files = sorted(model_dir.glob(pattern))

if len(files) < 2:

raise ValueError(f'Need at least 2 checkpoints, got {len(files)} in {model_dir}')

first_state, first_ckpt, first_key = _extract_state(torch.load(files[0], map_location='cpu'), use_ema)

keys = list(first_state.keys())

avg_state = {}

for k in keys:

v = first_state[k]

avg_state[k] = v.float().clone() if torch.is_floating_point(v) else v

for f in files[1:]:

state, _, _ = _extract_state(torch.load(f, map_location='cpu'), use_ema)

if strict_keys and set(state.keys()) != set(keys):

raise ValueError(f'Key mismatch in {f}')

for k in keys:

if k not in state:

if strict_keys:

raise ValueError(f'Missing key {k} in {f}')

continue

v = state[k]

if torch.is_floating_point(v):

avg_state[k] += v.float()

else:

avg_state[k] = avg_state[k]

for k in keys:

if torch.is_floating_point(avg_state[k]):

avg_state[k] /= float(len(files))

if first_key in ('model', 'ema'):

first_ckpt[first_key].load_state_dict(avg_state, strict=False)

if first_key == 'ema' and 'model' in first_ckpt:

first_ckpt['model'].load_state_dict(avg_state, strict=False)

torch.save(first_ckpt, output)

else:

torch.save(avg_state, output)

return files

def main():

parser = argparse.ArgumentParser()

parser.add_argument('--dir', type=str, required=True, help='directory with checkpoints')

parser.add_argument('--output', type=str, required=True, help='output checkpoint path')

parser.add_argument('--pattern', type=str, default='*.pt', help='checkpoint filename pattern')

parser.add_argument('--use-ema', action='store_true', help='average EMA weights if available')

parser.add_argument('--strict-keys', action='store_true', help='require identical state_dict keys')

args = parser.parse_args()

files = average_checkpoints(

model_dir=args.dir,

output=args.output,

pattern=args.pattern,

use_ema=args.use_ema,

strict_keys=args.strict_keys,

)

print(f'Averaged {len(files)} checkpoints to {args.output}')

if name == 'main':

main()

已新增模型参数平均脚本 average_weights.py,输入一个目录,自动把目录下多个 checkpoint 做参数平均并保存为新权重。支持使用 EMA 权重,且可选择严格检查 key 是否一致。

用法示例:

python average_weights.py --dir "E:\同致电子\动态分辨率\runs\train\exp\weights" --output "E:\同致电子\动态分辨率\avg.pt"

如果你要用 EMA 权重平均:

python average_weights.py --dir "E:\...\weights" --output "E:\...\avg.pt" --use-ema

如果想强制所有模型 key 必须完全一致:

python average_weights.py --dir "E:\...\weights" --output "E:\...\avg.pt" --strict-keys

说明:

  • 默认匹配模式是 *.pt,可用 --pattern 调整。
  • 生成的 avg.pt 会保留与首个 checkpoint 相同的结构(model 或 ema),后续可以直接拿去量化。

如果你希望只平均指定的几个 epoch(不是目录中全部),告诉我你想按什么规则筛选(文件名、epoch 号列表、正则),我可以加筛选参数。

为了使的平均后的模型效果最好,--use-ema和--strict-keys要开启吗

简短结论:

  • --use-ema:建议开启,前提是这些 checkpoint 里确实有 EMA。EMA 往往更稳、更接近最优验证效果,平均 EMA 通常比平均普通权重更好。
  • --strict-keys:建议开启,前提是这些模型结构完全一致(同一 cfg/类数/分割头)。这样能保证平均的是同一层,避免"误对齐"带来的退化。

例外情况:

  • 如果部分模型来自不同结构或类数,--strict-keys 会直接报错,这时要么筛掉不一致的模型,要么不用 strict(但平均效果通常不如严格一致)。
  • 如果某些 checkpoint 没有 EMA,--use-ema 会回退到 model 吗?当前脚本不会自动回退,会按 ema 读取不到就报错,因此要确认这些权重都带 EMA。

最稳妥的组合(同结构同数据):--use-ema --strict-keys。

相关推荐
霖大侠16 分钟前
CPAL: Cross-Prompting Adapter With LoRAs forRGB+X Semantic Segmentation
人工智能·深度学习·算法·机器学习·transformer
飞Link1 小时前
工业级时序异常检测利器:USAD 算法深度解析与实战
人工智能·深度学习·机器学习
白雨青1 小时前
国信 iQuant 自动国债逆回购实战:Python 自动化闲钱理财
python·量化策略·量化交易·国债逆回购
qq_404265831 小时前
用Python批量处理Excel和CSV文件
jvm·数据库·python
才兄说2 小时前
机器人租售效果好吗?任务前对齐需求
python
喵手2 小时前
Python 爬虫实战:构建开源主题模板版本库
爬虫·python·数据采集·爬虫实战·零基础python爬虫教学·开源主题·采集开源主题模版本库
qq_418101772 小时前
使用Scikit-learn进行机器学习模型评估
jvm·数据库·python
2601_953465613 小时前
HLS.js 原生开发!m3u8live.cn打造最贴合项目的 M3U8 在线播放器
开发语言·前端·javascript·python·json·ecmascript·前端开发工具
szcsun53 小时前
python中包、模块的层级关系,以及import、from...import...的相关用法
开发语言·python
高洁013 小时前
数字孪生在航空领域的应用方法及案例
python·深度学习·信息可视化·数据挖掘·transformer