在 https://blog.csdn.net/fengbingchun/article/details/112709281 中介绍了AlexNet网络,这里使用PyTorch中提供的AlexNet预训练模型对新数据集进行训练,然后使用生成的模型进行预测。主要包括三部分:新数据集自动拆分、训练、预测
1.新数据集自动拆分:这里使用从网络下载的西瓜、冬瓜图像作为新数据集,只有2类。西瓜存放在单独的目录下,名为watermelon;冬瓜存放在单独的目录下,名为wintermelon。图像总数为264张。
以下为自动拆分数据集的实现代码:
python
import cv2
import os
import random
import shutil
import numpy as np
class SplitClassifyDataset:
"""split the classification dataset"""
def __init__(self, path_src, path_dst, ratios=(0.8, 0.1, 0.1)):
"""
path_src: source dataset path
path_dst: the path to the split dataset
ratios: they are the ratios of train set, validation set, and test set, respectively
"""
assert len(ratios) == 3, f"the length of ratios is not 3: {len(ratios)}"
assert abs(ratios[0] + ratios[1] + ratios[2] - 1) < 1e-05, f"ratios sum must be 1: {ratios[0]}, {ratios[1]}, {ratios[2]}"
self.path_src = path_src
self.path_dst = path_dst
self.ratio_train = ratios[0]
self.ratio_val = ratios[1]
self.ratio_test = ratios[2]
self.is_resize = False
self.fill_value = None
self.shape = None
self.length_total = None
self.classes = None
self.mean = None
self.std = None
self.supported_img_formats = (".bmp", ".jpeg", ".jpg", ".png", ".webp")
def resize(self, value=(114,114,114), shape=(256,256)):
"""
value: fill value
shape: the scaled shape
"""
self.is_resize = True
self.fill_value = value
self.shape = shape
def _create_dir(self):
self.classes = [name for name in os.listdir(self.path_src) if os.path.isdir(os.path.join(self.path_src, name))]
for name in self.classes:
directory = self.path_dst + "/train/" + name
if os.path.exists(directory):
raise ValueError(f"{directory} directory already exists, delete it")
os.makedirs(directory, exist_ok=True)
directory = self.path_dst + "/val/" + name
if os.path.exists(directory):
raise ValueError(f"{directory} directory already exists, delete it")
os.makedirs(directory, exist_ok=True)
if self.ratio_test != 0:
directory = self.path_dst + "/test/" + name
if os.path.exists(directory):
raise ValueError(f"{directory} directory already exists, delete it")
os.makedirs(directory, exist_ok=True)
def _get_images(self):
image_names = {}
self.length_total = 0
for class_name in self.classes:
imgs = []
for root, dirs, files in os.walk(os.path.join(self.path_src, class_name)):
for file in files:
_, extension = os.path.splitext(file)
if extension in self.supported_img_formats:
imgs.append(file)
else:
print(f"Warning: {self.path_src+'/'+class_name+'/'+file} is an unsupported file")
image_names[class_name] = imgs
self.length_total += len(imgs)
return image_names
def _get_random_sequence(self, image_names):
image_sequences = {}
for name in self.classes:
length = len(image_names[name])
numbers = list(range(0, length))
train_sequence = random.sample(numbers, int(length*self.ratio_train))
# print("train_sequence:", train_sequence)
val_sequence = [x for x in numbers if x not in train_sequence]
if self.ratio_test != 0:
val_sequence = random.sample(val_sequence, int(length*self.ratio_val))
# print("val_sequence:", val_sequence)
test_sequence = [x for x in numbers if x not in train_sequence and x not in val_sequence]
# print("test_sequence:", test_sequence)
else:
test_sequence = []
image_sequences[name] = [train_sequence, val_sequence, test_sequence]
return image_sequences
def _letterbox(self, img):
shape = img.shape[:2] # current shape: [height, width, channel]
new_shape = [self.shape[0], self.shape[1]]
# scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.fill_value) # add border
return img
def _copy_image(self):
image_names = self._get_images()
image_sequences = self._get_random_sequence(image_names) # train, val, test
sum = 0
for name in self.classes:
for i in range(3):
sum += len(image_sequences[name][i])
assert self.length_total == sum, f"the length before and afeter the split must be equal: {self.length_total}:{sum}"
for name in self.classes:
dirname = ["train", "val", "test"]
index = [0, 1, 2]
if self.ratio_test == 0:
index = [0, 1]
for idx in index:
for i in image_sequences[name][idx]:
image_name = self.path_src + "/" + name + "/" + image_names[name][i]
dst_dir_name =self.path_dst + "/" + dirname[idx] + "/" + name
# print(image_name)
if not self.is_resize: # only copy
shutil.copy(image_name, dst_dir_name)
else: # resize, scale the image proportionally
img = cv2.imread(image_name) # BGR
if img is None:
raise FileNotFoundError(f"image not found: {image_name}")
img = self._letterbox(img)
cv2.imwrite(dst_dir_name+"/"+image_names[name][i], img)
def _cal_mean_std(self):
imgs = []
std_reds = []
std_greens = []
std_blues = []
for name in self.classes:
dst_dir_name = self.path_dst + "/train/" + name + "/"
for root, dirs, files in os.walk(dst_dir_name):
for file in files:
# print("file:", dst_dir_name+file)
img = cv2.imread(dst_dir_name+file)
if img is None:
raise FileNotFoundError(f"image not found: {dst_dir_name}{file}")
imgs.append(img)
img_array = np.array(img)
std_reds.append(np.std(img_array[:,:,0]))
std_greens.append(np.std(img_array[:,:,1]))
std_blues.append(np.std(img_array[:,:,2]))
arr = np.array(imgs)
# print("arr.shape:", arr.shape)
self.mean = np.mean(arr, axis=(0, 1, 2)) / 255
self.std = [np.mean(std_reds) / 255, np.mean(std_greens) / 255, np.mean(std_blues) / 255] # B,G,R
def __call__(self):
self._create_dir()
self._copy_image()
self._cal_mean_std()
def get_mean_std(self):
"""get the mean and variance"""
return self.mean, self.std
if __name__ == "__main__":
split = SplitClassifyDataset(path_src="../../data/database/classify/melon", path_dst="datasets/melon_new_classify")
split.resize(shape=(256,256))
split()
mean, std = split.get_mean_std()
print(f"mean: {mean}; std: {std}")
print("====== execution completed ======")
说明如下:
(1).实现类名为SplitClassifyDataset,供外层调用;
(2).接收的参数包括:源数据集目录(watermelon目录和wintermelon目录所在的目录);结果存放目录;拆分时训练集、验证集、测试集的比例;图像resize后的大小(图像不能变形,默认使用(114,114,114)填充);
(3).图像会随机被拆分,即每次执行后结果图像会不同;
(4).拆分后会计算训练集的均值和标准差
以下为外层代码调用实现:
python
def split_dataset(src_dataset_path, dst_dataset_path, resize, ratios):
split = SplitClassifyDataset(path_src=src_dataset_path, path_dst=dst_dataset_path, ratios=ast.literal_eval(ratios))
# print("resize:", type(ast.literal_eval(resize))) # str to tuple
split.resize(shape=ast.literal_eval(resize))
split()
mean, std = split.get_mean_std()
print(f"mean: {mean}; std: {std}")
执行后结果如下图所示:输出均值和标准差(后面训练和预测时都需要);新生成的目录组织结构(每个目录下存放一类)满足PyTorch的要求
2.训练:
(1).下载AlexNet预训练模型:仅第一次执行时会从网络下载
python
def load_pretraind_model():
model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1) # the first execution will download model: alexnet-owt-7be5be79.pth, pos: C:\Users\xxxxxx/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth
# print("model:", model)
return model
(2).加载拆分后的数据集:
python
def load_dataset(dataset_path, mean, std, labels_file):
mean = ast.literal_eval(mean) # str to tuple
std = ast.literal_eval(std)
# print(f"type: {type(mean)}, {type(std)}")
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)
print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)
print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")
assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)
write_labels(train_dataset.class_to_idx, labels_file)
return len(train_dataset.class_to_idx), len(train_dataset), len(val_dataset), train_loader, val_loader
(3).将对应的索引和标签写入文件:索引从0开始,依次加1;标签即拆分后的目录名;后面预测时会需要此文件
python
def write_labels(class_to_idx, labels_file):
# print("class_to_idx:", class_to_idx)
with open(labels_file, "w") as file:
for key, val in class_to_idx.items():
file.write("%d %s\n" % (int(val), key))
(4).可视化训练过程中训练集和验证集的Loss和Accuracy:
python
def draw_graph(train_losses, train_accuracies, val_losses, val_accuracies):
plt.subplot(1, 2, 1) # loss
plt.title("Loss curve")
plt.xlabel("Epoch Number")
plt.ylabel("Loss")
plt.plot(train_losses, color="blue")
plt.plot(val_losses, color="red")
plt.legend(["Train Loss", "Val Loss"])
plt.subplot(1, 2, 2) # accuracy
plt.title("Accuracy curve")
plt.xlabel("Epoch Number")
plt.ylabel("Accuracy")
plt.plot(train_accuracies, color="blue")
plt.plot(val_accuracies, color="red")
plt.legend(["Train Accuracy", "Val Accuracy"])
plt.show()
某次的执行结果如下图所示:
(5).主体代码:代码中加入了提前终止训练的判断条件;生成的最终模型名为best.pth
python
def train(dataset_path, epochs, mean, std, model_name, labels_file):
classes_num, train_dataset_num, val_dataset_num, train_loader, val_loader = load_dataset(dataset_path, mean, std, labels_file)
model = load_pretraind_model()
in_features = model.classifier[6].in_features
# print(f"in_features: {in_features}")
model.classifier[6] = nn.Linear(in_features, classes_num) # modify the number of categories
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0002) # set the optimizer
criterion = nn.CrossEntropyLoss() # set the loss
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
highest_accuracy = 0.
minimum_loss = 100.
for epoch in range(epochs):
# reference: https://learnopencv.com/image-classification-using-transfer-learning-in-pytorch/
epoch_start = time.time()
# print(colorama.Fore.CYAN + f"epoch: {epoch+1}/{epochs}")
train_loss = 0.0 # loss
train_acc = 0.0 # accuracy
val_loss = 0.0
val_acc = 0.0
model.train() # set to training mode
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
# print("inputs.size(0):", inputs.size(0))
optimizer.zero_grad() # clean existing gradients
outputs = model(inputs) # forward pass
loss = criterion(outputs, labels) # compute loss
loss.backward() # backpropagate the gradients
optimizer.step() # update the parameters
train_loss += loss.item() * inputs.size(0) # compute the total loss
_, predictions = torch.max(outputs.data, 1) # compute the accuracy
correct_counts = predictions.eq(labels.data.view_as(predictions))
acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
train_acc += acc.item() * inputs.size(0) # compute the total accuracy
# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")
model.eval() # set to evaluation mode
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs) # forward pass
loss = criterion(outputs, labels) # compute loss
val_loss += loss.item() * inputs.size(0) # compute the total loss
_, predictions = torch.max(outputs.data, 1) # compute validation accuracy
correct_counts = predictions.eq(labels.data.view_as(predictions))
acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
val_acc += acc.item() * inputs.size(0) # compute the total accuracy
# print(f"val batch number: {i}; validation loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")
avg_train_loss = train_loss / train_dataset_num # average training loss
avg_train_acc = train_acc / train_dataset_num # average training accuracy
avg_val_loss = val_loss / val_dataset_num # average validation loss
avg_val_acc = val_acc / val_dataset_num # average validation accuracy
train_losses.append(avg_train_loss)
train_accuracies.append(avg_train_acc)
val_losses.append(avg_val_loss)
val_accuracies.append(avg_val_acc)
epoch_end = time.time()
print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.4f}, accuracy:{avg_train_acc:.4f}; validation loss:{avg_val_loss:.4f}, accuracy:{avg_val_acc:.4f}; time:{epoch_end-epoch_start:.2f}s")
if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:
torch.save(model.state_dict(), model_name)
highest_accuracy = avg_val_acc
minimum_loss = avg_val_loss
if avg_val_loss < 0.00001 and avg_val_acc > 0.99999:
print(colorama.Fore.YELLOW + "stop training early")
torch.save(model.state_dict(), model_name)
break
draw_graph(train_losses, train_accuracies, val_losses, val_accuracies)
运行结果如下图所示:
3.预测:
(1).解析上面2.3中生成的文本文件:
python
def parse_labels_file(labels_file):
classes = {}
with open(labels_file, "r") as file:
for line in file:
# print(f"line: {line}")
idx_value = []
for v in line.split(" "):
idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the line
assert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"
classes[int(idx_value[0])] = idx_value[1]
# print(f"clases: {classes}; length: {len(classes)}")
return classes
(2).保存features:供后期使用,长度为9216
python
def get_images_list(images_path):
image_names = []
p = Path(images_path)
for subpath in p.rglob("*"):
if subpath.is_file():
image_names.append(subpath)
return image_names
def save_features(model, input_batch, image_name):
features = model.features(input_batch) # shape: torch.Size([1, 256, 6, 6])
features = model.avgpool(features)
features = torch.flatten(features, 1) # shape: torch.Size([1, 9216])
if torch.cuda.is_available():
features = features.squeeze().detach().cpu().numpy() # shape: (9216,)
else:
features = features.queeeze().detach().numpy()
# print(f"features: {features}; shape: {features.shape}")
dir_name = "tmp"
if not os.path.exists(dir_name):
os.makedirs(dir_name)
file_name = Path(image_name)
file_name = file_name.name
# print(f"file name: {file_name}")
features.tofile(dir_name+"/"+file_name+".bin")
(3).主体代码:
python
def predict(model_name, labels_file, images_path, mean, std):
classes = parse_labels_file(labels_file)
assert len(classes) != 0, "the number of categories can't be 0"
image_names = get_images_list(images_path)
assert len(image_names) != 0, "no images found"
mean = ast.literal_eval(mean) # str to tuple
std = ast.literal_eval(std)
model = models.alexnet(weights=None)
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(in_features, len(classes)) # modify the number of categories
# print("alexnet model:", model)
model.load_state_dict(torch.load(model_name))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("image name\t\t\t\t\t\tclass\tprobability")
model.eval()
with torch.no_grad():
for image_name in image_names:
input_image = Image.open(image_name)
preprocess = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
input_tensor = preprocess(input_image) # (c,h,w)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)
input_batch = input_batch.to(device)
output = model(input_batch)
# print(f"output.shape: {output.shape}")
probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on it
max_value, max_index = torch.max(probabilities, dim=0)
print(f"{image_name}\t\t\t\t\t\t{classes[max_index.item()]}\t{max_value.item():.4f}")
save_features(model, input_batch, image_name)
执行结果如下图所示:由结果可知,虽然数据集很少,训练次数也很少,但预测时能百分百预测准确
支持的输入参数及入口函数如下图所示:
python
def parse_args():
parser = argparse.ArgumentParser(description="AlexNet image classification")
parser.add_argument("--task", required=True, type=str, choices=["split", "train", "predict"], help="specify what kind of task")
parser.add_argument("--src_dataset_path", type=str, help="source dataset path")
parser.add_argument("--dst_dataset_path", type=str, help="the path of the destination dataset after split")
parser.add_argument("--resize", default=(256,256), help="the size to which images are resized when split the dataset")
parser.add_argument("--ratios", default=(0.8,0.1,0.1), help="the ratio of split the data set(train set, validation set, test set), the test set can be 0, but their sum must be 1")
parser.add_argument("--epochs", type=int, help="number of training")
parser.add_argument("--mean", type=str, help="the mean of the training set of images")
parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")
parser.add_argument("--model_name", type=str, help="the model generated during training or the model loaded during prediction")
parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")
parser.add_argument("--predict_images_path", type=str, help="predict images path")
args = parser.parse_args()
return args
if __name__ == "__main__":
colorama.init(autoreset=True)
args = parse_args()
if args.task == "split":
# python test_alexnet.py --task split --src_dataset_path ../../data/database/classify/melon --dst_dataset_path datasets/melon_new_classify --resize (256,256) --ratios (0.7,0.2,0.1)
split_dataset(args.src_dataset_path, args.dst_dataset_path, args.resize, args.ratios)
elif args.task == "train":
# python test_alexnet.py --task train --dst_dataset_path datasets/melon_new_classify --epochs 100 --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txt
train(args.dst_dataset_path, args.epochs, args.mean, args.std, args.model_name, args.labels_file)
else: # predict
# python test_alexnet.py --task predict --predict_images_path datasets/melon_new_classify/test --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txt
predict(args.model_name, args.labels_file, args.predict_images_path, args.mean, args.std)
print(colorama.Fore.GREEN + "====== execution completed ======")