使用PyTorch AlexNet预训练模型对新数据集进行训练及预测

https://blog.csdn.net/fengbingchun/article/details/112709281 中介绍了AlexNet网络,这里使用PyTorch中提供的AlexNet预训练模型对新数据集进行训练,然后使用生成的模型进行预测。主要包括三部分:新数据集自动拆分、训练、预测

1.新数据集自动拆分:这里使用从网络下载的西瓜、冬瓜图像作为新数据集,只有2类。西瓜存放在单独的目录下,名为watermelon;冬瓜存放在单独的目录下,名为wintermelon。图像总数为264张。

以下为自动拆分数据集的实现代码:

python 复制代码
import cv2
import os
import random
import shutil
import numpy as np

class SplitClassifyDataset:
	"""split the classification dataset"""

	def __init__(self, path_src, path_dst, ratios=(0.8, 0.1, 0.1)):
		"""
		path_src: source dataset path
		path_dst: the path to the split dataset
		ratios: they are the ratios of train set, validation set, and test set, respectively 
		"""
		assert len(ratios) == 3, f"the length of ratios is not 3: {len(ratios)}"
		assert abs(ratios[0] + ratios[1] + ratios[2] - 1) < 1e-05, f"ratios sum must be 1: {ratios[0]}, {ratios[1]}, {ratios[2]}"

		self.path_src = path_src
		self.path_dst = path_dst
		self.ratio_train = ratios[0]
		self.ratio_val = ratios[1]
		self.ratio_test = ratios[2]

		self.is_resize = False
		self.fill_value = None
		self.shape = None

		self.length_total = None
		self.classes = None

		self.mean = None
		self.std = None

		self.supported_img_formats = (".bmp", ".jpeg", ".jpg", ".png", ".webp")

	def resize(self, value=(114,114,114), shape=(256,256)):
		"""
		value: fill value
		shape: the scaled shape
		"""
		self.is_resize = True
		self.fill_value = value
		self.shape = shape

	def _create_dir(self):
		self.classes = [name for name in os.listdir(self.path_src) if os.path.isdir(os.path.join(self.path_src, name))]

		for name in self.classes:
			directory = self.path_dst + "/train/" + name
			if os.path.exists(directory):
				raise ValueError(f"{directory} directory already exists, delete it")
			os.makedirs(directory, exist_ok=True)

			directory = self.path_dst + "/val/" + name
			if os.path.exists(directory):
				raise ValueError(f"{directory} directory already exists, delete it")
			os.makedirs(directory, exist_ok=True)

			if self.ratio_test != 0:
				directory = self.path_dst + "/test/" + name
				if os.path.exists(directory):
					raise ValueError(f"{directory} directory already exists, delete it")
				os.makedirs(directory, exist_ok=True)

	def _get_images(self):
		image_names = {}
		self.length_total = 0

		for class_name in self.classes:
			imgs = []
			for root, dirs, files in os.walk(os.path.join(self.path_src, class_name)):
				for file in files:
					_, extension = os.path.splitext(file)
					if extension in self.supported_img_formats:
						imgs.append(file)
					else:
						print(f"Warning: {self.path_src+'/'+class_name+'/'+file} is an unsupported file")

			image_names[class_name] = imgs
			self.length_total += len(imgs)

		return image_names

	def _get_random_sequence(self, image_names):
		image_sequences = {}

		for name in self.classes:
			length = len(image_names[name])
			numbers = list(range(0, length))
			train_sequence = random.sample(numbers, int(length*self.ratio_train))
			# print("train_sequence:", train_sequence)

			val_sequence = [x for x in numbers if x not in train_sequence]

			if self.ratio_test != 0:
				val_sequence = random.sample(val_sequence, int(length*self.ratio_val))
				# print("val_sequence:", val_sequence)

				test_sequence = [x for x in numbers if x not in train_sequence and x not in val_sequence]
				# print("test_sequence:", test_sequence)
			else:
				test_sequence = []

			image_sequences[name] = [train_sequence, val_sequence, test_sequence]

		return image_sequences

	def _letterbox(self, img):
		shape = img.shape[:2] # current shape: [height, width, channel]
		new_shape = [self.shape[0], self.shape[1]]

		# scale ratio (new / old)
		r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

		# compute padding
		new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
		dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
		dw /= 2 # divide padding into 2 sides
		dh /= 2

		if shape[::-1] != new_unpad: # resize
			img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)

		top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
		left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
		img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.fill_value) # add border

		return img

	def _copy_image(self):
		image_names = self._get_images()
		image_sequences = self._get_random_sequence(image_names) # train, val, test

		sum = 0
		for name in self.classes:
			for i in range(3):
				sum += len(image_sequences[name][i])
		assert self.length_total == sum, f"the length before and afeter the split must be equal: {self.length_total}:{sum}"

		for name in self.classes:
			dirname = ["train", "val", "test"]
			index = [0, 1, 2]
			if self.ratio_test == 0:
				index = [0, 1]

			for idx in index:
				for i in image_sequences[name][idx]:
					image_name = self.path_src + "/" + name + "/" + image_names[name][i]
					dst_dir_name =self.path_dst + "/" + dirname[idx] + "/" + name
					# print(image_name)
					if not self.is_resize: # only copy
						shutil.copy(image_name, dst_dir_name)
					else: # resize, scale the image proportionally
						img = cv2.imread(image_name) # BGR
						if img is None:
							raise FileNotFoundError(f"image not found: {image_name}")

						img = self._letterbox(img)
						cv2.imwrite(dst_dir_name+"/"+image_names[name][i], img)

	def _cal_mean_std(self):
		imgs = []
		std_reds = []
		std_greens = []
		std_blues = []

		for name in self.classes:
			dst_dir_name = self.path_dst + "/train/" + name + "/"

			for root, dirs, files in os.walk(dst_dir_name):
				for file in files:
					# print("file:", dst_dir_name+file)
					img = cv2.imread(dst_dir_name+file)
					if img is None:
						raise FileNotFoundError(f"image not found: {dst_dir_name}{file}")
					imgs.append(img)

					img_array = np.array(img)
					std_reds.append(np.std(img_array[:,:,0]))
					std_greens.append(np.std(img_array[:,:,1]))
					std_blues.append(np.std(img_array[:,:,2]))

		arr = np.array(imgs)
		# print("arr.shape:", arr.shape)
		self.mean = np.mean(arr, axis=(0, 1, 2)) / 255
		self.std = [np.mean(std_reds) / 255, np.mean(std_greens) / 255, np.mean(std_blues) / 255] # B,G,R

	def __call__(self):
		self._create_dir()
		self._copy_image()
		self._cal_mean_std()

	def get_mean_std(self):
		"""get the mean and variance"""
		return self.mean, self.std

