深度学习:数据集处理简单记录

  1. 数据生成
python 复制代码
import torch


def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

print(features)
print(labels)
print(len(features), len(labels))
print(features.shape, labels.shape)
print(features[0], labels[0])

上述是生成一个简单数据集代码,最后的features的shape为torch.Size([1000, 2]),labels的shape为torch.Size([1000, 1])。注意数据集的所有数据都是tensor张量,用于在gpu上计算.

  1. 数据类型转换
python 复制代码
import torch

# 将Python中的标量转换为张量
scalar = 42
tensor_scalar = torch.tensor(scalar)
print(tensor_scalar)

# 将Python中的列表转换为张量
list_data = [1, 2, 3, 4, 5]
tensor_list = torch.tensor(list_data)
print(tensor_list)

# 将NumPy数组转换为张量
import numpy as np
numpy_array = np.array([1.0, 2.0, 3.0])
tensor_numpy = torch.tensor(numpy_array, dtype=torch.float16)
print(tensor_numpy)

上述是一个张量和numpy类型的转换代码

  1. 图片类数据增强(裁剪 + 翻转)
python 复制代码
from PIL import Image
from torchvision import transforms

imgs_path = ['data/000000039769.jpg', 'data/屏幕截图 2024-09-12 140912.png', 'data/屏幕截图 2024-09-12 140916.png','data/屏幕截图 2024-09-12 140919.png']

for i, img_path in enumerate(imgs_path):
    img = Image.open(img_path)
    transform = transforms.Compose([
        # 尺寸变换
        # transforms.Resize((224,244), interpolation=transforms.InterpolationMode.BICUBIC)  当尺寸变化导致图片模糊时,可以通过双三次插值方法,减轻模型效果
        transforms.Resize((256, 256)),
        # 随即裁剪
        transforms.RandomCrop((224, 224)),
        # 镜像翻转
        transforms.RandomHorizontalFlip()
        # 将图片从PIL image 或 ndrray 转化为c* h * w 张量并将相随范围从[0, 255]缩放到[0, 1]
        transforms.ToTensor(),
        # 对图片进行归一化,以提高训练稳定性,加速收敛效果,防止梯度爆炸/消失
	    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    ])
    
    img = transform(img)
    img.save(f'{i+5}.' + img_path.split('.')[-1])

上述为训练模型时的数据增强代码(主要是裁剪+翻转),用于提高模型的鲁棒性

需要注意训练模型时候的模型输入的图片需要尺寸统一且为tensor格式

  1. DataSet与DataLoader应用于自建数据集(以图像数据集为例)
python 复制代码
from torch.utils.data import Dataset


# 所有自定义的数据集类都需要继承torch.utils.data.DataSet类
class CustomDataSet(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    # 数据集大小
    def __len__(self):
        return len(self.images)
    
    # 返回每一项
    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
 
        return image, label

transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])
images_set = ...
labels_set = ...

dataset = CustomDataSet(images=images_set, labels=labels_set, transform=transform)
loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

上述为调用torch.utils.data中的Dataset和DataLoader类应用于自己的数据集,来训练模型

  1. 调用3中的内容
python 复制代码
# 一次读取16个image和label
for images, labels in loader:
	...
  1. 问:为什么数据类型必须是tensor张量,numpy为什么不行?
    答:因为现代深度学习框架(PyTorch、Tensorflow)主要数据结构是张量(Tensor),它们的计算都是基于张量进行的,张量支持自动求导,这对于训练神经网络至关重要。在反向传播算法中,需要计算损失函数关于模型参数的梯度,张量的自动求导机制简化了这一过程。张量是一种多维数组,它可以在GPU上运行,支持自动微分,是深度学习中的基本数据结构。CPU张量存在内存中,而GPU张量存在显存中。NumPy 本身是专门为 CPU 设计的,并不直接支持 GPU 计算。NumPy 数组(即 numpy.ndarray 对象)默认在 CPU 上运行,不能直接在 GPU 上运行。
相关推荐
2403_8757368710 分钟前
道品科技智慧农业中的自动气象检测站
网络·人工智能·智慧城市
学术头条33 分钟前
AI 的「phone use」竟是这样练成的,清华、智谱团队发布 AutoGLM 技术报告
人工智能·科技·深度学习·语言模型
准橙考典34 分钟前
怎么能更好的通过驾考呢?
人工智能·笔记·自动驾驶·汽车·学习方法
ai_xiaogui37 分钟前
AIStarter教程:快速学会卸载AI项目【AI项目管理平台】
人工智能·ai作画·语音识别·ai写作·ai软件
孙同学要努力42 分钟前
《深度学习》——深度学习基础知识(全连接神经网络)
人工智能·深度学习·神经网络
喵~来学编程啦1 小时前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记
深圳市青牛科技实业有限公司2 小时前
【青牛科技】应用方案|D2587A高压大电流DC-DC
人工智能·科技·单片机·嵌入式硬件·机器人·安防监控
水豚AI课代表2 小时前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
几两春秋梦_2 小时前
符号回归概念
人工智能·数据挖掘·回归
用户691581141653 小时前
Ascend Extension for PyTorch的源码解析
人工智能