如何快速看懂并修改神经网络


前言:个人之见,一个神经网络网络源码出现,你先看数据集的输入和输出,而这数据集肯定要包括数据增加和制作数据集,第二 看模型的输入和输出(至于模型内部可以自己看论文 无非就是加了几个组件),然后根据输出选择的损失函数。至于学习率和优化器 差不多都是余弦退火和admw的优化器


1.数据集

直接实战,首先你看它的readme,它一般由标注文件的格式(一般都是 文件路径 + 对应的标签数字)(要求自己制作)

输入一般都是这个标注文件,输出一般都是元组或者字典。

数据增强一般包含在数据集的制作当中

actionclip

数据增强(空间剪裁)

数据增强源码

python 复制代码
from datasets.transforms_ss import *
from RandAugment import RandAugment

class GroupTransform(object):
    def __init__(self, transform):
        self.worker = transform

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]

def get_augmentation(training, config):
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]
    scale_size = config.data.input_size * 256 // 224
    if training:

        unique = torchvision.transforms.Compose([GroupMultiScaleCrop(config.data.input_size, [1, .875, .75, .66]),
                                                 GroupRandomHorizontalFlip(is_sth='some' in config.data.dataset),
                                                 GroupRandomColorJitter(p=0.8, brightness=0.4, contrast=0.4,
                                                                        saturation=0.2, hue=0.1),
                                                 GroupRandomGrayscale(p=0.2),
                                                 GroupGaussianBlur(p=0.0),
                                                 GroupSolarization(p=0.0)]
                                                )
    else:
        unique = torchvision.transforms.Compose([GroupScale(scale_size),
                                                 GroupCenterCrop(config.data.input_size)])

    common = torchvision.transforms.Compose([Stack(roll=False),
                                             ToTorchFormatTensor(div=True),
                                             GroupNormalize(input_mean,
                                                            input_std)])
    return torchvision.transforms.Compose([unique, common])

def randAugment(transform_train,config):
    print('Using RandAugment!')
    transform_train.transforms.insert(0, GroupTransform(RandAugment(config.data.randaug.N, config.data.randaug.M)))
    return transform_train

这个数据增强 你可以直接 参考()

一般直接蕴含在数据集

python 复制代码
 def __init__(self, list_file, labels_file,
                 num_segments=1, new_length=1,
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 random_shift=True, test_mode=False, index_bias=1):
 
    def get(self, record, indices):
        images = list()
        for i, seg_ind in enumerate(indices):
            p = int(seg_ind)
            try:
                seg_imgs = self._load_image(record.path, p)
            except OSError:
                print('ERROR: Could not read image "{}"'.format(record.path))
                print('invalid indices: {}'.format(indices))
                raise
            images.extend(seg_imgs)
        process_data = self.transform(images)
        return process_data, record.label
  • 空间剪裁 无疑就是进行多少词crop 你得了解一手 ranaugment函数
数据集的制作(时间剪裁以及帧数实现)
  • 输入
    actionclip的标注文件为:

    /public/datasets/kinetics400/data2/extracted_train_frames/bowling/HfI4vN2vbHU_000000_000010 289 31
    /public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/B8FXlmO5zk4_000079_000089 240 29
    /public/datasets/kinetics400/data2/extracted_train_frames/abseiling/XsEw1vd32l8_000052_000062 300 0
    /public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/r61D2lDCHsM_000268_000278 240 18
    /public/datasets/kinetics400/data2/extracted_train_frames/abseiling/4sCQ-EX6cIg_000021_000031 300 0
    /public/datasets/kinetics400/data2/extracted_train_frames/bowling/N9mQC7MeZCk_000008_000018 300 31
    /public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/fzVhIrMnY-E_000322_000332 250 1
    /public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/6dLNI2BPTY0_000057_000067 250 23
    /public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/othYtMhFdOU_000020_000030 250 29
    /public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/JVSxlojnBYk_000047_000057 300 18
    /public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/8jO9DeYLruU_000003_000013 300 1
    /public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/pU12_c-XvU_000045_000055 300 18
    /public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/x6rP9b1V7sQ_000060_000070 250 18
    /public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/jqC2SnFAvoM_000092_000102 300 23
    /public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri6AwOp59yA_000009_000019 250 31
    /public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/wRaacvxMoc8_000014_000024 150 1
    /public/datasets/kinetics400/data2/extracted_train_frames/abseiling/7kbO0v4hag_000107_000117 300 0
    /public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/GjtR9KZbV3Y_000494_000504 300 29
    /public/datasets/kinetics400/data2/extracted_train_frames/abseiling/hwUQqFadvE_000048_000058 250 0
    /public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/vXmgE41UnBk_000844_000854 300 29
    /public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/dglCzcubsw_000246_000256 159 1
    /public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri1H0ygN3Us_000768_000778 300 31
    /public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/n24zV9OtorU_000257_000267 300 18
    /public/datasets/kinetics400/data2/extracted_train_frames/abseiling/nKoqxSJcZn8_000071_000081 250 0
    /public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/pT2byS0qiZM_000001_000011 150 1
    /public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/CMo6AJhtZo_000075_000085 250 29

视频提起帧 视频总帧数 对应的标签数字

  • 输出
    一般看__getitem_
python 复制代码
    def __getitem__(self, index):
        record = self.video_list[index]
        segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        return self.get(record, segment_indices)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]


    def get(self, record, indices):
        images = list()
        for i, seg_ind in enumerate(indices):
            p = int(seg_ind)
            try:
                seg_imgs = self._load_image(record.path, p)
            except OSError:
                print('ERROR: Could not read image "{}"'.format(record.path))
                print('invalid indices: {}'.format(indices))
                raise
            images.extend(seg_imgs)
        process_data = self.transform(images)
        return process_data, record.label

