植物分类-PlantsClassification

一、模型配置

一、backbone

resnet50

二、neck

GlobalAveragePooling

三、head

fc

四、loss

type='LabelSmoothLoss',

label_smooth_val=0.1,

num_classes=30,

reduction='mean',

loss_weight=1.0

五、optimizer

lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001

六、scheduler

T_max=260, begin=20, by_epoch=True, end=300, type='CosineAnnealingLR

七、evaluator

topk=(1, 5 ), type='Accuracy'

八、max_epochs

300

九、Config

python 复制代码
auto_scale_lr = dict(base_batch_size=256)
data_preprocessor = dict(
    mean=[
        123.675,
        116.28,
        103.53,
    ],
    num_classes=30,
    std=[
        58.395,
        57.12,
        57.375,
    ],
    to_rgb=True)
dataset_type = 'ImageNet'
data_root = 'data/PlantsClassification'
default_hooks = dict(
    checkpoint=dict(interval=1, type='CheckpointHook', max_keep_ckpts=2, save_best="auto"),
    logger=dict(interval=100, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'),
    visualization=dict(enable=False, type='VisualizationHook'))
default_scope = 'mmpretrain'
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
launcher = 'none'
load_from = './work_dirs/resnet50_8xb32-coslr_in1k/resnet50_8xb32_in1k_20210831-ea4938fc.pth'
log_level = 'INFO'
model = dict(
    backbone=dict(
        depth=50,
        num_stages=4,
        out_indices=(3,),
        style='pytorch',
        type='ResNet'),
    head=dict(
        in_channels=2048,
        # loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
        loss=dict(
                    type='LabelSmoothLoss',
                    label_smooth_val=0.1,
                    num_classes=30,
                    reduction='mean',
                    loss_weight=1.0),
        num_classes=30,
        topk=(
            1,
            5,
        ),
        type='LinearClsHead'),
    data_preprocessor=data_preprocessor,
    neck=dict(type='GlobalAveragePooling'),
    type='ImageClassifier')
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
optim_wrapper = dict(
    optimizer=dict(lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001))
param_scheduler = dict(
    T_max=260, begin=20, by_epoch=True, end=300, type='CosineAnnealingLR')
randomness = dict(deterministic=False, seed=None)
resume = False
test_cfg = dict()
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(edge='short', scale=256, type='ResizeEdge'),
    dict(crop_size=224, type='CenterCrop'),
    dict(type='PackInputs'),
]
test_dataloader = dict(
    batch_size=32,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_root=data_root,
        pipeline=test_pipeline,
        split='test',
        ann_file='test.txt',
        type=dataset_type),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(
    topk=(
        1,
        5,
    ), type='Accuracy')

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(scale=224, type='RandomResizedCrop'),
    dict(direction='horizontal', prob=0.5, type='RandomFlip'),
    dict(type='PackInputs'),
]
train_dataloader = dict(
    batch_size=45,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_root=data_root,
        pipeline=train_pipeline,
        split='train',
        ann_file='train.txt',
        type=dataset_type),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=True, type='DefaultSampler'))

val_cfg = dict()
val_dataloader = dict(
    batch_size=45,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_root=data_root,
        pipeline=test_pipeline,
        split='val',
        ann_file='val.txt',
        type=dataset_type),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = test_evaluator
vis_backends = [
    dict(type='LocalVisBackend'),
]
visualizer = dict(
    type='UniversalVisualizer', vis_backends=[
        dict(type='LocalVisBackend'),
    ])
work_dir = './work_dirs\\resnet50_8xb32-coslr_in1k'

二、训练结果

采用kaggle植物分类数据集,30分类,标签:

IMAGENET_CATEGORIES = ['aloevera', 'banana', 'bilimbi', 'cantaloupe', 'cassava', 'coconut', 'corn', 'cucumber',

'curcuma', 'eggplant', 'galangal', 'ginger', 'guava', 'kale', 'longbeans', 'mango', 'melon',

'orange', 'paddy', 'papaya', 'peperchili', 'pineapple', 'pomelo', 'shallot', 'soybeans',

'spinach', 'sweetpotatoes', 'tobacco', 'waterapple', 'watermelon']

"accuracy/top1": 90.00000762939453, "accuracy/top5": 98.0666732788086

三、结果分析

分析分类结果发现,更改不同的训练策略,结果不会增加,且总有个别类别分类错误,仔细分析数据发现,引起该问题的主要原因是数据集本身引起的,在个别类别中混入了其他类别的图片,甚至出现两个类别使用的数据完全一致的情况,比如melon和cantaloupe使用的是相同的数据集,且内部混入了西瓜的数据集

四、预测测试
























五、总结

在训练前,必须检查数据集

相关推荐
新缸中之脑4 分钟前
Llama 3.2 安卓手机安装教程
前端·人工智能·算法
人工智障调包侠5 分钟前
基于深度学习多层感知机进行手机价格预测
人工智能·python·深度学习·机器学习·数据分析
开始King1 小时前
Tensorflow2.0
人工智能·tensorflow
Elastic 中国社区官方博客1 小时前
Elasticsearch 开放推理 API 增加了对 Google AI Studio 的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎
infominer1 小时前
RAGFlow 0.12 版本功能导读
人工智能·开源·aigc·ai-native
涩即是Null1 小时前
如何构建LSTM神经网络模型
人工智能·rnn·深度学习·神经网络·lstm
本本的小橙子1 小时前
第十四周:机器学习
人工智能·机器学习
励志成为美貌才华为一体的女子2 小时前
《大规模语言模型从理论到实践》第一轮学习--第四章分布式训练
人工智能·分布式·语言模型
学步_技术2 小时前
自动驾驶系列—自动驾驶背后的数据通道:通信总线技术详解与应用场景分析
人工智能·机器学习·自动驾驶·通信总线
winds~2 小时前
自动驾驶-问题笔记-待解决
人工智能·笔记·自动驾驶