第一,网络架构不能变
第二,改造测试函数为一张图片识别
以第二个pytorch程序训练cifar100为例:
我的第二个pytorch人工智能程序(最简单的方式训练cifar100)-CSDN博客
重点说第二:
a,打开图片,transform,变成nchw=【1,3,32,32】
transform = transforms.Compose(
[transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
transforms.ToTensor(),
transforms.Normalize((0, 0, 0), (1, 1, 1))])
im = Image.open("mifeng.png").convert("RGB")#来自网络,随便下载一张蜜蜂图片
im = transform(im) # C, H, W
im = torch.unsqueeze(im, dim=0) # 对数据增加一个新维度,因为tensor的参数是batch, channel, height, width
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#这句话似乎多余了!
b,使用网络模型,加载训练好的pth文件,枚举类别
net = Model()#参考前面程序训练模型架构
net.load_state_dict(torch.load('./cifar_ResNet_simple100.pth'))#参考前面程序训练保存
classes1 = ('apple0苹果', 'aquarium_fish', 'baby 2宝贝', 'bear 3熊', 'beaver',
'bed5 床', 'bee6蜜蜂', 'beetle7甲虫', 'bicycle 8自行车', 'bottle9瓶子',
'bowl10碗', 'boy11男孩', 'bridge12桥', 'bus13公共汽车', 'butterfly14蝴蝶',
'camel15骆驼', 'can16罐子', 'castle17城堡', 'caterpillar18毛毛虫', 'cattle19牛',
'chair 20椅子', 'chimpanzee21黑猩猩', 'clock22时钟', 'cloud23云', 'cockroach24蟑螂 ',
'couch25沙发', 'crab26螃蟹', 'crocodile27鳄鱼', 'cup28杯子', 'dinosaur29恐龙',
'dolphin30海豚', 'elephant31大象', 'flatfish67比目鱼', 'forest33森林', 'fox34狐狸',
'girl35女孩', 'hamster36仓鼠', 'house37房子', 'kangaroo38袋鼠', 'keyboard39键盘',
'lamp40台灯', 'lawn_mower41割草机', 'leopard42豹', 'lion43狮子', 'lizard 44蜥蜴',
'lobster45龙虾', 'man46男人', 'maple_tree47枫树', 'motorcycle48摩托车', 'mountain49山',
'mouse50老鼠', 'mushroom51蘑菇', 'oak_tree52橡树', 'orange 53橘子', 'orchid54兰花',
'otter55水獭', 'palm_tree 56棕榈', 'pear 57梨', 'pickup_truck58皮卡车', 'pine_tree59松树',
'plain60平原', 'plate61盘子', 'poppy62罂粟花', 'porcupine63刺猬豪猪' , 'possum64负鼠',
'rabbit兔子 65', 'raccoon66浣熊', 'ray67射线', 'road68路', 'rocket69火箭',
'rose70玫瑰','sea71海', 'seal72封印海豹', 'shark73鲨鱼', 'shrew74泼妇母老虎',
'skunk75臭鼬', 'skyscraper76摩天大楼', 'snail77蜗牛', 'snake78蛇', 'spider79蜘蛛',
'squirrel80松鼠','streetcar81有轨电车', 'sunflower82向日葵', 'sweet_pepper83甜椒', 'table84 桌子',
'tank85坦克', 'telephone86电话机', 'television87电视机', 'tiger88老虎', 'tractor89拖拉',
'train90火车','trout91鳟鱼', 'tulip92郁金香', 'turtle93乌龟', 'wardrobe94衣柜',
'whale95鲸鱼 ', 'willow_tree96柳树', 'wolf97狼', 'woman98女人', 'worm99蠕虫')
c,改造测试函数:(用cpu方式识别)
with torch.no_grad():
all_preds = \[\]
outputs = net(im)
predict = torch.max(outputs, dim=1)1.data.numpy()
_, predicted = torch.max(outputs, 1)
all_preds.extend(predicted.cpu().numpy())
plt.figure(figsize=(1, 1))
plt.imshow(data0i.numpy().transpose(1, 2, 0))
plt.imshow(im0.permute(1, 2, 0))
plt.show()
#print(classes1int(predict))
print(classes1int(predicted))
#类别正确了,程序训练最好63,识别仍然一般,关键是可以一幅图识别程序成功了
