分层图像金字塔变压器

文章来源:hierarchical-image-pyramid-transformers

2024 年 2 月 5 日

本文介绍了分层图像金字塔变换器 (HIPT),这是一种新颖的视觉变换器 (ViT) 架构,设计用于分析计算病理学中的十亿像素全幻灯片图像 (WSI)。 HIPT 利用 WSI 固有的层次结构通过自我监督学习来学习高分辨率图像表示。 HIPT 在涵盖 33 种癌症类型的大型数据集上进行预训练,并在多个幻灯片级任务中进行评估,在癌症亚型分型和生存预测方面表现出卓越的性能,展示了自我监督学习模型在捕获肿瘤微环境中关键的归纳偏差和表型方面的潜力。

本图展示了计算病理学中使用的全切片图像 (WSI) 的分层结构。左图显示的是多层次方法,在这种方法中,大型组织图像(150,000 x 150,000 像素)被分解成更小、更易于管理的部分:首先是显示组织表型的 4096 x 4096 区域,然后是 256 x 256 细胞组织斑块,最后是最小的 16 x 16 细胞特征。右图展示了 256 x 256 图像是如何由 256 个较小的 16 x 16 标记序列组成的,反过来,每个 256 x 256 图像又是如何成为 4096 x 4096 区域内 256 x 256 标记的更大的不连续序列的一部分。这种分层标记化方法可以处理和分析不同分辨率和比例的超大图像。

该模型由三个阶段的分层聚合组成,首先是自下而上地聚合各自 256x256 和 4096x4096 窗口中的 16x16 视觉标记,最终形成幻灯片级表示。HIPT 模型的主要组成部分可写如下:

  1. 分层聚合: HIPT 在细胞、斑块和区域层面聚合视觉标记,形成幻灯片表征。这种分层方法是受自然语言处理中使用分层表示法的启发,在自然语言处理中,嵌入可以在不同层次上聚合,形成文档表示法。同样,在 WSI 的背景下,分层聚合允许模型捕捉不同粒度级别的信息,从单个细胞到更广泛的组织结构。

  2. Transformer自注意力: 为了在聚合的每个阶段对视觉概念之间的重要依赖关系进行建模,HIPT 将 Transformer 自注意力调整为包络变换聚合层。这样,该模型就能捕捉视觉标记之间的复杂关系,并学习能编码图像中局部和全局上下文的表征。

  3. 预训练和自我监督学习: HIPT 采用自我监督学习的方式对 33 种癌症类型的千兆像素 WSI 大数据集进行预训练。该模型利用两个层次的自我监督学习来学习高分辨率图像表征,并利用学生-教师知识提炼来对每个聚合层进行预训练,对大至 4096x4096 的区域进行自我监督学习。

  4. 性能和应用: 研究结果表明,采用分层预训练的 HIPT 在幻灯片级任务上的表现优于目前最先进的方法。该模型的性能在包括癌症亚型和生存预测在内的 9 项幻灯片级任务上进行了评估,并显示其在捕捉组织微环境中更广泛的预后特征方面表现出色。

图中从左到右显示了三个聚合级别:

  1. 细胞级聚合: 单个细胞由 16 px tokens表示,然后使用 ViT256-16 模型将其聚合为片段级表示,再进行全局池化以获得单一矢量表示。
  2. 斑块级聚合: 使用专为 256 px 输入设计的更大 ViT 变体来处理 256 px 补丁,然后再次使用池化层将补丁级特征汇总为区域级表示。
  3. 区域级聚合: 最后,对 4096 px 的区域进行聚合,这一次使用的是将整个区域作为输入的 ViT,从而形成一个全局注意力汇集层,提供幻灯片级表示。

这一分层过程将问题分解为易于处理的部分,并关注从细胞到组织结构等不同层次的细节,从而使模型能够处理规模巨大的 WSI。

