cifar100下载使用python数据集:cifar-100-python
高版本的pytorch2.01(cuda11.6)加载cifar100与cifar10没什么两样:
1 数据
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0,再把图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.ToTensor(),
transforms.Normalize((0, 0, 0), (1, 1, 1)), #R,G,B每层的归一化用到的均值和方差
])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize( (0.1307,), (0.3081,) )
transforms.Normalize((0, 0, 0), (1,1, 1))
])
train_set = datasets.CIFAR100(root = './data',
train=True,
transform = transform_train,
download=True)
train_loader = DataLoader( train_set, shuffle=True, batch_size=batch_size)
当初用低版本pytorch1.01(cuda10.2)花了很长时间都搞不定!
一个是加载不了,一个是类别搞不清(研究结果:存储类别用了两个字节,其中一个字节代表20个大类,另一个字节才是100分类)
加载不了,甚至一度用二进制binarycifar100数据集,写了一个类来加载服务pytorch,下面看看我猜中的类别,32*32的图片太难认(人眼识别),当然现在不必如此!
classes1猜中 = (' 0苹果 ', ' 1水族馆的鱼 ',' 2宝贝 ',' 3熊 ','4豪猪 ', ' 5 床 ', ' 6蜜蜂 ',' 7甲虫 ',' 8自行车 ',' 9瓶子 ',
'10碗', ' 11男孩', '12桥',' 13公共汽车 ', '14蝴蝶 ','15骆驼 ', '16罐子','17城堡 ', '18 蠕虫 ', ' 19牛 ',
' 20椅子 ',' 21黑猩猩','22时钟 ','23云 ', '24蟑螂 ', '25沙发 ',' 26螃蟹',' 27鳄鱼', '28杯子', ' 29恐龙',
' 30海豚 ',' 31大象', '32母老虎 ', ' 33森林','34狐狸',' 35女孩','36负鼠 ', ' 37房子 ', '38水獭 ',' 39键盘 ',
' 40台灯 ',' 41割草机 ',' 42豹', ' 43狮子 ', ' 44蜥蜴',' 45龙虾','46男人', '47枫树','48摩托车',' 49山 ',
' 50老鼠 ', '51蘑菇', '52橡树 ',' 53橘子',' 54兰花',' 55袋鼠', ' 56棕榈',' 57梨', ' 58皮卡车 ',' 59松树 ',
'60平原', '61盘子','62罂粟花 ', ' 63仓鼠 ', '64浣熊 ',' 兔子 65 ',' 66臭鼬 ', '67比目鱼', ' 68路','69火箭',
' 70玫瑰', ' 71海', ' 72海豹',' 73鲨鱼 ',' 74毛虫 ', ' 75海狸',' 76摩天大楼 ','77蜗牛','78蛇',' 79蜘蛛',
'80松鼠', ' 81有轨电车',' 82向日葵','83甜椒', '84 桌子 ','85坦克', '86电话机', '87电视机','88老虎','89拖拉',
' 90火车','91鳟鱼','92郁金香','93乌龟',' 94衣柜', '95鲸鱼 ', '96柳树',' 97狼',' 98女人', '99射线 ' )
classes1正确= ('apple0苹果', 'aquarium_fish', 'baby 2宝贝', 'bear 3熊', 'beaver',
'bed5 床', 'bee6蜜蜂', 'beetle7甲虫', 'bicycle 8自行车', 'bottle9瓶子',
'bowl10碗', 'boy11男孩', 'bridge12桥', 'bus13公共汽车', 'butterfly14蝴蝶',
'camel15骆驼', 'can16罐子', 'castle17城堡', 'caterpillar18毛毛虫', 'cattle19牛',
'chair 20椅子', 'chimpanzee21黑猩猩', 'clock22时钟', 'cloud23云', 'cockroach24蟑螂 ',
'couch25沙发', 'crab26螃蟹', 'crocodile27鳄鱼', 'cup28杯子', 'dinosaur29恐龙',
'dolphin30海豚', 'elephant31大象', 'flatfish67比目鱼', 'forest33森林', 'fox34狐狸',
'girl35女孩', 'hamster36仓鼠', 'house37房子', 'kangaroo38袋鼠', 'keyboard39键盘',
'lamp40台灯', 'lawn_mower41割草机', 'leopard42豹', 'lion43狮子', 'lizard 44蜥蜴',
'lobster45龙虾', 'man46男人', 'maple_tree47枫树', 'motorcycle48摩托车', 'mountain49山',
'mouse50老鼠', 'mushroom51蘑菇', 'oak_tree52橡树', 'orange 53橘子', 'orchid54兰花',
'otter55水獭', 'palm_tree 56棕榈', 'pear 57梨', 'pickup_truck58皮卡车', 'pine_tree59松树',
'plain60平原', 'plate61盘子', 'poppy62罂粟花', 'porcupine63刺猬豪猪' , 'possum64负鼠',
'rabbit兔子 65', 'raccoon66浣熊', 'ray67射线', 'road68路', 'rocket69火箭',
'rose70玫瑰','sea71海', 'seal72封印海豹', 'shark73鲨鱼', 'shrew74泼妇母老虎',
'skunk75臭鼬', 'skyscraper76摩天大楼', 'snail77蜗牛', 'snake78蛇', 'spider79蜘蛛',
'squirrel80松鼠','streetcar81有轨电车', 'sunflower82向日葵', 'sweet_pepper83甜椒', 'table84 桌子',
'tank85坦克', 'telephone86电话机', 'television87电视机', 'tiger88老虎', 'tractor89拖拉',
'train90火车','trout91鳟鱼', 'tulip92郁金香', 'turtle93乌龟', 'wardrobe94衣柜',
'whale95鲸鱼 ', 'willow_tree96柳树', 'wolf97狼', 'woman98女人', 'worm99蠕虫')
现在才知道pytorch中打印类别很好的方法:
使用: print()函数,举例: print(classes1int(predicted)),这个是我最先学会的,竟然有更好的方法:print(labels),抄的程序多了,很容易进步,学会很多pytorch的tricks(技巧)。
奇怪的是,python版cifar100数据集的10000张测试集到哪里去了,我还没找到!
现有5万张,拆成4万训练,1万test,如下:
trainset = datasets.CIFAR100(root='./Data_CIFAR100', train=True, download=True, transform=transform_train)
config = {
"train_size_perc": 0.8,
"batch_size": 128,
"learning_rate": 0.001,
"epochs": 60,
"lr_decay_step": 20,
"lr_decay_gamma": 0.2, # 衰减系数
"save_path": "model_save/SeRes_cifar100_model.pth"
}
设置训练集和验证集的比例
train_size = int(config"train_size_perc" * len(trainset)) # 80%用于训练
val_size = len(trainset) - train_size # 20%用于验证
train_data, val_data = random_split(trainset, train_size, val_size)
trainloader = DataLoader(dataset=train_data, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(dataset=val_data, batch_size=100, shuffle=False, num_workers=2)
从前从cpu一路走过来,发现与pytorch(python)差距太大了!
刚刚玩转的cudnn训练(gpu),用c++实现的2个残差+6个bn网络,在pytorch面前,连尾灯都看不见!
好处是,现在学pytorch,不迷茫!
因为九层之台,已经在脚下了!