if __name__ == "__main__":
	split = SplitClassifyDataset(path_src="../../data/database/classify/melon", path_dst="datasets/melon_new_classify")
	split.resize(shape=(256,256))

	split()
	mean, std = split.get_mean_std()
	print(f"mean: {mean}; std: {std}")
	print("====== execution completed ======")

说明如下:

(1).实现类名为SplitClassifyDataset,供外层调用;

(2).接收的参数包括:源数据集目录(watermelon目录和wintermelon目录所在的目录);结果存放目录;拆分时训练集、验证集、测试集的比例;图像resize后的大小(图像不能变形,默认使用(114,114,114)填充);

(3).图像会随机被拆分,即每次执行后结果图像会不同;

(4).拆分后会计算训练集的均值和标准差

以下为外层代码调用实现:

python 复制代码
def split_dataset(src_dataset_path, dst_dataset_path, resize, ratios):
	split = SplitClassifyDataset(path_src=src_dataset_path, path_dst=dst_dataset_path, ratios=ast.literal_eval(ratios))

	# print("resize:", type(ast.literal_eval(resize))) # str to tuple
	split.resize(shape=ast.literal_eval(resize))

	split()
	mean, std = split.get_mean_std()
	print(f"mean: {mean}; std: {std}")

执行后结果如下图所示:输出均值和标准差(后面训练和预测时都需要);新生成的目录组织结构(每个目录下存放一类)满足PyTorch的要求

