自定义数据集
- 1、加载目录文件信息,对照片进行整形。
-
- 定义pokemon
- 定义load_csv
- [定义 len函数](#定义 len函数)
- 定义getitem函数
- 测试运行:
- 定义denormalize函数
- [2、 定义resNet神经网络:](#2、 定义resNet神经网络:)
- 3、训练过程
1、加载目录文件信息,对照片进行整形。
定义pokemon
python
class Pokemon(Dataset):
#mode :training,test,
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
# os.listdir 用于列出指定目录中的所有文件和子目录,
for name in sorted(os.listdir((os.path.join(root)))):
#如果不是目录文件就跳过
if not os.path.isdir(os.path.join(root,name)):
continue
# name 为key值,name2label.keys 取出当前name2label中的关键字,如皮卡丘,小火龙, 而len就是取出当前有几个
#{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
self.name2label[name] = len(self.name2label.keys())
#获取images,labels
self.images,self.labels = self.load_csv('images.csv')
# print(self.images,self.labels)
if mode == 'train': #:60%
#将images 列表的长度限制为原始长度的60%,相当于保留原始元素的前百分之六十
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val':#:20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
定义load_csv
功能:用于返回images和labels。
python
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
#glob用于加载指定路径目录中的所有以Png格式的图片,如pokemon\\mewtwo\\000001.png
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root,name,'*.jpg'))
images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
# 1167张图片
# print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
#img 为E:\\Deep_Learning\\Resource\\入门项目\\pokeman\\bulbasaur\\00000000.png
# -6 -5 -4 -3 -2 -1
for img in images:
#取出key值,如皮卡丘
name = img.split(os.sep)[-2]
#取出属于哪个标签
label = self.name2label[name]
#E:\Deep_Learning\Resource\introductory_program\pokeman\charmander\00000082.png,1
writer.writerow([img,label])
print('writen into csv file',filename)
#read from csv file
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
定义 len函数
python
def __len__(self):
return len(self.images)
定义getitem函数
python
def __getitem__(self, idx):
#idx : [0-len(images]
#img :为E:\\Deep_Learning\\Resource\\入门项目\\pokeman\\bulbasaur\\00000000.png
#label : 0 或者1 或2 或3
img,label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'), #string path =>image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
transforms.RandomRotation(15),#旋转15度
transforms.CenterCrop(self.resize), #中心裁剪,会保留原来的底色
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img = tf(img)
label = torch.tensor(label)
return img,label
测试运行:
python
def main():
import visdom
import time
viz = visdom.Visdom()
# tf = transforms.Compose([
# transforms.Resize((64,64)),
# transforms.ToTensor(),
#
# ])
# db = torchvision.datasets.ImageFolder(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',transform=tf)
db = Pokemon('E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',64,"train")
loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x, y in loader:
# nrow 表示一行显示 8 张
viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
x,y = next(iter(db))
print("sample",x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
loader = DataLoader(db,batch_size=32,shuffle=True)
for x,y in loader:
#nrow 表示一行显示 8 张
viz.images(db.denormalize(x), nrow= 8,win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()
定义denormalize函数
作用:因为对照片进行了normalize,会导致图片显示很奇怪,对其进行逆操作之后即可正常显示。
python
def denormalize(self,x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#normalize的公式是:normalize = (original_value - mean) /std
# denormalize: => std * normalize + mean = original_value
# x_hat = (x-mean)/std
# x = x_hat*std = mean
# x[c,h,w]
#mean:[3] => [3,1,1], 因为mean和std 都是一个标量,应加维,才能够进行相加。
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
2、 定义resNet神经网络:
python
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet Block
"""
def __init__(self,ch_in,ch_out,stride=1):
super(ResBlk,self).__init__()
self.conv1 = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra =nn.Sequential()
if ch_out != ch_in:
self.extra = nn.Sequential(
# [b,ch_in,h,w] =>[b,ch_out,h,w]
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self,x):
"""
:param x: [b,ch,h,w]
:return:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
#shor cut
# x :[b,ch_in,h,w] 而out [b,ch_out,h,w]
out = F.relu(self.extra(x) +out) #resNet的精髓所在,能够避免过拟合,梯度爆炸,梯度消失,
return out
class ResNet18(nn.Module):
def __init__(self,num_class):
super(ResNet18,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b,64,h,w] => [b,128,h,w]
self.blk1 = ResBlk(16,32,stride=3)
# [b,128,h,w] => [b,256,h,w]
self.blk2 = ResBlk(32,64,stride=3)
# [b,256,h,w] => [b,512,h,w]
self.blk3 = ResBlk(64, 128,stride=2)
# [b,512,h,w] => [b,1024,h,w]
self.blk4 = ResBlk(128, 256,stride=2)
self.outlayer = nn.Linear(256*3*3,num_class)
def forward(self,x):
x = F.relu(self.conv1(x))
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# x = F.adaptive_avg_pool2d(x,[1,1])
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64,128,stride=1)
tmp = torch.randn(2,64,224,224)
out = blk(tmp)
print(out.shape)
model = ResNet18(5)
tmp = torch.randn(2,3,224,224)
out = model(tmp)
print('resnet:', out.shape)
p = sum(map(lambda p: p.numel(), model.parameters()))
print('parameters size:', p)
if __name__ == '__main__':
main()
3、训练过程
包括数据集的加载,以及三个不同数据集的训练与测试(训练集,验证集,测试集),并且在训练过程中,保留效果最好的一个模型,用该模型对测试集进行测试。
python
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemen import Pokemon
from learing_resnet import ResNet18
batchz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)
train_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="train")
val_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="val")
test_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="test")
train_loader = DataLoader(train_db,batch_size=batchz,shuffle=True,num_workers=4)
val_loader = DataLoader(val_db,batch_size=batchz,shuffle=True,num_workers=2)
test_loader = DataLoader(test_db,batch_size=batchz,shuffle=True,num_workers=2)
viz = visdom.Visdom()
def evaluate(model,loader):
correct = 0
total = len(loader.dataset)
for x,y in loader:
x,y = x.to(device),y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model = ResNet18(5).to(device)
optimzer = optim.Adam(model.parameters(),lr=lr)
criteon = nn.CrossEntropyLoss().to(device)
best_acc ,best_epoch = 0,0
global_step = 0
viz.line([0],[-1],win='loss',opts=dict(title='loss'))
viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
# x: [b,3,224,224].,y:[b]
x,y = x.to(device),y.to(device)
logits = model(x)
loss = criteon(logits,y)
optimzer.zero_grad()
loss.backward()
optimzer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evaluate(model,val_loader)
if val_acc >best_acc :
best_epoch = epoch
best_acc = val_acc
#保存准确率最高的模型
torch.save(model.state_dict(),'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best_acc:',best_acc,'best_epoch:',best_epoch)
#用最好的模型去测试
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evaluate(model,test_loader)
print('test_acc:',test_acc)
if __name__ == '__main__':
main()