huggingface/pytorch-image-models
1.训练指令
单卡:
bash
python train.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 128 --validation-batch-size 128 --color-jitter-prob 0.2 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡,下面参数的4表示4块卡一起训练:
bash
sh distributed_train.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡的另一种形式,更改监听的端口号:
bash
python -m torch.distributed.launch --nproc_per_node=3 --master_port=29501 train_v2.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
2.模型转ONNX
bash
python onnx_export.py huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.onnx --mean 0 0 0 --std 1 1 1 --img-size 224 --checkpoint huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.pth.tar
3. 分类网络数据
训练集组织形式如yolov8_cls:
bash
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
bash
sh distributed_train_v2.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
4. 修改分类train.py
python
#!/usr/bin/env python3
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import importlib
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
has_compile = hasattr(torch, 'compile')
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset (positional is *deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR', default=r'/media/lg/C2032F933B04C4E6/00.Data/009.Uniform/81.version-2024.05.25/00.train_224_224',
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--train-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
parser.add_argument('--val-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
group.add_argument('--input-img-mode', default=None, type=str,
help='Dataset image conversion mode for input images.')
group.add_argument('--input-key', default=None, type=str,
help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--pretrained-path', default='/home/test/pytorch-image-models/output/train/20240528-142446-vit_base_patch16_224-224/last.pth.tar', type=str,
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Load this checkpoint into model after initialization (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=3000, metavar='N',
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=1.0, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=128, metavar='N',
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,
help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value')
# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
help="Python imports for device backend modules.")
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
group.add_argument('--train-crop-mode', type=str, default=None,
help='Crop-mode in train'),
group.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.5,
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
help='Probability of applying any color jitter.')
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
help='Probability of applying random grayscale conversion.')
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
help='Probability of applying gaussian blur.')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-sum', action='store_true', default=False,
help='Sum over classes when using BCE loss.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled).')
group.add_argument('--bce-pos-weight', type=float, default=None,
help='Positive weighting for BCE loss.')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
help='Enable warmup for model EMA decay.')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
else:
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
assert args.amp_dtype == 'float16'
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
use_amp = 'native'
assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16':
amp_dtype = torch.bfloat16
utils.random_seed(args.seed, args.rank)
if args.fuser:
utils.set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
factory_kwargs = {}
if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(
file=args.pretrained_path,
num_classes=-1, # force head adaptation
)
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**factory_kwargs,
**args.model_kwargs,
)
if args.head_init_scale is not None:
with torch.no_grad():
model.get_classifier().weight.mul_(args.head_init_scale)
model.get_classifier().bias.mul_(args.head_init_scale)
if args.head_init_bias is not None:
nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.to(device=device)
if args.channels_last:
model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
args.dist_bn = '' # disable dist_bn when sync BN active
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model)
else:
model = convert_sync_batchnorm(model)
if utils.is_primary(args):
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if not args.lr:
global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler()
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
if args.torchcompile:
# torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile)
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
seed=args.seed,
repeats=args.epoch_repeats,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
)
if args.val_split:
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
train_crop_mode=args.train_crop_mode,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
color_jitter_prob=args.color_jitter_prob,
grayscale_prob=args.grayscale_prob,
gaussian_blur_prob=args.gaussian_blur_prob,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = None
if args.val_split:
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
smoothing=args.smoothing,
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
decreasing_metric = eval_metric == 'loss'
best_metric = None
best_epoch = None
saver = None
output_dir = None
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
saver = utils.CheckpointSaver(
model=model,
optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing_metric,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
if args.sched_on_updates:
lr_scheduler.step_update(start_epoch * updates_per_epoch)
else:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
_logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
results = []
try:
for epoch in range(start_epoch, num_epochs):
if hasattr(dataset_train, 'set_epoch'):
dataset_train.set_epoch(epoch)
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
if loader_eval is not None:
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
else:
eval_metrics = None
if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
utils.update_summary(
epoch,
train_metrics,
eval_metrics,
filename=os.path.join(output_dir, 'summary.csv'),
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if eval_metrics is not None:
latest_metric = eval_metrics[eval_metric]
else:
latest_metric = train_metrics[eval_metric]
if saver is not None:
# save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, latest_metric)
results.append({
'epoch': epoch,
'train': train_metrics,
'validation': eval_metrics,
})
except KeyboardInterrupt:
pass
results = {'all': results}
if best_metric is not None:
results['best'] = results['all'][best_epoch - start_epoch]
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
print(f'--result\n{json.dumps(results, indent=4)}')
def train_one_epoch(
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
accum_steps = args.grad_accum_steps
last_accum_steps = len(loader) % accum_steps
updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = len(loader) - 1
last_batch_idx_to_accum = len(loader) - last_accum_steps
data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_update=need_update,
)
else:
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode,
)
optimizer.step()
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
if update_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
update_sample_count *= args.world_size
if utils.is_primary(args):
_logger.info(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True
)
if saver is not None and args.recovery_interval and (
(update_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=update_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
update_sample_count = 0
data_start_time = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
acc1 = utils.reduce_tensor(acc1, args.world_size)
acc5 = utils.reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
if device.type == 'cuda':
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
if __name__ == '__main__':
main()
5. 支持的网络
可选的在train.py
里修改使用的网络:
python
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
支持的网络(根据名字选择)如下:
bash
beit_base_patch16_224
beit_base_patch16_384
beit_large_patch16_224
beit_large_patch16_384
beit_large_patch16_512
beitv2_base_patch16_224
beitv2_large_patch16_224
gernet_l
gernet_m
gernet_s
repvgg_a0
repvgg_a1
repvgg_a2
repvgg_b0
repvgg_b1
repvgg_b1g4
repvgg_b2
repvgg_b2g4
repvgg_b3
repvgg_b3g4
repvgg_d2se
resnet51q
resnet61q
resnext26ts
gcresnext26ts
seresnext26ts
eca_resnext26ts
bat_resnext26ts
resnet32ts
resnet33ts
gcresnet33ts
seresnet33ts
eca_resnet33ts
gcresnet50t
gcresnext50ts
regnetz_b16
regnetz_c16
regnetz_d32
regnetz_d8
regnetz_e8
regnetz_b16_evos
regnetz_c16_evos
regnetz_d8_evos
mobileone_s0
mobileone_s1
mobileone_s2
mobileone_s3
mobileone_s4
botnet26t_256
sebotnet33ts_256
botnet50ts_256
eca_botnext26ts_256
halonet_h1
halonet26t
sehalonet33ts
halonet50ts
eca_halonext26ts
lambda_resnet26t
lambda_resnet50ts
lambda_resnet26rpt_256
haloregnetz_b
lamhalobotnet50ts_256
halo2botnet50ts_256
cait_xxs24_224
cait_xxs24_384
cait_xxs36_224
cait_xxs36_384
cait_xs24_384
cait_s24_224
cait_s24_384
cait_s36_384
cait_m36_384
cait_m48_448
coat_tiny
coat_mini
coat_small
coat_lite_tiny
coat_lite_mini
coat_lite_small
coat_lite_medium
coat_lite_medium_384
resnet10t
resnet14t
resnet18
resnet18d
resnet34
resnet34d
resnet26
resnet26t
resnet26d
resnet50
resnet50c
resnet50d
resnet50s
resnet50t
resnet101
resnet101c
resnet101d
resnet101s
resnet152
resnet152c
resnet152d
resnet152s
resnet200
resnet200d
wide_resnet50_2
wide_resnet101_2
resnet50_gn
resnext50_32x4d
resnext50d_32x4d
resnext101_32x4d
resnext101_32x8d
resnext101_32x16d
resnext101_32x32d
resnext101_64x4d
ecaresnet26t
ecaresnet50d
ecaresnet50d_pruned
ecaresnet50t
ecaresnetlight
ecaresnet101d
ecaresnet101d_pruned
ecaresnet200d
ecaresnet269d
ecaresnext26t_32x4d
ecaresnext50t_32x4d
seresnet18
seresnet34
seresnet50
seresnet50t
seresnet101
seresnet152
seresnet152d
seresnet200d
seresnet269d
seresnext26d_32x4d
seresnext26t_32x4d
seresnext50_32x4d
seresnext101_32x4d
seresnext101_32x8d
seresnext101d_32x8d
seresnext101_64x4d
senet154
resnetblur18
resnetblur50
resnetblur50d
resnetblur101d
resnetaa34d
resnetaa50
resnetaa50d
resnetaa101d
seresnetaa50d
seresnextaa101d_32x8d
seresnextaa201d_32x8d
resnetrs50
resnetrs101
resnetrs152
resnetrs200
resnetrs270
resnetrs350
resnetrs420
tv_resnet34
tv_resnet50
tv_resnet101
tv_resnet152
tv_resnext50_32x4d
ig_resnext101_32x8d
ig_resnext101_32x16d
ig_resnext101_32x32d
ig_resnext101_32x48d
ssl_resnet18
ssl_resnet50
ssl_resnext50_32x4d
ssl_resnext101_32x4d
ssl_resnext101_32x8d
ssl_resnext101_32x16d
swsl_resnet18
swsl_resnet50
swsl_resnext50_32x4d
swsl_resnext101_32x4d
swsl_resnext101_32x8d
swsl_resnext101_32x16d
gluon_resnet18_v1b
gluon_resnet34_v1b
gluon_resnet50_v1b
gluon_resnet101_v1b
gluon_resnet152_v1b
gluon_resnet50_v1c
gluon_resnet101_v1c
gluon_resnet152_v1c
gluon_resnet50_v1d
gluon_resnet101_v1d
gluon_resnet152_v1d
gluon_resnet50_v1s
gluon_resnet101_v1s
gluon_resnet152_v1s
gluon_resnext50_32x4d
gluon_resnext101_32x4d
gluon_resnext101_64x4d
gluon_seresnext50_32x4d
gluon_seresnext101_32x4d
gluon_seresnext101_64x4d
gluon_senet154
seresnext26tn_32x4d
resnetv2_50x1_bit
resnetv2_50x3_bit
resnetv2_101x1_bit
resnetv2_101x3_bit
resnetv2_152x2_bit
resnetv2_152x4_bit
resnetv2_50
resnetv2_50d
resnetv2_50t
resnetv2_101
resnetv2_101d
resnetv2_152
resnetv2_152d
resnetv2_50d_gn
resnetv2_50d_evos
resnetv2_50d_frn
resnetv2_50x1_bitm
resnetv2_50x3_bitm
resnetv2_101x1_bitm
resnetv2_101x3_bitm
resnetv2_152x2_bitm
resnetv2_152x4_bitm
resnetv2_50x1_bitm_in21k
resnetv2_50x3_bitm_in21k
resnetv2_101x1_bitm_in21k
resnetv2_101x3_bitm_in21k
resnetv2_152x2_bitm_in21k
resnetv2_152x4_bitm_in21k
resnetv2_50x1_bit_distilled
resnetv2_152x2_bit_teacher
resnetv2_152x2_bit_teacher_384
vit_tiny_patch16_224
vit_tiny_patch16_384
vit_small_patch32_224
vit_small_patch32_384
vit_small_patch16_224
vit_small_patch16_384
vit_small_patch8_224
vit_base_patch32_224
vit_base_patch32_384
vit_base_patch16_224
vit_base_patch16_384
vit_base_patch8_224
vit_large_patch32_224
vit_large_patch32_384
vit_large_patch16_224
vit_large_patch16_384
vit_large_patch14_224
vit_huge_patch14_224
vit_giant_patch14_224
vit_gigantic_patch14_224
vit_base_patch16_224_miil
vit_medium_patch16_gap_240
vit_medium_patch16_gap_256
vit_medium_patch16_gap_384
vit_betwixt_patch16_gap_256
vit_base_patch16_gap_224
vit_huge_patch14_gap_224
vit_huge_patch16_gap_448
vit_giant_patch16_gap_224
vit_xsmall_patch16_clip_224
vit_medium_patch32_clip_224
vit_medium_patch16_clip_224
vit_betwixt_patch32_clip_224
vit_base_patch32_clip_224
vit_base_patch32_clip_256
vit_base_patch32_clip_384
vit_base_patch32_clip_448
vit_base_patch16_clip_224
vit_base_patch16_clip_384
vit_large_patch14_clip_224
vit_large_patch14_clip_336
vit_huge_patch14_clip_224
vit_huge_patch14_clip_336
vit_huge_patch14_clip_378
vit_giant_patch14_clip_224
vit_gigantic_patch14_clip_224
vit_base_patch32_clip_quickgelu_224
vit_base_patch16_clip_quickgelu_224
vit_large_patch14_clip_quickgelu_224
vit_large_patch14_clip_quickgelu_336
vit_huge_patch14_clip_quickgelu_224
vit_huge_patch14_clip_quickgelu_378
vit_base_patch32_plus_256
vit_base_patch16_plus_240
vit_base_patch16_rpn_224
vit_small_patch16_36x1_224
vit_small_patch16_18x2_224
vit_base_patch16_18x2_224
eva_large_patch14_196
eva_large_patch14_336
flexivit_small
flexivit_base
flexivit_large
vit_base_patch16_xp_224
vit_large_patch14_xp_224
vit_huge_patch14_xp_224
vit_small_patch14_dinov2
vit_base_patch14_dinov2
vit_large_patch14_dinov2
vit_giant_patch14_dinov2
vit_small_patch14_reg4_dinov2
vit_base_patch14_reg4_dinov2
vit_large_patch14_reg4_dinov2
vit_giant_patch14_reg4_dinov2
vit_base_patch16_siglip_224
vit_base_patch16_siglip_256
vit_base_patch16_siglip_384
vit_base_patch16_siglip_512
vit_large_patch16_siglip_256
vit_large_patch16_siglip_384
vit_so400m_patch14_siglip_224
vit_so400m_patch14_siglip_384
vit_base_patch16_siglip_gap_224
vit_base_patch16_siglip_gap_256
vit_base_patch16_siglip_gap_384
vit_base_patch16_siglip_gap_512
vit_large_patch16_siglip_gap_256
vit_large_patch16_siglip_gap_384
vit_so400m_patch14_siglip_gap_224
vit_so400m_patch14_siglip_gap_384
vit_so400m_patch14_siglip_gap_448
vit_so400m_patch14_siglip_gap_896
vit_wee_patch16_reg1_gap_256
vit_pwee_patch16_reg1_gap_256
vit_little_patch16_reg1_gap_256
vit_little_patch16_reg4_gap_256
vit_medium_patch16_reg1_gap_256
vit_medium_patch16_reg4_gap_256
vit_mediumd_patch16_reg4_gap_256
vit_betwixt_patch16_reg1_gap_256
vit_betwixt_patch16_reg4_gap_256
vit_base_patch16_reg4_gap_256
vit_so150m_patch16_reg4_map_256
vit_so150m_patch16_reg4_gap_256
vit_tiny_patch16_224_in21k
vit_small_patch32_224_in21k
vit_small_patch16_224_in21k
vit_base_patch32_224_in21k
vit_base_patch16_224_in21k
vit_base_patch8_224_in21k
vit_large_patch32_224_in21k
vit_large_patch16_224_in21k
vit_huge_patch14_224_in21k
vit_base_patch32_224_sam
vit_base_patch16_224_sam
vit_small_patch16_224_dino
vit_small_patch8_224_dino
vit_base_patch16_224_dino
vit_base_patch8_224_dino
vit_base_patch16_224_miil_in21k
vit_base_patch32_224_clip_laion2b
vit_large_patch14_224_clip_laion2b
vit_huge_patch14_224_clip_laion2b
vit_giant_patch14_224_clip_laion2b
vit_tiny_r_s16_p8_224
vit_tiny_r_s16_p8_384
vit_small_r26_s32_224
vit_small_r26_s32_384
vit_base_r26_s32_224
vit_base_r50_s16_224
vit_base_r50_s16_384
vit_large_r50_s32_224
vit_large_r50_s32_384
vit_small_resnet26d_224
vit_small_resnet50d_s16_224
vit_base_resnet26d_224
vit_base_resnet50d_224
vit_tiny_r_s16_p8_224_in21k
vit_small_r26_s32_224_in21k
vit_base_r50_s16_224_in21k
vit_base_resnet50_224_in21k
vit_large_r50_s32_224_in21k
vit_base_resnet50_384
convit_tiny
convit_small
convit_base
convmixer_1536_20
convmixer_768_32
convmixer_1024_20_ks9_p14
convnext_atto
convnext_atto_ols
convnext_femto
convnext_femto_ols
convnext_pico
convnext_pico_ols
convnext_nano
convnext_nano_ols
convnext_tiny_hnf
convnext_tiny
convnext_small
convnext_base
convnext_large
convnext_large_mlp
convnext_xlarge
convnext_xxlarge
convnextv2_atto
convnextv2_femto
convnextv2_pico
convnextv2_nano
convnextv2_tiny
convnextv2_small
convnextv2_base
convnextv2_large
convnextv2_huge
convnext_tiny_in22ft1k
convnext_small_in22ft1k
convnext_base_in22ft1k
convnext_large_in22ft1k
convnext_xlarge_in22ft1k
convnext_tiny_384_in22ft1k
convnext_small_384_in22ft1k
convnext_base_384_in22ft1k
convnext_large_384_in22ft1k
convnext_xlarge_384_in22ft1k
convnext_tiny_in22k
convnext_small_in22k
convnext_base_in22k
convnext_large_in22k
convnext_xlarge_in22k
crossvit_tiny_240
crossvit_small_240
crossvit_base_240
crossvit_9_240
crossvit_15_240
crossvit_18_240
crossvit_9_dagger_240
crossvit_15_dagger_240
crossvit_15_dagger_408
crossvit_18_dagger_240
crossvit_18_dagger_408
cspresnet50
cspresnet50d
cspresnet50w
cspresnext50
cspdarknet53
darknet17
darknet21
sedarknet21
darknet53
darknetaa53
cs3darknet_s
cs3darknet_m
cs3darknet_l
cs3darknet_x
cs3darknet_focus_s
cs3darknet_focus_m
cs3darknet_focus_l
cs3darknet_focus_x
cs3sedarknet_l
cs3sedarknet_x
cs3sedarknet_xdw
cs3edgenet_x
cs3se_edgenet_x
davit_tiny
davit_small
davit_base
davit_large
davit_huge
davit_giant
deit_tiny_patch16_224
deit_small_patch16_224
deit_base_patch16_224
deit_base_patch16_384
deit_tiny_distilled_patch16_224
deit_small_distilled_patch16_224
deit_base_distilled_patch16_224
deit_base_distilled_patch16_384
deit3_small_patch16_224
deit3_small_patch16_384
deit3_medium_patch16_224
deit3_base_patch16_224
deit3_base_patch16_384
deit3_large_patch16_224
deit3_large_patch16_384
deit3_huge_patch14_224
deit3_small_patch16_224_in21ft1k
deit3_small_patch16_384_in21ft1k
deit3_medium_patch16_224_in21ft1k
deit3_base_patch16_224_in21ft1k
deit3_base_patch16_384_in21ft1k
deit3_large_patch16_224_in21ft1k
deit3_large_patch16_384_in21ft1k
deit3_huge_patch14_224_in21ft1k
densenet121
densenetblur121d
densenet169
densenet201
densenet161
densenet264d
tv_densenet121
dla60_res2net
dla60_res2next
dla34
dla46_c
dla46x_c
dla60x_c
dla60
dla60x
dla102
dla102x
dla102x2
dla169
dpn48b
dpn68
dpn68b
dpn92
dpn98
dpn131
dpn107
edgenext_xx_small
edgenext_x_small
edgenext_small
edgenext_base
edgenext_small_rw
efficientformer_l1
efficientformer_l3
efficientformer_l7
efficientformerv2_s0
efficientformerv2_s1
efficientformerv2_s2
efficientformerv2_l
mnasnet_050
mnasnet_075
mnasnet_100
mnasnet_140
semnasnet_050
semnasnet_075
semnasnet_100
semnasnet_140
mnasnet_small
mobilenetv2_035
mobilenetv2_050
mobilenetv2_075
mobilenetv2_100
mobilenetv2_140
mobilenetv2_110d
mobilenetv2_120d
fbnetc_100
spnasnet_100
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_b8
efficientnet_l2
efficientnet_b0_gn
efficientnet_b0_g8_gn
efficientnet_b0_g16_evos
efficientnet_b3_gn
efficientnet_b3_g8_gn
efficientnet_es
efficientnet_es_pruned
efficientnet_em
efficientnet_el
efficientnet_el_pruned
efficientnet_cc_b0_4e
efficientnet_cc_b0_8e
efficientnet_cc_b1_8e
efficientnet_lite0
efficientnet_lite1
efficientnet_lite2
efficientnet_lite3
efficientnet_lite4
efficientnet_b1_pruned
efficientnet_b2_pruned
efficientnet_b3_pruned
efficientnetv2_rw_t
gc_efficientnetv2_rw_t
efficientnetv2_rw_s
efficientnetv2_rw_m
efficientnetv2_s
efficientnetv2_m
efficientnetv2_l
efficientnetv2_xl
tf_efficientnet_b0
tf_efficientnet_b1
tf_efficientnet_b2
tf_efficientnet_b3
tf_efficientnet_b4
tf_efficientnet_b5
tf_efficientnet_b6
tf_efficientnet_b7
tf_efficientnet_b8
tf_efficientnet_l2
tf_efficientnet_es
tf_efficientnet_em
tf_efficientnet_el
tf_efficientnet_cc_b0_4e
tf_efficientnet_cc_b0_8e
tf_efficientnet_cc_b1_8e
tf_efficientnet_lite0
tf_efficientnet_lite1
tf_efficientnet_lite2
tf_efficientnet_lite3
tf_efficientnet_lite4
tf_efficientnetv2_s
tf_efficientnetv2_m
tf_efficientnetv2_l
tf_efficientnetv2_xl
tf_efficientnetv2_b0
tf_efficientnetv2_b1
tf_efficientnetv2_b2
tf_efficientnetv2_b3
mixnet_s
mixnet_m
mixnet_l
mixnet_xl
mixnet_xxl
tf_mixnet_s
tf_mixnet_m
tf_mixnet_l
tinynet_a
tinynet_b
tinynet_c
tinynet_d
tinynet_e
tf_efficientnet_b0_ap
tf_efficientnet_b1_ap
tf_efficientnet_b2_ap
tf_efficientnet_b3_ap
tf_efficientnet_b4_ap
tf_efficientnet_b5_ap
tf_efficientnet_b6_ap
tf_efficientnet_b7_ap
tf_efficientnet_b8_ap
tf_efficientnet_b0_ns
tf_efficientnet_b1_ns
tf_efficientnet_b2_ns
tf_efficientnet_b3_ns
tf_efficientnet_b4_ns
tf_efficientnet_b5_ns
tf_efficientnet_b6_ns
tf_efficientnet_b7_ns
tf_efficientnet_l2_ns_475
tf_efficientnet_l2_ns
tf_efficientnetv2_s_in21ft1k
tf_efficientnetv2_m_in21ft1k
tf_efficientnetv2_l_in21ft1k
tf_efficientnetv2_xl_in21ft1k
tf_efficientnetv2_s_in21k
tf_efficientnetv2_m_in21k
tf_efficientnetv2_l_in21k
tf_efficientnetv2_xl_in21k
efficientnet_b2a
efficientnet_b3a
mnasnet_a1
mnasnet_b1
efficientvit_b0
efficientvit_b1
efficientvit_b2
efficientvit_b3
efficientvit_l1
efficientvit_l2
efficientvit_l3
efficientvit_m0
efficientvit_m1
efficientvit_m2
efficientvit_m3
efficientvit_m4
efficientvit_m5
eva_giant_patch14_224
eva_giant_patch14_336
eva_giant_patch14_560
eva02_tiny_patch14_224
eva02_small_patch14_224
eva02_base_patch14_224
eva02_large_patch14_224
eva02_tiny_patch14_336
eva02_small_patch14_336
eva02_base_patch14_448
eva02_large_patch14_448
eva_giant_patch14_clip_224
eva02_base_patch16_clip_224
eva02_large_patch14_clip_224
eva02_large_patch14_clip_336
eva02_enormous_patch14_clip_224
vit_medium_patch16_rope_reg1_gap_256
vit_mediumd_patch16_rope_reg1_gap_256
vit_betwixt_patch16_rope_reg4_gap_256
vit_base_patch16_rope_reg1_gap_256
fastvit_t8
fastvit_t12
fastvit_s12
fastvit_sa12
fastvit_sa24
fastvit_sa36
fastvit_ma36
focalnet_tiny_srf
focalnet_small_srf
focalnet_base_srf
focalnet_tiny_lrf
focalnet_small_lrf
focalnet_base_lrf
focalnet_large_fl3
focalnet_large_fl4
focalnet_xlarge_fl3
focalnet_xlarge_fl4
focalnet_huge_fl3
focalnet_huge_fl4
gcvit_xxtiny
gcvit_xtiny
gcvit_tiny
gcvit_small
gcvit_base
ghostnet_050
ghostnet_100
ghostnet_130
ghostnetv2_100
ghostnetv2_130
ghostnetv2_160
mobilenetv3_large_075
mobilenetv3_large_100
mobilenetv3_small_050
mobilenetv3_small_075
mobilenetv3_small_100
mobilenetv3_rw
tf_mobilenetv3_large_075
tf_mobilenetv3_large_100
tf_mobilenetv3_large_minimal_100
tf_mobilenetv3_small_075
tf_mobilenetv3_small_100
tf_mobilenetv3_small_minimal_100
fbnetv3_b
fbnetv3_d
fbnetv3_g
lcnet_035
lcnet_050
lcnet_075
lcnet_100
lcnet_150
mobilenetv3_large_100_miil
mobilenetv3_large_100_miil_in21k
hardcorenas_a
hardcorenas_b
hardcorenas_c
hardcorenas_d
hardcorenas_e
hardcorenas_f
hgnet_tiny
hgnet_small
hgnet_base
hgnetv2_b0
hgnetv2_b1
hgnetv2_b2
hgnetv2_b3
hgnetv2_b4
hgnetv2_b5
hgnetv2_b6
hiera_tiny_224
hiera_small_224
hiera_base_224
hiera_base_plus_224
hiera_large_224
hiera_huge_224
hrnet_w18_small
hrnet_w18_small_v2
hrnet_w18
hrnet_w30
hrnet_w32
hrnet_w40
hrnet_w44
hrnet_w48
hrnet_w64
hrnet_w18_ssld
hrnet_w48_ssld
inception_next_tiny
inception_next_small
inception_next_base
inception_resnet_v2
ens_adv_inception_resnet_v2
inception_v3
tf_inception_v3
adv_inception_v3
gluon_inception_v3
inception_v4
levit_128s
levit_128
levit_192
levit_256
levit_384
levit_384_s8
levit_512_s8
levit_512
levit_256d
levit_512d
levit_conv_128s
levit_conv_128
levit_conv_192
levit_conv_256
levit_conv_384
levit_conv_384_s8
levit_conv_512_s8
levit_conv_512
levit_conv_256d
levit_conv_512d
coatnet_pico_rw_224
coatnet_nano_rw_224
coatnet_0_rw_224
coatnet_1_rw_224
coatnet_2_rw_224
coatnet_3_rw_224
coatnet_bn_0_rw_224
coatnet_rmlp_nano_rw_224
coatnet_rmlp_0_rw_224
coatnet_rmlp_1_rw_224
coatnet_rmlp_1_rw2_224
coatnet_rmlp_2_rw_224
coatnet_rmlp_2_rw_384
coatnet_rmlp_3_rw_224
coatnet_nano_cc_224
coatnext_nano_rw_224
coatnet_0_224
coatnet_1_224
coatnet_2_224
coatnet_3_224
coatnet_4_224
coatnet_5_224
maxvit_pico_rw_256
maxvit_nano_rw_256
maxvit_tiny_rw_224
maxvit_tiny_rw_256
maxvit_rmlp_pico_rw_256
maxvit_rmlp_nano_rw_256
maxvit_rmlp_tiny_rw_256
maxvit_rmlp_small_rw_224
maxvit_rmlp_small_rw_256
maxvit_rmlp_base_rw_224
maxvit_rmlp_base_rw_384
maxvit_tiny_pm_256
maxxvit_rmlp_nano_rw_256
maxxvit_rmlp_tiny_rw_256
maxxvit_rmlp_small_rw_256
maxxvitv2_nano_rw_256
maxxvitv2_rmlp_base_rw_224
maxxvitv2_rmlp_base_rw_384
maxxvitv2_rmlp_large_rw_224
maxvit_tiny_tf_224
maxvit_tiny_tf_384
maxvit_tiny_tf_512
maxvit_small_tf_224
maxvit_small_tf_384
maxvit_small_tf_512
maxvit_base_tf_224
maxvit_base_tf_384
maxvit_base_tf_512
maxvit_large_tf_224
maxvit_large_tf_384
maxvit_large_tf_512
maxvit_xlarge_tf_224
maxvit_xlarge_tf_384
maxvit_xlarge_tf_512
poolformer_s12
poolformer_s24
poolformer_s36
poolformer_m36
poolformer_m48
poolformerv2_s12
poolformerv2_s24
poolformerv2_s36
poolformerv2_m36
poolformerv2_m48
convformer_s18
convformer_s36
convformer_m36
convformer_b36
caformer_s18
caformer_s36
caformer_m36
caformer_b36
mixer_s32_224
mixer_s16_224
mixer_b32_224
mixer_b16_224
mixer_l32_224
mixer_l16_224
gmixer_12_224
gmixer_24_224
resmlp_12_224
resmlp_24_224
resmlp_36_224
resmlp_big_24_224
gmlp_ti16_224
gmlp_s16_224
gmlp_b16_224
mixer_b16_224_in21k
mixer_l16_224_in21k
mixer_b16_224_miil
mixer_b16_224_miil_in21k
resmlp_12_distilled_224
resmlp_24_distilled_224
resmlp_36_distilled_224
resmlp_big_24_distilled_224
resmlp_big_24_224_in22ft1k
resmlp_12_224_dino
resmlp_24_224_dino
mobilevit_xxs
mobilevit_xs
mobilevit_s
mobilevitv2_050
mobilevitv2_075
mobilevitv2_100
mobilevitv2_125
mobilevitv2_150
mobilevitv2_175
mobilevitv2_200
mobilevitv2_150_in22ft1k
mobilevitv2_175_in22ft1k
mobilevitv2_200_in22ft1k
mobilevitv2_150_384_in22ft1k
mobilevitv2_175_384_in22ft1k
mobilevitv2_200_384_in22ft1k
mvitv2_tiny
mvitv2_small
mvitv2_base
mvitv2_large
mvitv2_small_cls
mvitv2_base_cls
mvitv2_large_cls
mvitv2_huge_cls
nasnetalarge
nest_base
nest_small
nest_tiny
nest_base_jx
nest_small_jx
nest_tiny_jx
jx_nest_base
jx_nest_small
jx_nest_tiny
nextvit_small
nextvit_base
nextvit_large
dm_nfnet_f0
dm_nfnet_f1
dm_nfnet_f2
dm_nfnet_f3
dm_nfnet_f4
dm_nfnet_f5
dm_nfnet_f6
nfnet_f0
nfnet_f1
nfnet_f2
nfnet_f3
nfnet_f4
nfnet_f5
nfnet_f6
nfnet_f7
nfnet_l0
eca_nfnet_l0
eca_nfnet_l1
eca_nfnet_l2
eca_nfnet_l3
nf_regnet_b0
nf_regnet_b1
nf_regnet_b2
nf_regnet_b3
nf_regnet_b4
nf_regnet_b5
nf_resnet26
nf_resnet50
nf_resnet101
nf_seresnet26
nf_seresnet50
nf_seresnet101
nf_ecaresnet26
nf_ecaresnet50
nf_ecaresnet101
pit_b_224
pit_s_224
pit_xs_224
pit_ti_224
pit_b_distilled_224
pit_s_distilled_224
pit_xs_distilled_224
pit_ti_distilled_224
pnasnet5large
pvt_v2_b0
pvt_v2_b1
pvt_v2_b2
pvt_v2_b3
pvt_v2_b4
pvt_v2_b5
pvt_v2_b2_li
regnetx_002
regnetx_004
regnetx_004_tv
regnetx_006
regnetx_008
regnetx_016
regnetx_032
regnetx_040
regnetx_064
regnetx_080
regnetx_120
regnetx_160
regnetx_320
regnety_002
regnety_004
regnety_006
regnety_008
regnety_008_tv
regnety_016
regnety_032
regnety_040
regnety_064
regnety_080
regnety_080_tv
regnety_120
regnety_160
regnety_320
regnety_640
regnety_1280
regnety_2560
regnety_040_sgn
regnetv_040
regnetv_064
regnetz_005
regnetz_040
regnetz_040_h
regnetz_040h
repghostnet_050
repghostnet_058
repghostnet_080
repghostnet_100
repghostnet_111
repghostnet_130
repghostnet_150
repghostnet_200
repvit_m1
repvit_m2
repvit_m3
repvit_m0_9
repvit_m1_0
repvit_m1_1
repvit_m1_5
repvit_m2_3
res2net50_26w_4s
res2net101_26w_4s
res2net50_26w_6s
res2net50_26w_8s
res2net50_48w_2s
res2net50_14w_8s
res2next50
res2net50d
res2net101d
resnest14d
resnest26d
resnest50d
resnest101e
resnest200e
resnest269e
resnest50d_4s2x40d
resnest50d_1s4x24d
rexnet_100
rexnet_130
rexnet_150
rexnet_200
rexnet_300
rexnetr_100
rexnetr_130
rexnetr_150
rexnetr_200
rexnetr_300
selecsls42
selecsls42b
selecsls60
selecsls60b
selecsls84
legacy_seresnet18
legacy_seresnet34
legacy_seresnet50
legacy_seresnet101
legacy_seresnet152
legacy_senet154
legacy_seresnext26_32x4d
legacy_seresnext50_32x4d
legacy_seresnext101_32x4d
sequencer2d_s
sequencer2d_m
sequencer2d_l
skresnet18
skresnet34
skresnet50
skresnet50d
skresnext50_32x4d
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_large_patch4_window7_224
swin_large_patch4_window12_384
swin_s3_tiny_224
swin_s3_small_224
swin_s3_base_224
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
swinv2_tiny_window16_256
swinv2_tiny_window8_256
swinv2_small_window16_256
swinv2_small_window8_256
swinv2_base_window16_256
swinv2_base_window8_256
swinv2_base_window12_192
swinv2_base_window12to16_192to256
swinv2_base_window12to24_192to384
swinv2_large_window12_192
swinv2_large_window12to16_192to256
swinv2_large_window12to24_192to384
swinv2_base_window12_192_22k
swinv2_base_window12to16_192to256_22kft1k
swinv2_base_window12to24_192to384_22kft1k
swinv2_large_window12_192_22k
swinv2_large_window12to16_192to256_22kft1k
swinv2_large_window12to24_192to384_22kft1k
swinv2_cr_tiny_384
swinv2_cr_tiny_224
swinv2_cr_tiny_ns_224
swinv2_cr_small_384
swinv2_cr_small_224
swinv2_cr_small_ns_224
swinv2_cr_small_ns_256
swinv2_cr_base_384
swinv2_cr_base_224
swinv2_cr_base_ns_224
swinv2_cr_large_384
swinv2_cr_large_224
swinv2_cr_huge_384
swinv2_cr_huge_224
swinv2_cr_giant_384
swinv2_cr_giant_224
tiny_vit_5m_224
tiny_vit_11m_224
tiny_vit_21m_224
tiny_vit_21m_384
tiny_vit_21m_512
tnt_s_patch16_224
tnt_b_patch16_224
tresnet_m
tresnet_l
tresnet_xl
tresnet_v2_l
tresnet_m_miil_in21k
tresnet_m_448
tresnet_l_448
tresnet_xl_448
twins_pcpvt_small
twins_pcpvt_base
twins_pcpvt_large
twins_svt_small
twins_svt_base
twins_svt_large
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn
visformer_tiny
visformer_small
vit_relpos_base_patch32_plus_rpn_256
vit_relpos_base_patch16_plus_240
vit_relpos_small_patch16_224
vit_relpos_medium_patch16_224
vit_relpos_base_patch16_224
vit_srelpos_small_patch16_224
vit_srelpos_medium_patch16_224
vit_relpos_medium_patch16_cls_224
vit_relpos_base_patch16_cls_224
vit_relpos_base_patch16_clsgap_224
vit_relpos_small_patch16_rpn_224
vit_relpos_medium_patch16_rpn_224
vit_relpos_base_patch16_rpn_224
samvit_base_patch16
samvit_large_patch16
samvit_huge_patch16
samvit_base_patch16_224
volo_d1_224
volo_d1_384
volo_d2_224
volo_d2_384
volo_d3_224
volo_d3_448
volo_d4_224
volo_d4_448
volo_d5_224
volo_d5_448
volo_d5_512
vovnet39a
vovnet57a
ese_vovnet19b_slim_dw
ese_vovnet19b_dw
ese_vovnet19b_slim
ese_vovnet39b
ese_vovnet57b
ese_vovnet99b
eca_vovnet39b
ese_vovnet39b_evos
legacy_xception
xception
xception41
xception65
xception71
xception41p
xception65p
xcit_nano_12_p16_224
xcit_nano_12_p16_384
xcit_tiny_12_p16_224
xcit_tiny_12_p16_384
xcit_small_12_p16_224
xcit_small_12_p16_384
xcit_tiny_24_p16_224
xcit_tiny_24_p16_384
xcit_small_24_p16_224
xcit_small_24_p16_384
xcit_medium_24_p16_224
xcit_medium_24_p16_384
xcit_large_24_p16_224
xcit_large_24_p16_384
xcit_nano_12_p8_224
xcit_nano_12_p8_384
xcit_tiny_12_p8_224
xcit_tiny_12_p8_384
xcit_small_12_p8_224
xcit_small_12_p8_384
xcit_tiny_24_p8_224
xcit_tiny_24_p8_384
xcit_small_24_p8_224
xcit_small_24_p8_384
xcit_medium_24_p8_224
xcit_medium_24_p8_384
xcit_large_24_p8_224
xcit_large_24_p8_384
xcit_nano_12_p16_224_dist
xcit_nano_12_p16_384_dist
xcit_tiny_12_p16_224_dist
xcit_tiny_12_p16_384_dist
xcit_tiny_24_p16_224_dist
xcit_tiny_24_p16_384_dist
xcit_small_12_p16_224_dist
xcit_small_12_p16_384_dist
xcit_small_24_p16_224_dist
xcit_small_24_p16_384_dist
xcit_medium_24_p16_224_dist
xcit_medium_24_p16_384_dist
xcit_large_24_p16_224_dist
xcit_large_24_p16_384_dist
xcit_nano_12_p8_224_dist
xcit_nano_12_p8_384_dist
xcit_tiny_12_p8_224_dist
xcit_tiny_12_p8_384_dist
xcit_tiny_24_p8_224_dist
xcit_tiny_24_p8_384_dist
xcit_small_12_p8_224_dist
xcit_small_12_p8_384_dist
xcit_small_24_p8_224_dist
xcit_small_24_p8_384_dist
xcit_medium_24_p8_224_dist
xcit_medium_24_p8_384_dist
xcit_large_24_p8_224_dist
xcit_large_24_p8_384_dist