2.训练:

(1).下载AlexNet预训练模型:仅第一次执行时会从网络下载

python 复制代码
def load_pretraind_model():
	model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1) # the first execution will download model: alexnet-owt-7be5be79.pth, pos: C:\Users\xxxxxx/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth
	# print("model:", model)

	return model

(2).加载拆分后的数据集:

python 复制代码
def load_dataset(dataset_path, mean, std, labels_file):
	mean = ast.literal_eval(mean) # str to tuple
	std = ast.literal_eval(std)
	# print(f"type: {type(mean)}, {type(std)}")

	train_transform = transforms.Compose([
		transforms.RandomHorizontalFlip(p=0.5),
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize(mean=mean, std=std),
	])

	train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)
	print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")

	train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

	val_transform = transforms.Compose([
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize(mean=mean, std=std),
	])

	val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)
	print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")
	assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"

	val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)

	write_labels(train_dataset.class_to_idx, labels_file)

	return len(train_dataset.class_to_idx), len(train_dataset), len(val_dataset), train_loader, val_loader

(3).将对应的索引和标签写入文件:索引从0开始,依次加1;标签即拆分后的目录名;后面预测时会需要此文件

python 复制代码
def write_labels(class_to_idx, labels_file):
	# print("class_to_idx:", class_to_idx)
	with open(labels_file, "w") as file:
		for key, val in class_to_idx.items():
			file.write("%d %s\n" % (int(val), key))

(4).可视化训练过程中训练集和验证集的Loss和Accuracy:

python 复制代码
def draw_graph(train_losses, train_accuracies, val_losses, val_accuracies):
	plt.subplot(1, 2, 1) # loss
	plt.title("Loss curve")
	plt.xlabel("Epoch Number")
	plt.ylabel("Loss")
	plt.plot(train_losses, color="blue")
	plt.plot(val_losses, color="red")
	plt.legend(["Train Loss", "Val Loss"])

	plt.subplot(1, 2, 2) # accuracy
	plt.title("Accuracy curve")
	plt.xlabel("Epoch Number")
	plt.ylabel("Accuracy")
	plt.plot(train_accuracies, color="blue")
	plt.plot(val_accuracies, color="red")
	plt.legend(["Train Accuracy", "Val Accuracy"])

	plt.show()

某次的执行结果如下图所示:

(5).主体代码:代码中加入了提前终止训练的判断条件;生成的最终模型名为best.pth

python 复制代码
def train(dataset_path, epochs, mean, std, model_name, labels_file):
	classes_num, train_dataset_num, val_dataset_num, train_loader, val_loader = load_dataset(dataset_path, mean, std, labels_file)

	model = load_pretraind_model()

	in_features = model.classifier[6].in_features
	# print(f"in_features: {in_features}")
	model.classifier[6] = nn.Linear(in_features, classes_num) # modify the number of categories

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	model.to(device)

	optimizer = optim.Adam(model.parameters(), lr=0.0002) # set the optimizer
	criterion = nn.CrossEntropyLoss() # set the loss

	train_losses = []
	train_accuracies = []
	val_losses = []
	val_accuracies = []

	highest_accuracy = 0.
	minimum_loss = 100.

	for epoch in range(epochs):
		# reference: https://learnopencv.com/image-classification-using-transfer-learning-in-pytorch/
		epoch_start = time.time()
		# print(colorama.Fore.CYAN + f"epoch: {epoch+1}/{epochs}")

		train_loss = 0.0 # loss
		train_acc = 0.0 # accuracy
		val_loss = 0.0
		val_acc = 0.0

		model.train() # set to training mode
		for i, (inputs, labels) in enumerate(train_loader):
			inputs = inputs.to(device)
			labels = labels.to(device)
			# print("inputs.size(0):", inputs.size(0))

			optimizer.zero_grad() # clean existing gradients
			outputs = model(inputs) # forward pass
			loss = criterion(outputs, labels) # compute loss
			loss.backward() # backpropagate the gradients
			optimizer.step() # update the parameters

			train_loss += loss.item() * inputs.size(0) # compute the total loss
			_, predictions = torch.max(outputs.data, 1) # compute the accuracy
			correct_counts = predictions.eq(labels.data.view_as(predictions))
			acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
			train_acc += acc.item() * inputs.size(0) # compute the total accuracy
			# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")

		model.eval() # set to evaluation mode
		with torch.no_grad():
			for i, (inputs, labels) in enumerate(val_loader):
				inputs = inputs.to(device)
				labels = labels.to(device)

				outputs = model(inputs) # forward pass
				loss = criterion(outputs, labels) # compute loss
				val_loss += loss.item() * inputs.size(0) # compute the total loss
				_, predictions = torch.max(outputs.data, 1) # compute validation accuracy
				correct_counts = predictions.eq(labels.data.view_as(predictions))
				acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
				val_acc += acc.item() * inputs.size(0) # compute the total accuracy
				# print(f"val batch number: {i}; validation loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")

		avg_train_loss = train_loss / train_dataset_num # average training loss
		avg_train_acc = train_acc / train_dataset_num # average training accuracy
		avg_val_loss = val_loss / val_dataset_num # average validation loss
		avg_val_acc = val_acc / val_dataset_num # average validation accuracy
		train_losses.append(avg_train_loss)
		train_accuracies.append(avg_train_acc)
		val_losses.append(avg_val_loss)
		val_accuracies.append(avg_val_acc)
		epoch_end = time.time()
		print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.4f}, accuracy:{avg_train_acc:.4f}; validation loss:{avg_val_loss:.4f}, accuracy:{avg_val_acc:.4f}; time:{epoch_end-epoch_start:.2f}s")

		if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:
			torch.save(model.state_dict(), model_name)
			highest_accuracy = avg_val_acc
			minimum_loss = avg_val_loss

		if avg_val_loss < 0.00001 and avg_val_acc > 0.99999:
			print(colorama.Fore.YELLOW + "stop training early")
			torch.save(model.state_dict(), model_name)
			break

	draw_graph(train_losses, train_accuracies, val_losses, val_accuracies)

运行结果如下图所示:

3.预测:

(1).解析上面2.3中生成的文本文件:

python 复制代码
def parse_labels_file(labels_file):
	classes = {}

	with open(labels_file, "r") as file:
		for line in file:
			# print(f"line: {line}")
			idx_value = []
			for v in line.split(" "):
				idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the line
			assert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"
			classes[int(idx_value[0])] = idx_value[1]

	# print(f"clases: {classes}; length: {len(classes)}")
	return classes

(2).保存features:供后期使用,长度为9216

python 复制代码
def get_images_list(images_path):
	image_names = []

	p = Path(images_path)
	for subpath in p.rglob("*"):
		if subpath.is_file():
			image_names.append(subpath)

	return image_names

def save_features(model, input_batch, image_name):
	features = model.features(input_batch) # shape: torch.Size([1, 256, 6, 6])
	features = model.avgpool(features)
	features = torch.flatten(features, 1) # shape: torch.Size([1, 9216])

	if torch.cuda.is_available():
		features = features.squeeze().detach().cpu().numpy() # shape: (9216,)
	else:
		features = features.queeeze().detach().numpy()
	# print(f"features: {features}; shape: {features.shape}")

	dir_name = "tmp"
	if not os.path.exists(dir_name):
		os.makedirs(dir_name)

	file_name = Path(image_name)
	file_name = file_name.name
	# print(f"file name: {file_name}")
	features.tofile(dir_name+"/"+file_name+".bin")

(3).主体代码:

