深度学习:数据集处理简单记录

  1. 数据生成
python 复制代码
import torch


def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

print(features)
print(labels)
print(len(features), len(labels))
print(features.shape, labels.shape)
print(features[0], labels[0])

上述是生成一个简单数据集代码,最后的features的shape为torch.Size([1000, 2]),labels的shape为torch.Size([1000, 1])。注意数据集的所有数据都是tensor张量,用于在gpu上计算.

  1. 数据类型转换
python 复制代码
import torch

# 将Python中的标量转换为张量
scalar = 42
tensor_scalar = torch.tensor(scalar)
print(tensor_scalar)

# 将Python中的列表转换为张量
list_data = [1, 2, 3, 4, 5]
tensor_list = torch.tensor(list_data)
print(tensor_list)

# 将NumPy数组转换为张量
import numpy as np
numpy_array = np.array([1.0, 2.0, 3.0])
tensor_numpy = torch.tensor(numpy_array, dtype=torch.float16)
print(tensor_numpy)

上述是一个张量和numpy类型的转换代码

  1. 图片类数据增强(裁剪 + 翻转)
python 复制代码
from PIL import Image
from torchvision import transforms

imgs_path = ['data/000000039769.jpg', 'data/屏幕截图 2024-09-12 140912.png', 'data/屏幕截图 2024-09-12 140916.png','data/屏幕截图 2024-09-12 140919.png']

for i, img_path in enumerate(imgs_path):
    img = Image.open(img_path)
    transform = transforms.Compose([
        # 尺寸变换
        # transforms.Resize((224,244), interpolation=transforms.InterpolationMode.BICUBIC)  当尺寸变化导致图片模糊时,可以通过双三次插值方法,减轻模型效果
        transforms.Resize((256, 256)),
        # 随即裁剪
        transforms.RandomCrop((224, 224)),
        # 镜像翻转
        transforms.RandomHorizontalFlip()
        # 将图片从PIL image 或 ndrray 转化为c* h * w 张量并将相随范围从[0, 255]缩放到[0, 1]
        transforms.ToTensor(),
        # 对图片进行归一化,以提高训练稳定性,加速收敛效果,防止梯度爆炸/消失
	    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    ])
    
    img = transform(img)
    img.save(f'{i+5}.' + img_path.split('.')[-1])

上述为训练模型时的数据增强代码(主要是裁剪+翻转),用于提高模型的鲁棒性

需要注意训练模型时候的模型输入的图片需要尺寸统一且为tensor格式

  1. DataSet与DataLoader应用于自建数据集(以图像数据集为例)
python 复制代码
from torch.utils.data import Dataset


# 所有自定义的数据集类都需要继承torch.utils.data.DataSet类
class CustomDataSet(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    # 数据集大小
    def __len__(self):
        return len(self.images)
    
    # 返回每一项
    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
 
        return image, label

transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])
images_set = ...
labels_set = ...

dataset = CustomDataSet(images=images_set, labels=labels_set, transform=transform)
loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

上述为调用torch.utils.data中的Dataset和DataLoader类应用于自己的数据集,来训练模型

  1. 调用3中的内容
python 复制代码
# 一次读取16个image和label
for images, labels in loader:
	...
  1. 问:为什么数据类型必须是tensor张量,numpy为什么不行?
    答:因为现代深度学习框架(PyTorch、Tensorflow)主要数据结构是张量(Tensor),它们的计算都是基于张量进行的,张量支持自动求导,这对于训练神经网络至关重要。在反向传播算法中,需要计算损失函数关于模型参数的梯度,张量的自动求导机制简化了这一过程。张量是一种多维数组,它可以在GPU上运行,支持自动微分,是深度学习中的基本数据结构。CPU张量存在内存中,而GPU张量存在显存中。NumPy 本身是专门为 CPU 设计的,并不直接支持 GPU 计算。NumPy 数组(即 numpy.ndarray 对象)默认在 CPU 上运行,不能直接在 GPU 上运行。
相关推荐
张人玉1 小时前
人工智能——猴子摘香蕉问题
人工智能
草莓屁屁我不吃1 小时前
Siri因ChatGPT-4o升级:我们的个人信息还安全吗?
人工智能·安全·chatgpt·chatgpt-4o
小言从不摸鱼1 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
AI科研视界1 小时前
ChatGPT+2:修订初始AI安全性和超级智能假设
人工智能·chatgpt
霍格沃兹测试开发学社测试人社区1 小时前
人工智能 | 基于ChatGPT开发人工智能服务平台
软件测试·人工智能·测试开发·chatgpt
小R资源2 小时前
3款免费的GPT类工具
人工智能·gpt·chatgpt·ai作画·ai模型·国内免费
artificiali5 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
酱香编程,风雨兼程5 小时前
深度学习——基础知识
人工智能·深度学习
Lossya5 小时前
【机器学习】参数学习的基本概念以及贝叶斯网络的参数学习和马尔可夫随机场的参数学习
人工智能·学习·机器学习·贝叶斯网络·马尔科夫随机场·参数学习
#include<菜鸡>6 小时前
动手学深度学习(pytorch土堆)-04torchvision中数据集的使用
人工智能·pytorch·深度学习