返回元组 (images,labes)

  • 帧数 一般num_segment由这个决定 为什么?
    因为我看顶刊 基本上 一个片段抽一政数,这个无疑由片段决定
  • 时间剪裁
    时间剪裁指的是从视频的时间维度上选取特定的帧(验证数据集)
python 复制代码
  def _get_val_indices(self, record):
        if self.num_segments == 1:
            return np.array([record.num_frames //2], dtype=np.int) + self.index_bias
        
        if record.num_frames <= self.total_length:
            if self.loop:
                return np.mod(np.arange(self.total_length), record.num_frames) + self.index_bias
            return np.array([i * record.num_frames // self.total_length
                             for i in range(self.total_length)], dtype=np.int) + self.index_bias
        offset = (record.num_frames / self.num_segments - self.seg_length) / 2.0
        return np.array([i * record.num_frames / self.num_segments + offset + j
                         for i in range(self.num_segments)
                         for j in range(self.seg_length)], dtype=np.int) + self.index_bias

帧数不足时

当 self.loop 为 True 时,通过 np.mod(np.arange(self.total_length), record.num_frames) 循环选取视频帧,确保选取的帧数达到 self.total_length,这是一种时间剪裁方式,通过循环利用现有帧来满足所需的帧数。

当 self.loop 为 False 时,使用 i * record.num_frames // self.total_length 均匀地从视频中选取 self.total_length 帧,同样实现了时间维度上的剪裁。
在视频帧数充足的情况下,先根据 self.num_segments 划分片段,然后在每个片段内选取连续的 self.seg_length 帧。offset 确保每个片段内选取的帧在片段中处于相对居中的位置,通过这种方式实现了在每个片段内的时间剪裁。

x-clip

数据集

1.参考一下这一篇 关于数据集的输入输出

2 讲一下时间剪裁

python 复制代码
val_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, test_mode=True),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, scale_resize)),
        dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs'])
    ]
    if config.TEST.NUM_CROP == 3:
        val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE))
        val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
    if config.TEST.NUM_CLIP > 1:
        val_pipeline[1] = dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, multiview=config.TEST.NUM_CLIP)
    

multiview=config.TEST.NUM_CLIP)无疑是控制为时间剪裁的数量

3 空间剪裁
val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)

这个更加直观了直接剪了三次 所以为3

2 模型

action-clip

从输入而言:

  • 文本
    classes, num_text_aug, text_dict = text_prompt(train_data)
    class为( num_text_augxnum_class,context)
    text_dict为(num_class,context)
    num_text_aug为填充内容长度
python 复制代码
text_id = numpy.random.randint(num_text_aug,size=len(list_id))
            texts = torch.stack([text_dict[j][i,:] for i,j in zip(list_id,text_id)])

分为了(B,context)

  • 图片
python 复制代码
images = images.view((-1,config.data.num_segments,3)+images.size()[-2:])
           b,t,c,h,w = images.size()
  images= images.to(device).view(-1,c,h,w ) 

这个论文严格意义上 是借用 clip的编码器 所以它压缩了

输出也简单

  • 文件
    text_embedding = model_text(texts)(b,d)

  • 图片

    image_embedding = model_image(images)
    image_embedding = image_embedding.view(b,t,-1)
    image_embedding = fusion_model(image_embedding)

关于这个fusion输出x.mean(dim=1, keepdim=False)

会把t压缩 x 变成了 (b,d)

x-clip

  • 文本
    text_labels = generate_text(train_data) 这个为(num_class(k),77)
    (和上面同理),但是它没有转为样本数
  • 图片
    images = images.view((-1, config.DATA.NUM_FRAMES, 3) + images.size()[-2:])
    它内部实现了一个编码器
python 复制代码
    def encode_video(self, image):
        b,t,c,h,w = image.size()
        image = image.reshape(-1,c,h,w)

        cls_features, img_features = self.encode_image(image)
        img_features = self.prompts_visual_ln(img_features)
        img_features = img_features @ self.prompts_visual_proj
        
        cls_features = cls_features.view(b, t, -1)
        img_features = img_features.view(b,t,-1,cls_features.shape[-1])
        
        video_features = self.mit(cls_features)

        return video_features, img_features

image = image.reshape(-1,c,h,w) 内部化了

输出:

复制代码
logit_scale = self.logit_scale.exp()
 logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)
        
  return logits

返回了一个b k 相似度得分

相关推荐
机器之心2 分钟前
在GSM8K上比GRPO快8倍!厦大提出CPPO,让强化学习快如闪电
人工智能
果冻人工智能6 分钟前
我们的灵魂需要“工作量证明”, 论在人工智能时代的欲望与安逸
人工智能
机器之心6 分钟前
自动学会工具解题,RL扩展催化奥数能力激增17%
人工智能
Shockang7 分钟前
机器学习的一百个概念(6)最小最大缩放
人工智能·机器学习
枉费红笺13 分钟前
目标检测的训练策略
人工智能·目标检测·计算机视觉
进取星辰14 分钟前
PyTorch 深度学习实战(30):模型压缩与量化部署
人工智能·pytorch·深度学习
新智元14 分钟前
吉卜力太火,奥特曼求饶!GPT-4o 免费生图登王座,设计师直呼天塌了
人工智能·openai
新智元19 分钟前
OpenAI 要 Open 了!奥特曼开源首个推理模型,ChatGPT 一小时暴增百万用户
人工智能·openai
新智元24 分钟前
DeepSeek-V3 击败 R1 开源登顶!杭州黑马撼动硅谷 AI 霸主,抹去 1 万亿市值神话
人工智能·openai
wei_shuo2 小时前
DeepSeek-R1 模型现已在亚马逊云科技上推出
人工智能·amazon