python 复制代码
def predict(model_name, labels_file, images_path, mean, std):
	classes = parse_labels_file(labels_file)
	assert len(classes) != 0, "the number of categories can't be 0"

	image_names = get_images_list(images_path)
	assert len(image_names) != 0, "no images found"

	mean = ast.literal_eval(mean) # str to tuple
	std = ast.literal_eval(std)

	model = models.alexnet(weights=None)
	in_features = model.classifier[6].in_features
	model.classifier[6] = nn.Linear(in_features, len(classes)) # modify the number of categories
	# print("alexnet model:", model)
	model.load_state_dict(torch.load(model_name))

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	model.to(device)
	print("image name\t\t\t\t\t\tclass\tprobability")

	model.eval()
	with torch.no_grad():
		for image_name in image_names:
			input_image = Image.open(image_name)
			preprocess = transforms.Compose([
				transforms.CenterCrop(224),
				transforms.ToTensor(),
				transforms.Normalize(mean=mean, std=std)
			])

			input_tensor = preprocess(input_image) # (c,h,w)
			input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)
			input_batch = input_batch.to(device)

			output = model(input_batch)
			# print(f"output.shape: {output.shape}")
			probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on it
			max_value, max_index = torch.max(probabilities, dim=0)
			print(f"{image_name}\t\t\t\t\t\t{classes[max_index.item()]}\t{max_value.item():.4f}")

			save_features(model, input_batch, image_name)

执行结果如下图所示:由结果可知,虽然数据集很少,训练次数也很少,但预测时能百分百预测准确

支持的输入参数及入口函数如下图所示:

python 复制代码
def parse_args():
	parser = argparse.ArgumentParser(description="AlexNet image classification")
	parser.add_argument("--task", required=True, type=str, choices=["split", "train", "predict"], help="specify what kind of task")
	parser.add_argument("--src_dataset_path", type=str, help="source dataset path")
	parser.add_argument("--dst_dataset_path", type=str, help="the path of the destination dataset after split")
	parser.add_argument("--resize", default=(256,256), help="the size to which images are resized when split the dataset")
	parser.add_argument("--ratios", default=(0.8,0.1,0.1), help="the ratio of split the data set(train set, validation set, test set), the test set can be 0, but their sum must be 1")
	parser.add_argument("--epochs", type=int, help="number of training")
	parser.add_argument("--mean", type=str, help="the mean of the training set of images")
	parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")
	parser.add_argument("--model_name", type=str, help="the model generated during training or the model loaded during prediction")
	parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")
	parser.add_argument("--predict_images_path", type=str, help="predict images path")

	args = parser.parse_args()
	return args

if __name__ == "__main__":
	colorama.init(autoreset=True)
	args = parse_args()

	if args.task == "split":
		# python test_alexnet.py --task split --src_dataset_path ../../data/database/classify/melon --dst_dataset_path datasets/melon_new_classify --resize (256,256) --ratios (0.7,0.2,0.1)
		split_dataset(args.src_dataset_path, args.dst_dataset_path, args.resize, args.ratios)
	elif args.task == "train":
		# python test_alexnet.py --task train --dst_dataset_path datasets/melon_new_classify --epochs 100 --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txt
		train(args.dst_dataset_path, args.epochs, args.mean, args.std, args.model_name, args.labels_file)
	else: # predict
		# python test_alexnet.py --task predict --predict_images_path datasets/melon_new_classify/test --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txt
		predict(args.model_name, args.labels_file, args.predict_images_path, args.mean, args.std)

	print(colorama.Fore.GREEN + "====== execution completed ======")

GitHubhttps://github.com/fengbingchun/NN_Test

相关推荐
四口鲸鱼爱吃盐10 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf10 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零110 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐11 小时前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗11 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
四口鲸鱼爱吃盐16 小时前
Pytorch | 利用VMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
四口鲸鱼爱吃盐16 小时前
Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
love you joyfully1 天前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
这个男人是小帅2 天前
【AutoDL】通过【SSH远程连接】【vscode】
运维·人工智能·pytorch·vscode·深度学习·ssh
四口鲸鱼爱吃盐2 天前
Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python