下面的脚本利用了专门用于高分辨率图像分析的视觉转换器(ViTs),并结合了几种先进的功能和技术:

  1. 截断法线初始化: 这是一种用于初始化神经网络权重的技术,可避免与平均值产生较大偏差,从而确保早期训练阶段的稳定性。

  2. Drop Path: 一种正则化方法,在训练过程中随机丢弃网络中的路径,通过模拟更薄的网络来提高泛化效果,类似于dropout,但针对的是残余连接。

  3. 多层感知器(MLP)模块: 定义一个简单的双层 MLP,具有 GELU 激活函数和滤除功能,用于在转换器模块中处理特征。

  4. 注意机制:采用可选偏置和缩放的自注意机制,这对捕捉输入数据中的全局依赖性至关重要。

  5. Transformer模块: 将规范层、注意机制和 MLP 组合成一个内聚块,并可选择路径剔除进行正则化。

  6. VisionTransformer4K:Vision Transformer 的专用版本,专为超高分辨率图像而设计,采用了位置嵌入插值等技术,以适应不同的图像尺寸,其结构也针对处理大规模图像进行了优化。

  7. 实用功能: 包括用于截断法线权重初始化、下落路径模拟和参数计算的函数,以帮助进行模型设置和分析。

    import argparse
    import os
    import sys
    import datetime
    import time
    import math
    import json
    from pathlib import Path
    import numpy as np
    from PIL import Image
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    import torch.backends.cudnn as cudnn
    import torch.nn.functional as F
    from torchvision import datasets, transforms
    from torchvision import models as torchvision_models
    import vision_transformer as vits
    from vision_transformer import DINOHead
    import math
    from functools import partial
    import torch
    import torch.nn as nn
    def no_grad_trunc_normal(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
    # Computes standard normal cumulative distribution function
    return (1. + math.erf(x / math.sqrt(2.))) / 2.
    if (mean < a - 2 * std) or (mean > b + 2 * std):
    warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
    "The distribution of values may be incorrect.",
    stacklevel=2)
    with torch.no_grad():
    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)
    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)
    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()
    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)
    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor
    def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return no_grad_trunc_normal(tensor, mean, std, a, b)

    def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
    return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_() # binarize
    output = x.div(keep_prob) * random_tensor
    return output
    class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """
    def init(self, drop_prob=None):
    super(DropPath, self).init()
    self.drop_prob = drop_prob
    def forward(self, x):
    return drop_path(x, self.drop_prob, self.training)
    class Mlp(nn.Module):
    def init(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
    super().init()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features, out_features)
    self.drop = nn.Dropout(drop)
    def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x
    class Attention(nn.Module):
    def init(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
    super().init()
    self.num_heads = num_heads
    head_dim = dim // num_heads
    self.scale = qk_scale or head_dim ** -0.5
    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x, attn
    class Block(nn.Module):
    def init(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
    drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
    super().init()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(
    dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    def forward(self, x, return_attention=False):
    y, attn = self.attn(self.norm1(x))
    if return_attention:
    return attn
    x = x + self.drop_path(y)
    x = x + self.drop_path(self.mlp(self.norm2(x)))
    return x
    class VisionTransformer4K(nn.Module):
    """ Vision Transformer 4K """
    def init(self, num_classes=0, img_size=[224], input_embed_dim=384, output_embed_dim = 192,
    depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
    drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, num_prototypes=64, **kwargs):
    super().init()
    embed_dim = output_embed_dim
    self.num_features = self.embed_dim = embed_dim
    self.phi = nn.Sequential(*[nn.Linear(input_embed_dim, output_embed_dim), nn.GELU(), nn.Dropout(p=drop_rate)])
    num_patches = int(img_size[0] // 16)**2
    print("# of Patches:", num_patches)

         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         self.pos_drop = nn.Dropout(p=drop_rate)
         dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
         self.blocks = nn.ModuleList([
             Block(
                 dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
             for i in range(depth)])
         self.norm = norm_layer(embed_dim)
         # Classifier head
         self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
         trunc_normal_(self.pos_embed, std=.02)
         trunc_normal_(self.cls_token, std=.02)
         self.apply(self._init_weights)
     def _init_weights(self, m):
         if isinstance(m, nn.Linear):
             trunc_normal_(m.weight, std=.02)
             if isinstance(m, nn.Linear) and m.bias is not None:
                 nn.init.constant_(m.bias, 0)
         elif isinstance(m, nn.LayerNorm):
             nn.init.constant_(m.bias, 0)
             nn.init.constant_(m.weight, 1.0)
     def interpolate_pos_encoding(self, x, w, h):
         npatch = x.shape[1] - 1
         N = self.pos_embed.shape[1] - 1
         if npatch == N and w == h:
             return self.pos_embed
         class_pos_embed = self.pos_embed[:, 0]
         patch_pos_embed = self.pos_embed[:, 1:]
         dim = x.shape[-1]
         w0 = w // 1
         h0 = h // 1
         # we add a small number to avoid floating point error in the interpolation
         # see discussion at https://github.com/facebookresearch/dino/issues/8
         w0, h0 = w0 + 0.1, h0 + 0.1
         patch_pos_embed = nn.functional.interpolate(
             patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
             scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
             mode='bicubic',
         )
         assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
     def prepare_tokens(self, x):
         #print('preparing tokens (after crop)', x.shape)
         self.mpp_feature = x
         B, embed_dim, w, h = x.shape
         x = x.flatten(2, 3).transpose(1,2)
         x = self.phi(x)
         # add the [CLS] token to the embed patch tokens
         cls_tokens = self.cls_token.expand(B, -1, -1)
         x = torch.cat((cls_tokens, x), dim=1)
         # add positional encoding to each token
         x = x + self.interpolate_pos_encoding(x, w, h)
         return self.pos_drop(x)
     def forward(self, x):
         x = self.prepare_tokens(x)
         for blk in self.blocks:
             x = blk(x)
         x = self.norm(x)
         return x[:, 0]
     def get_last_selfattention(self, x):
         x = self.prepare_tokens(x)
         for i, blk in enumerate(self.blocks):
             if i < len(self.blocks) - 1:
                 x = blk(x)
             else:
                 # return attention of the last block
                 return blk(x, return_attention=True)
     def get_intermediate_layers(self, x, n=1):
         x = self.prepare_tokens(x)
         # we return the output tokens from the `n` last blocks
         output = []
         for i, blk in enumerate(self.blocks):
             x = blk(x)
             if len(self.blocks) - i <= n:
                 output.append(self.norm(x))
         return output
    

    def vit4k_xs(patch_size=16, **kwargs):
    model = VisionTransformer4K(
    patch_size=patch_size, input_embed_dim=384, output_embed_dim=192,
    depth=6, num_heads=6, mlp_ratio=4,
    qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
    def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

下面的代码脚本概述了加载和评估用于图像分析的 Vision Transformer (ViT) 模型的实现过程,该模型专为计算病理学中的高分辨率图像而设计。它定义了以下功能

  1. 加载预训练的 ViT 模型(`get_vit256` 和 `get_vit4k`),并为不同的架构和设备设置提供选项,在不进行梯度计算的评估模式下对其进行初始化。

  2. 为模型评估应用变换 (`eval_transforms`),以特定的平均值和标准偏差对图像进行归一化处理。

  3. 将成批的图像张量转换为单个 PIL 图像(`roll_batch2img`)或 numpy 数组(`tensorbatch2im`),便于处理图像数据,以实现可视化或进一步处理。

    Dependencies

    Base Dependencies

    import argparse
    import colorsys
    from io import BytesIO
    import os
    import random
    import requests
    import sys

    LinAlg / Stats / Plotting Dependencies

    import cv2
    import h5py
    import matplotlib
    import matplotlib.pyplot as plt
    from matplotlib.patches import Polygon
    import numpy as np
    from PIL import Image
    from PIL import ImageFont
    from PIL import ImageDraw
    from scipy.stats import rankdata
    import skimage.io
    from skimage.measure import find_contours
    from tqdm import tqdm
    import webdataset as wds

    Torch Dependencies

    import torch
    import torch.multiprocessing
    import torchvision
    from torchvision import transforms
    from einops import rearrange, repeat
    torch.multiprocessing.set_sharing_strategy('file_system')

    Local Dependencies

    import vision_transformer as vits
    import vision_transformer4k as vits4k
    def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cuda:0')):
    r"""
    Builds ViT-256 Model.

     Args:
     - pretrained_weights (str): Path to ViT-256 Model Checkpoint.
     - arch (str): Which model architecture.
     - device (torch): Torch device to save model.
     
     Returns:
     - model256 (torch.nn): Initialized model.
     """
     
     checkpoint_key = 'teacher'
     device = torch.device("cpu")
     model256 = vits.__dict__[arch](patch_size=16, num_classes=0)
     for p in model256.parameters():
         p.requires_grad = False
     model256.eval()
     model256.to(device)
     if os.path.isfile(pretrained_weights):
         state_dict = torch.load(pretrained_weights, map_location="cpu")
         if checkpoint_key is not None and checkpoint_key in state_dict:
             print(f"Take key {checkpoint_key} in provided checkpoint dict")
             state_dict = state_dict[checkpoint_key]
         # remove `module.` prefix
         state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
         # remove `backbone.` prefix induced by multicrop wrapper
         state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
         msg = model256.load_state_dict(state_dict, strict=False)
         print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
         
     return model256
    

    def get_vit4k(pretrained_weights, arch='vit4k_xs', device=torch.device('cuda:1')):
    r"""
    Builds ViT-4K Model.

     Args:
     - pretrained_weights (str): Path to ViT-4K Model Checkpoint.
     - arch (str): Which model architecture.
     - device (torch): Torch device to save model.
     
     Returns:
     - model256 (torch.nn): Initialized model.
     """
     
     checkpoint_key = 'teacher'
     device = torch.device("cpu")
     model4k = vits4k.__dict__[arch](num_classes=0)
     for p in model4k.parameters():
         p.requires_grad = False
     model4k.eval()
     model4k.to(device)
     if os.path.isfile(pretrained_weights):
         state_dict = torch.load(pretrained_weights, map_location="cpu")
         if checkpoint_key is not None and checkpoint_key in state_dict:
             print(f"Take key {checkpoint_key} in provided checkpoint dict")
             state_dict = state_dict[checkpoint_key]
         # remove `module.` prefix
         state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
         # remove `backbone.` prefix induced by multicrop wrapper
         state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
         msg = model4k.load_state_dict(state_dict, strict=False)
         print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
         
     return model4k
    

    def eval_transforms():
    """
    """
    mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
    eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
    return eval_t
    def roll_batch2img(batch: torch.Tensor, w: int, h: int, patch_size=256):
    """
    Rolls an image tensor batch (batch of [256 x 256] images) into a [W x H] Pil.Image object.
    Args:
    batch (torch.Tensor): [B x 3 x 256 x 256] image tensor batch.
    Return:
    Image.PIL: [W x H X 3] Image.
    """
    batch = batch.reshape(w, h, 3, patch_size, patch_size)
    img = rearrange(batch, 'p1 p2 c w h-> c (p1 w) (p2 h)').unsqueeze(dim=0)
    return Image.fromarray(tensorbatch2im(img)[0])
    def tensorbatch2im(input_image, imtype=np.uint8):
    r""""
    Converts a Tensor array into a numpy image array.

     Args:
         - input_image (torch.Tensor): (B, C, W, H) Torch Tensor.
         - imtype (type): the desired type of the converted numpy array
         
     Returns:
         - image_numpy (np.array): (B, W, H, C) Numpy Array.
     """
     if not isinstance(input_image, np.ndarray):
         image_numpy = input_image.cpu().float().numpy()  # convert it into a numpy array
         #if image_numpy.shape[0] == 1:  # grayscale to RGB
         #    image_numpy = np.tile(image_numpy, (3, 1, 1))
         image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
     else:  # if it is a numpy array, do nothing
         image_numpy = input_image
     return image_numpy.astype(imtype)
    

该脚本定义了 HIPT_4K 模型,整合了用于处理高分辨率图像的视觉转换器模型。它为 256x256 和 4K 分辨率加载预先训练好的 ViT 模型,并将其分层应用于输入图像。这一过程包括将输入图像裁剪成 256x256 的补丁,使用 ViT_256 从每个补丁中提取特征,然后将这些特征输入 ViT_4K,以获得全局表示。这种分层方法能有效处理非方形的高分辨率图像,优化局部和全局尺度的详细特征提取,与本文利用分层结构进行图像分析的方法相一致。

import torch
from einops import rearrange, repeat
from HIPT_4K.hipt_model_utils import get_vit256, get_vit4k
class HIPT_4K(torch.nn.Module):
    """
    HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with 
    [256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens.
    """
    def __init__(self, 
        model256_path: str = 'path/to/Checkpoints/vit256_small_dino.pth',
        model4k_path: str = 'path/to/Checkpoints/vit4k_xs_dino.pth', 
        device256=torch.device('cuda:0'), 
        device4k=torch.device('cuda:1')):
        super().__init__()
        self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
        self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
        self.device256 = device256
        self.device4k = device4k
        self.patch_filter_params = patch_filter_params
 
    def forward(self, x):
        """
        Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K.
        1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256).
        2. x then gets unfolded into a "batch" of [256 x 256] images.
        3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
        4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
        5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K.
        Args:
          - x (torch.Tensor): [1 x C x W' x H'] image tensor.
        Return:
          - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
        """
        batch_256, w_256, h_256 = self.prepare_img_tensor(x)                    # 1. [1 x 3 x W x H].
        batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256)           # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 
        batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h')    # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)
        features_cls256 = []
        for mini_bs in range(0, batch_256.shape[0], 256):                       # 3. B may be too large for ViT_256. We further take minibatches of 256.
            minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
            features_cls256.append(self.model256(minibatch_256).detach().cpu()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.
        features_cls256 = torch.vstack(features_cls256)                         # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token.
        features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0) 
        features_cls256 = features_cls256.to(self.device4k, non_blocking=True)  # 4. [1 x 384 x w_256 x h_256]
        features_cls4k = self.model4k.forward(features_cls256)                  # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token.
        return features_cls4k
相关推荐
blammmp23 分钟前
Java:数据结构-枚举
java·开发语言·数据结构
昂子的博客1 小时前
基础数据结构——队列(链表实现)
数据结构
用户691581141651 小时前
Ascend Extension for PyTorch的源码解析
人工智能
Chef_Chen1 小时前
从0开始学习机器学习--Day13--神经网络如何处理复杂非线性函数
神经网络·学习·机器学习
Troc_wangpeng1 小时前
R language 关于二维平面直角坐标系的制作
开发语言·机器学习
用户691581141651 小时前
Ascend C的编程模型
人工智能
-Nemophilist-1 小时前
机器学习与深度学习-1-线性回归从零开始实现
深度学习·机器学习·线性回归
lulu_gh_yu1 小时前
数据结构之排序补充
c语言·开发语言·数据结构·c++·学习·算法·排序算法
成富2 小时前
文本转SQL(Text-to-SQL),场景介绍与 Spring AI 实现
数据库·人工智能·sql·spring·oracle
CSDN云计算2 小时前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab