【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py

文件内容:CenterFusion/src/lib/model/model.py

文件作用:模型的创建、导入、保存

model.py 具体内容如下:

python 复制代码
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torchvision.models as models
import torch
import torch.nn as nn
import os

from .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork

_network_factory = {
  'resdcn': PoseResDCN,
  'dla': DLASeg,
  'res': PoseResNet,
  'dlav0': DLASegv0,
  'generic': GenericNetwork
}

def create_model(arch, head, head_conv, opt=None):

  num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
  '''
  处理字符串 arch = dla_34 ,将下划线后半部分取出
  最后 num_layers = 34
  '''

  arch = arch[:arch.find('_')] if '_' in arch else arch
  '''
  将 arch = dla_34 中下划线前半部分取出
  最后 arch = 'dla'
  '''

  model_class = _network_factory[arch]
  '''
  根据 arch = 'dla' 获取 _network_factory 中的值
  最后 model_class = DLASeg
  DLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行
  '''

  model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)
  '''
  配置模型
  '''

  return model

def load_model(model, model_path, opt, optimizer=None):

  start_epoch = 0
  '''
  设定初始轮次 = 0
  '''

  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
  print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
  '''
  torch.load() 函数:用来加载 torch.save() 保存的模型文件
  '''

  state_dict_ = checkpoint['state_dict']
  '''
  获取 checkpoint 模型文件中的 state_dict 属性
  这个属性存放训练过程中需要学习的权重和偏执系数
  state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量
  需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数
  '''

  state_dict = {}

  for k in state_dict_:
    if k.startswith('module') and not k.startswith('module_list'):
      state_dict[k[7:]] = state_dict_[k]
    else:
      state_dict[k] = state_dict_[k]
  '''
  startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False
  这里只执行了 else 语句,相当于保存导入模型的网络参数
  '''
  
  model_state_dict = model.state_dict()
  '''
  浅拷贝 main.py 中创建的新模型 DLA 的网络参数
  '''

  for k in state_dict:
    '''
    遍历导入的模型中的每层网络参数
    '''
    if k in model_state_dict:
      '''
      判断新模型的网络参数中是否有导入的模型的参数
      是有的,因为导入的模型也是 DLA 模型
      '''
      if (state_dict[k].shape != model_state_dict[k].shape) or \
        (opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):
        '''
        第一个条件为 True
        其余条件全部为 False
        '''
        if opt.reuse_hm:
          '''
          不执行
          '''
          print('Reusing parameter {}, required shape{}, '\
                'loaded shape{}.'.format(
            k, model_state_dict[k].shape, state_dict[k].shape))
          # todo: bug in next line: both sides of < are the same
          if state_dict[k].shape[0] < state_dict[k].shape[0]:
            model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]
          else:
            model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]
          state_dict[k] = model_state_dict[k]
        
        elif opt.warm_start_weights:
          '''
          不执行
          '''
          try:
            print('Partially loading parameter {}, required shape{}, '\
                  'loaded shape{}.'.format(
              k, model_state_dict[k].shape, state_dict[k].shape))
            if state_dict[k].shape[1] < model_state_dict[k].shape[1]:
              model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]
            else:
              model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]
            state_dict[k] = model_state_dict[k]
          except:
            print('Skip loading parameter {}, required shape{}, '\
                'loaded shape{}.'.format(
                k, model_state_dict[k].shape, state_dict[k].shape))
            state_dict[k] = model_state_dict[k]
        
        else:
          '''
          执行该 else 中的语句
          '''
          print('Skip loading parameter {}, required shape{}, '\
                'loaded shape{}.'.format(
            k, model_state_dict[k].shape, state_dict[k].shape))
          state_dict[k] = model_state_dict[k]
          '''
          将新模型的网络参数赋值给导入的模型中
          '''
    else:
      print('Drop parameter {}.'.format(k))

  for k in model_state_dict:
    if not (k in state_dict):
      print('No param {}.'.format(k))
      state_dict[k] = model_state_dict[k]
  '''
  给导入的模型添加没有的参数
  '''
  
  model.load_state_dict(state_dict, strict=False)
  '''
  使用 state_dict 反序列化模型参数字字典,用来加载模型参数
  将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中
  简述:给模型对象加载训练好的模型参数,即加载模型参数
  '''

 #冻结骨干网,没有执行
  if opt.freeze_backbone:
    for (name, module) in model.named_children():
      if name in opt.layers_to_freeze:
        for (name, layer) in module.named_children():
          for param in layer.parameters():
            param.requires_grad = False

  # 恢复优化器参数,没有执行
  if optimizer is not None and opt.resume:
    if 'optimizer' in checkpoint:
      start_epoch = checkpoint['epoch']
      start_lr = opt.lr
      for step in opt.lr_step:
        if start_epoch >= step:
          start_lr *= 0.1
      for param_group in optimizer.param_groups:
        param_group['lr'] = start_lr
      print('Resumed optimizer with start lr', start_lr)
    else:
      print('No optimizer parameters in checkpoint.')
  if optimizer is not None:
    '''
    执行该 if 语句
    '''
    return model, optimizer, start_epoch
  else:
    return model

def save_model(path, epoch, model, optimizer=None):

  if isinstance(model, torch.nn.DataParallel):
    '''
    isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo
    是则返回 True,反之返回 False
    '''
    state_dict = model.module.state_dict()
  else:
    state_dict = model.state_dict()
  '''
  获取模型的参数矩阵
  '''

  data = {'epoch': epoch,
          'state_dict': state_dict}
  
  if not (optimizer is None):
    data['optimizer'] = optimizer.state_dict()
  '''
  获取模型的优化器
  '''

  torch.save(data, path)
  '''
  保存模型
  '''
相关推荐
2401_833033628 分钟前
如何修复固定定位头部容器中悬浮下拉菜单的错位问题
jvm·数据库·python
熊猫钓鱼>_>10 分钟前
当“虾”遇上“马”:QClaw 融合 Hermes 背后的智能体进化论
人工智能·ai·腾讯云·agent·openclaw·qclaw·hermes
深念Y13 分钟前
Denuvo加密被全面攻破?聊聊D加密原理和这次的破解事件
人工智能·游戏·ai·逆向·虚拟机·虚拟·d加密
KKKlucifer17 分钟前
日志审计与行为分析在安全服务中的应用实践
网络·人工智能·安全
SelectDB18 分钟前
Doris & SelectDB for AI 实战:从基础 RAG 到知识图谱增强的完整实现
数据库·人工智能·数据分析
Agent产品评测局20 分钟前
生产排期与MES/ERP系统打通,实操方法详解:2026企业级智能体与超自动化集成实战指南
运维·人工智能·ai·chatgpt·自动化
GitCode官方21 分钟前
一声唤醒 万物响应|AtomGit 首款开源鸿蒙 AI 硬件「小鸿」发布会圆满落幕 定义智能交互新入口
人工智能·开源·harmonyos
互联网志21 分钟前
打通转化通道 赋能产业发展——高校科技成果转化的现状与破局
大数据·人工智能·物联网
z44247532628 分钟前
CSS Grid布局如何实现网格项目的自动增长_设置grid-auto-flow- row
jvm·数据库·python
GeLx28 分钟前
从反爬角度:Playwright CDP 模式、Playwright 传统模式与 DrissionPage 的比较
python·程序人生·playwright·drissionpage·pyppeteer·浏览器自动化控制