使用pytorch处理自己的数据集

目录

[1 返回本地文件中的数据集](#1 返回本地文件中的数据集)

[2 根据当前已有的数据集创建每一个样本数据对应的标签](#2 根据当前已有的数据集创建每一个样本数据对应的标签)

[3 tensorboard的使用](#3 tensorboard的使用)

[4 transforms处理数据](#4 transforms处理数据)

tranfroms.Totensor的使用

transforms.Normalize的使用

transforms.Resize的使用

transforms.Compose使用

[5 dataset_transforms使用](#5 dataset_transforms使用)


1 返回本地文件中的数据集

在这个操作中,当前数据集的上一级目录就是当前所有同一数据的label

python 复制代码
import os
from torch.utils.data import Dataset
from PIL import Image

class MyDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        """
        :param root_dir: 根目录文件
        :param label_dir: 分类标签目录
        """
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)
        self.image_path_list = os.listdir(self.path)
    def __getitem__(self, idx):
        """
        :param idx: idx是自己文件夹下的每一个图片索引
        :return: 返回每一个图片对象和其对应的标签,对于返回类型可以直接调用image.show显示或者用于后续图像处理
        """
        img_name = self.image_path_list[idx]
        ever_image_path = os.path.join(self.root_dir, self.label_dir, img_name)
        image = Image.open(ever_image_path)
        label = self.label_dir
        return image, label
    def __len__(self):
        return len(self.image_path_list)

root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
label_dir = 'cats'
my_data = MyDataset(root_dir, label_dir)
first_pic, label = my_data[0]   # 自动调用__getitem__(self, idx)
first_pic.show()
print("当前图片中动物所属label", label)

F:\Anaconda\envs\py38\python.exe G:/python_files/深度学习代码库/dataset/MyDataSet.py

当前图片中动物所属label cats

2 根据当前已有的数据集创建每一个样本数据对应的标签

python 复制代码
import os
from torch.utils.data import Dataset
from PIL import Image

class MyLabelData:
    def __init__(self, root_dir, target_dir, label_dir, label_name):
        """
        :param root_dir: 根目录
        :param target_dir: 生成标签的目录
        :param label_dir: 要生成为标签目录名称
        :param label_name: 生成的标签名称
        """
        self.root_dir = root_dir
        self.target_dir = target_dir
        self.label_dir = label_dir
        self.label_name = label_name
        self.image_name_list = os.listdir(os.path.join(root_dir, target_dir))
    def label(self):
        for name in self.image_name_list:
            file_name = name.split(".jpg", 1)[0]
            label_path = os.path.join(self.root_dir, self.label_dir)
            if not os.path.exists(label_path):
                os.makedirs(label_path)
            with open(os.path.join(label_path, '{}'.format(file_name)), 'w') as f:
                f.write(self.label_name)
                f.close()
root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
target_dir = 'cats'
label_dir = 'cats_label'
label_name = 'cat'
label = MyLabelData(root_dir, target_dir, label_dir, label_name)
label.label()

这样上面的代码中的训练集目录下的每一个样本都会在train的cats_label目录下创建其对应的分类标签

每一个标签中文件中都有一个cat字符串或者其他动物的分类名称,以确定它到底是哪一个动物

3 tensorboard的使用

python 复制代码
# tensorboard --logdir=深度学习代码库/logs --port=2001
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
for i in range(100):
    writer.add_scalar('当前的函数表达式y=3*x',i*3,i)
writer.close()
#-----------------------------------------------------------
import numpy as np
from PIL import Image
image_PIL = Image.open('G:\python_files\深度学习代码库\cats_and_dogs_small\\train\cats\cat.1.jpg')
image_numpy = np.array(image_PIL)
print(type(image_numpy))
print(image_numpy.shape)
writer.add_image('cat图片', image_numpy,2, dataformats='HWC')

这里使用tensorboard的作用是为了更好的展示数据,但是对于函数的使用,比如上面的add_image中的参数,最好的方式是点击源码查看其对应的参数类型,然后根据实际需要将它所需的数据类型丢给add_image就好,而在源码中该函数的参数中所要求的图片类型必须是tensor类型或者是numpy,所以想要使用tensorboard展示数据就首先必须使用numpy或者使用transforms.Totensor将其转化为tensor,然后丢给add_image函数

还有一个需要注意的是,使用add_image函数,图片的tensor类型或者numpy类型必须和dataformats的默认数据类型一样,否则根据图片的数据类型修改后面的额dataformatas就好

4 transforms处理数据

tranfroms.Totensor的使用
python 复制代码
import numpy as np
from torchvision import transforms
from PIL import Image
tran = transforms.ToTensor()
PIL_image = Image.open('G:\python_files\深度学习代码库\\cats\cat\cat.11.jpg')
tensor_pic = tran(PIL_image)
print(tensor_pic)
print(tensor_pic.shape)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('logs')
write.add_image('Tensor_picture',tensor_pic)

tensor(\[\[0.9216, 0.9059, 0.8353, ..., 0.2392, 0.2275, 0.2078,

0.9765, 0.9216, 0.8118, ..., 0.2431, 0.2392, 0.2235,

0.9490, 0.8745, 0.7608, ..., 0.2471, 0.2471, 0.2314,

...,

0.3490, 0.4902, 0.6667, ..., 0.7804, 0.7804, 0.7804,

0.3412, 0.4431, 0.5216, ..., 0.7765, 0.7922, 0.7882,

0.3490, 0.4510, 0.5294, ..., 0.7765, 0.7922, 0.7882],

\[0.9451, 0.9294, 0.8706, ..., 0.2980, 0.2863, 0.2667,

1.0000, 0.9451, 0.8471, ..., 0.3020, 0.2980, 0.2824,

0.9725, 0.8980, 0.7961, ..., 0.2980, 0.2980, 0.2824,

...,

0.3725, 0.5137, 0.6902, ..., 0.8431, 0.8431, 0.8431,

0.3647, 0.4667, 0.5451, ..., 0.8392, 0.8549, 0.8510,

0.3608, 0.4627, 0.5412, ..., 0.8392, 0.8549, 0.8510],

\[0.9294, 0.9137, 0.8588, ..., 0.2235, 0.2118, 0.1922,

0.9922, 0.9373, 0.8353, ..., 0.2275, 0.2235, 0.2078,

0.9725, 0.8980, 0.7922, ..., 0.2275, 0.2275, 0.2118,

...,

0.4196, 0.5608, 0.7373, ..., 0.9412, 0.9412, 0.9333,

0.4196, 0.5216, 0.6000, ..., 0.9373, 0.9529, 0.9412,

0.4196, 0.5216, 0.6000, ..., 0.9373, 0.9529, 0.9412]])

torch.Size(3, 410, 431)

transforms.Normalize的使用
python 复制代码
# 对应三个通道,每一个通道一个平均值和方差
# output[channel] = (input[channel] - mean[channel]) / std[channel]
nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
print(tensor_pic[0][0][0])
x_nor = nor(tensor_pic)
write.add_image('nor_picture:', x_nor)
print(tensor_pic[0][0][0])
write.close()

打开源码查看

复制代码
def forward(self, tensor: Tensor) -> Tensor:
    """
    Args:
        tensor (Tensor): Tensor image to be normalized.

    Returns:
        Tensor: Normalized Tensor image.
    """
    return F.normalize(tensor, self.mean, self.std, self.inplace)

必须传入的是tensor数据类型

transforms.Resize的使用
python 复制代码
size_tensor = transforms.Resize((512,512))
# 裁剪tensor
tensor_pic_size = size_tensor(tensor_pic)
# 裁剪Image
size_pic = transforms.Resize((512,512))
image_size = size_pic(PIL_image)
print(image_size)
write.add_image('tensor_pic_size',tensor_pic_size)
print(tensor_pic_size.shape)
np_image = np.array(image_size)
print('np_image.shape:', np_image.shape)
write.add_image('image_size', np_image, dataformats='HWC')

调用Resize的时候,需要传入的数据类型的要求,查看源码如下

复制代码
def forward(self, img):
    """
    Args:
        img (PIL Image or Tensor): Image to be scaled.

    Returns:
        PIL Image or Tensor: Rescaled image.
    """
    return F.resize(img, self.size, self.interpolation)

<PIL.Image.Image image mode=RGB size=512x512 at 0x1A72B1E7D00>

torch.Size(3, 512, 512)

np_image.shape: (512, 512, 3)

transforms.Compose使用
python 复制代码
nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
trans_resize_2 = transforms.Resize((64,64))
trans_to_tensor = transforms.ToTensor()
trans_compose = transforms.Compose([trans_resize_2, trans_to_tensor])
tensor_pic_compose = trans_compose(PIL_image)
write.add_image('tensor_pic_compose',tensor_pic_compose,dataformats='CHW')
复制代码
class Compose:
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

5 dataset_transforms使用

python 复制代码
from torch.utils.data import DataLoader
from torchvision import  transforms
import torchvision
data_transform = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, download=True)
test_data = torchvision.datasets.CIFAR10('./data', train=False, download=True)
python 复制代码
print("train_data", train_data)
# 原始的数据集中每一条数据中包含以一张图片和该图片所属的类别
print("train_data[0]", train_data[0])
print("train_data.classes", train_data.classes)
image, label = train_data[0]
print("label ",label)
image.show()
print("train_data.classes[label]", train_data.classes[label])

train_data Dataset CIFAR10

Number of datapoints: 50000

Root location: ./data

Split: Train

train_data0 (<PIL.Image.Image image mode=RGB size=32x32 at 0x144ED58D970>, 6)

train_data.classes 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'

label 6

train_data.classeslabel frog

python 复制代码
#%%
from torchvision import transforms
import torchvision
# 将整个数据集转化为tensor类型
data_transform1 = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform1, download=True)
test_data1 = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform1, download=True)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('batch_picture')
for i in range(10):
    tensor_pic, label = train_data[i]  # 经过前面的transforms成了tensor
    print(tensor_pic.shape)
    write.add_image('batch_picture', tensor_pic, i)
write.close()

Files already downloaded and verified

Files already downloaded and verified

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

torch.Size(3, 32, 32)

复制代码
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
    """Add image data to summary.

    Note that this requires the ``pillow`` package.

    Args:
        tag (string): Data identifier
        img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
        global_step (int): Global step value to record
        walltime (float): Optional override default walltime (time.time())
          seconds after epoch of event
    Shape:
        img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
        convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
        Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
        corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.

    Examples::

        from torch.utils.tensorboard import SummaryWriter
        import numpy as np
        img = np.zeros((3, 100, 100))
        img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
        img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

        img_HWC = np.zeros((100, 100, 3))
        img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
        img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

        writer = SummaryWriter()
        writer.add_image('my_image', img, 0)

        # If you have non-default dimension setting, set the dataformats argument.
        writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
        writer.close()

    Expected result:

    .. image:: _static/img/tensorboard/add_image.png
       :scale: 50 %

    """
    torch._C._log_api_usage_once("tensorboard.logging.add_image")
    if self._check_caffe2_blob(img_tensor):
        from caffe2.python import workspace
        img_tensor = workspace.FetchBlob(img_tensor)
    self._get_file_writer().add_summary(
        image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
python 复制代码
from torchvision import transforms
import torchvision
# 将整个数据集转化为tensor类型
data_transform = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform, download=True)
test_data = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform, download=True)
# dataLoad会将原始数据中一个batch中的图片和图片的Label分别放在一起,形成对应
train_data_load = DataLoader(dataset=train_data, shuffle=True, batch_size=64,)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('dataLoad')
# 遍历整个load,一次遍历的图片是64个
for batch_id, data in enumerate(train_data_load):
    # 经过DataLoda之后,每一个批次返回一批图片和该图片对应的标签类别
    print('data',data)
    batch_image, batch_label = data
    print("batch_id",batch_id)
    print("image.shape", batch_image.shape)
    print("label.shape", batch_label.shape)
    write.add_images('batch_load_picture', batch_image, batch_id, dataformats='NCHW')
write.close()
复制代码
其中一个批次的输出结果展示
batch_id 646
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
data [tensor([[[[0.2510, 0.3804, 0.5176,  ..., 0.5529, 0.5451, 0.2980],
          [0.2706, 0.6000, 0.6667,  ..., 0.5686, 0.3961, 0.1176],
          [0.2745, 0.6627, 0.6980,  ..., 0.3961, 0.1608, 0.0824],
          ...,
          [0.6863, 0.6824, 0.5333,  ..., 0.2941, 0.4863, 0.5059],
          [0.5804, 0.6784, 0.4902,  ..., 0.1451, 0.2824, 0.3451],
          [0.4353, 0.4353, 0.5098,  ..., 0.1373, 0.1529, 0.2902]],

         [[0.3020, 0.4549, 0.6078,  ..., 0.6627, 0.6353, 0.3608],
          [0.3451, 0.6980, 0.7765,  ..., 0.6745, 0.4706, 0.1647],
          [0.3490, 0.7529, 0.8039,  ..., 0.4667, 0.2000, 0.1137],
          ...,
          [0.8196, 0.8157, 0.6157,  ..., 0.3608, 0.5529, 0.5804],
          [0.7137, 0.8039, 0.5686,  ..., 0.1922, 0.3373, 0.4078],
          [0.5412, 0.5333, 0.5765,  ..., 0.1765, 0.2000, 0.3490]],

         [[0.3098, 0.5490, 0.7412,  ..., 0.8314, 0.7373, 0.3765],
          [0.3765, 0.8392, 0.9569,  ..., 0.7686, 0.4941, 0.1216],
          [0.3843, 0.9176, 1.0000,  ..., 0.4627, 0.1490, 0.0588],
          ...,
          [0.9843, 0.9922, 0.7373,  ..., 0.3882, 0.6353, 0.7255],
          [0.8039, 0.9373, 0.6745,  ..., 0.1804, 0.3647, 0.4941],
          [0.6471, 0.6549, 0.6980,  ..., 0.1569, 0.2000, 0.3961]]],


        [[[0.9608, 0.9490, 0.9529,  ..., 0.8314, 0.8196, 0.8235],
          [0.9255, 0.9216, 0.9333,  ..., 0.8275, 0.8196, 0.8235],
          [0.9137, 0.9137, 0.9294,  ..., 0.8392, 0.8314, 0.8353],
          ...,
          [0.4118, 0.4353, 0.4431,  ..., 0.4157, 0.4431, 0.4275],
          [0.4667, 0.4667, 0.4627,  ..., 0.3961, 0.3804, 0.3882],
          [0.4392, 0.4235, 0.4235,  ..., 0.5490, 0.4471, 0.4706]],

         [[0.9647, 0.9529, 0.9529,  ..., 0.8745, 0.8667, 0.8667],
          [0.9294, 0.9255, 0.9333,  ..., 0.8627, 0.8549, 0.8549],
          [0.9137, 0.9176, 0.9294,  ..., 0.8627, 0.8588, 0.8549],
          ...,
          [0.4196, 0.4392, 0.4471,  ..., 0.4314, 0.4627, 0.4510],
          [0.4745, 0.4745, 0.4706,  ..., 0.4078, 0.4039, 0.4118],
          [0.4471, 0.4314, 0.4314,  ..., 0.5608, 0.4667, 0.4863]],

         [[0.9765, 0.9686, 0.9647,  ..., 0.9412, 0.9373, 0.9569],
          [0.9451, 0.9412, 0.9529,  ..., 0.9216, 0.9216, 0.9373],
          [0.9451, 0.9451, 0.9569,  ..., 0.9176, 0.9176, 0.9333],
          ...,
          [0.4078, 0.4314, 0.4353,  ..., 0.4353, 0.4706, 0.4588],
          [0.4627, 0.4627, 0.4588,  ..., 0.4118, 0.4118, 0.4157],
          [0.4353, 0.4196, 0.4196,  ..., 0.5569, 0.4627, 0.4863]]],


        [[[0.9569, 0.9569, 0.9647,  ..., 0.8510, 0.8353, 0.8235],
          [0.9569, 0.9569, 0.9608,  ..., 0.8627, 0.8431, 0.8392],
          [0.9804, 0.9725, 0.9725,  ..., 0.8745, 0.8627, 0.8549],
          ...,
          [0.3725, 0.3882, 0.3922,  ..., 0.3647, 0.3725, 0.3686],
          [0.3882, 0.4000, 0.4157,  ..., 0.3882, 0.3804, 0.3608],
          [0.3882, 0.4000, 0.4118,  ..., 0.3725, 0.3608, 0.3490]],

         [[0.9608, 0.9608, 0.9686,  ..., 0.8706, 0.8549, 0.8392],
          [0.9608, 0.9608, 0.9686,  ..., 0.8784, 0.8549, 0.8510],
          [0.9843, 0.9765, 0.9804,  ..., 0.8863, 0.8745, 0.8627],
          ...,
          [0.3804, 0.3922, 0.3961,  ..., 0.3255, 0.3529, 0.3686],
          [0.3961, 0.4078, 0.4235,  ..., 0.3647, 0.3686, 0.3647],
          [0.3961, 0.4078, 0.4196,  ..., 0.3843, 0.3686, 0.3569]],

         [[0.9843, 0.9765, 0.9804,  ..., 0.9294, 0.9176, 0.9137],
          [0.9804, 0.9686, 0.9725,  ..., 0.9216, 0.9059, 0.9098],
          [0.9961, 0.9804, 0.9765,  ..., 0.9137, 0.9098, 0.9098],
          ...,
          [0.3725, 0.3882, 0.3922,  ..., 0.2902, 0.3255, 0.3686],
          [0.3922, 0.4039, 0.4196,  ..., 0.3412, 0.3490, 0.3608],
          [0.3922, 0.4039, 0.4157,  ..., 0.3843, 0.3686, 0.3529]]],


        ...,


        [[[0.8902, 0.8863, 0.8824,  ..., 0.8314, 0.8392, 0.8353],
          [0.8902, 0.8863, 0.8863,  ..., 0.8353, 0.8431, 0.8392],
          [0.8902, 0.8863, 0.8902,  ..., 0.8392, 0.8431, 0.8431],
          ...,
          [0.9569, 0.9529, 0.9569,  ..., 0.5765, 0.5843, 0.5961],
          [0.9686, 0.9647, 0.9608,  ..., 0.9412, 0.9255, 0.9255],
          [0.9804, 0.9765, 0.9725,  ..., 0.9255, 0.9176, 0.9176]],

         [[0.9176, 0.9137, 0.9098,  ..., 0.8667, 0.8745, 0.8706],
          [0.9176, 0.9137, 0.9137,  ..., 0.8706, 0.8784, 0.8745],
          [0.9176, 0.9137, 0.9176,  ..., 0.8784, 0.8824, 0.8784],
          ...,
          [0.9608, 0.9569, 0.9608,  ..., 0.6392, 0.6667, 0.6706],
          [0.9765, 0.9725, 0.9647,  ..., 0.9608, 0.9765, 0.9725],
          [0.9882, 0.9843, 0.9804,  ..., 0.9255, 0.9451, 0.9490]],

         [[0.9412, 0.9373, 0.9333,  ..., 0.9255, 0.9333, 0.9294],
          [0.9412, 0.9373, 0.9373,  ..., 0.9294, 0.9373, 0.9333],
          [0.9412, 0.9373, 0.9412,  ..., 0.9294, 0.9333, 0.9333],
          ...,
          [0.9686, 0.9647, 0.9686,  ..., 0.6667, 0.6824, 0.6863],
          [0.9725, 0.9686, 0.9647,  ..., 0.9804, 0.9804, 0.9804],
          [0.9843, 0.9804, 0.9765,  ..., 0.9373, 0.9451, 0.9490]]],


        [[[0.1725, 0.1725, 0.1804,  ..., 0.1255, 0.1255, 0.1255],
          [0.1922, 0.1882, 0.1843,  ..., 0.1333, 0.1373, 0.1333],
          [0.1961, 0.1922, 0.1882,  ..., 0.1412, 0.1412, 0.1333],
          ...,
          [0.4471, 0.4902, 0.5137,  ..., 0.5647, 0.5725, 0.5961],
          [0.4431, 0.4706, 0.4824,  ..., 0.5608, 0.5529, 0.5569],
          [0.4275, 0.4431, 0.4392,  ..., 0.6078, 0.5608, 0.5176]],

         [[0.0980, 0.0980, 0.1059,  ..., 0.0353, 0.0353, 0.0392],
          [0.1137, 0.1137, 0.1098,  ..., 0.0431, 0.0471, 0.0471],
          [0.1216, 0.1176, 0.1137,  ..., 0.0549, 0.0549, 0.0549],
          ...,
          [0.2471, 0.2824, 0.3529,  ..., 0.5490, 0.5451, 0.5608],
          [0.2510, 0.2980, 0.3765,  ..., 0.5569, 0.5294, 0.5255],
          [0.2471, 0.3059, 0.3765,  ..., 0.6078, 0.5451, 0.4902]],

         [[0.0431, 0.0431, 0.0510,  ..., 0.0118, 0.0118, 0.0118],
          [0.0588, 0.0588, 0.0549,  ..., 0.0118, 0.0118, 0.0118],
          [0.0667, 0.0627, 0.0588,  ..., 0.0118, 0.0118, 0.0118],
          ...,
          [0.2431, 0.2745, 0.3176,  ..., 0.5373, 0.5608, 0.5804],
          [0.2510, 0.2824, 0.3294,  ..., 0.5490, 0.5412, 0.5412],
          [0.2510, 0.2863, 0.3216,  ..., 0.6000, 0.5529, 0.4980]]],


        [[[0.6353, 0.6314, 0.6314,  ..., 0.6157, 0.6157, 0.6157],
          [0.6353, 0.6314, 0.6314,  ..., 0.6157, 0.6157, 0.6157],
          [0.6353, 0.6314, 0.6314,  ..., 0.6157, 0.6157, 0.6157],
          ...,
          [0.6471, 0.6431, 0.6431,  ..., 0.6392, 0.6392, 0.6392],
          [0.6471, 0.6431, 0.6431,  ..., 0.6392, 0.6392, 0.6392],
          [0.6471, 0.6431, 0.6431,  ..., 0.6392, 0.6392, 0.6392]],

         [[0.7804, 0.7765, 0.7765,  ..., 0.7725, 0.7725, 0.7686],
          [0.7804, 0.7765, 0.7765,  ..., 0.7725, 0.7725, 0.7686],
          [0.7804, 0.7765, 0.7765,  ..., 0.7725, 0.7725, 0.7686],
          ...,
          [0.7922, 0.7882, 0.7882,  ..., 0.7843, 0.7843, 0.7843],
          [0.7922, 0.7882, 0.7882,  ..., 0.7843, 0.7843, 0.7843],
          [0.7922, 0.7882, 0.7882,  ..., 0.7843, 0.7843, 0.7843]],

         [[0.9882, 0.9804, 0.9843,  ..., 0.9765, 0.9765, 0.9765],
          [0.9882, 0.9804, 0.9843,  ..., 0.9765, 0.9765, 0.9765],
          [0.9882, 0.9804, 0.9843,  ..., 0.9765, 0.9765, 0.9765],
          ...,
          [0.9961, 0.9882, 0.9922,  ..., 0.9882, 0.9882, 0.9882],
          [0.9961, 0.9882, 0.9922,  ..., 0.9882, 0.9882, 0.9882],
          [0.9961, 0.9882, 0.9922,  ..., 0.9882, 0.9882, 0.9882]]]]), tensor([2, 8, 9, 6, 9, 3, 8, 3, 7, 7, 7, 3, 9, 2, 3, 1, 0, 1, 9, 6, 7, 6, 7, 9,
        1, 1, 8, 9, 2, 7, 5, 0, 1, 5, 9, 4, 2, 5, 7, 6, 3, 2, 2, 9, 4, 2, 1, 1,
        9, 5, 2, 5, 0, 8, 1, 7, 3, 5, 8, 0, 5, 0, 5, 0])]

使用add_images对所有批次的数据进行展示

复制代码
def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
    """Add batched image data to summary.

    Note that this requires the ``pillow`` package.

    Args:
        tag (string): Data identifier
        img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
        global_step (int): Global step value to record
        walltime (float): Optional override default walltime (time.time())
          seconds after epoch of event
        dataformats (string): Image data format specification of the form
          NCHW, NHWC, CHW, HWC, HW, WH, etc.
    Shape:
        img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
        accepted. e.g. NCHW or NHWC.

    Examples::

        from torch.utils.tensorboard import SummaryWriter
        import numpy as np

        img_batch = np.zeros((16, 3, 100, 100))
        for i in range(16):
            img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
            img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i

        writer = SummaryWriter()
        writer.add_images('my_image_batch', img_batch, 0)
        writer.close()

    Expected result:

    .. image:: _static/img/tensorboard/add_images.png
       :scale: 30 %

    """
    torch._C._log_api_usage_once("tensorboard.logging.add_images")
    if self._check_caffe2_blob(img_tensor):
        from caffe2.python import workspace
        img_tensor = workspace.FetchBlob(img_tensor)
    self._get_file_writer().add_summary(
        image(tag, img_tensor, dataformats=dataformats), global_step, walltime)

在使用add_images时要注意默认的通道数是3,如果经过卷积层以后的图片通道数大于3,那么是无法使用该函数进行显示的,会显示断言错误的信息,所以此时要使用torch.reshape将通道数变为3,然后可以正常调用

对于还未涉及的方法也是这样,查看其对应的参数类型(使用crtl+p,或者直接crtl+鼠标点击相应的函数查看源码),将所需要的参数类型丢给它使用就好

相关推荐
花酒锄作田3 小时前
[python]argparse 包在聊天机器人中的应用
python
久违 °5 小时前
【AI-Agent】TagMatrix 数据标注工具开发
人工智能·数据分析·go·agent·数据隐私
NiceCloud喜云5 小时前
Opus 4.8 的 Effort Control 怎么选:Low 到 Max 五档策略
android·java·大数据·前端·c++·python·spring
AI360labs_atyun5 小时前
腾讯推出电子牛马Marvis,好用吗?
人工智能·科技·ai
Dfreedom.5 小时前
Windows、虚拟机、开发板组网通信原理及调试通联步骤
人工智能·windows·部署·边缘计算·开发板·模型加速
3DVisionary5 小时前
蓝光三维扫描:医疗制造的精度焦虑怎么解
人工智能·算法·制造·蓝光三维扫描·医疗制造·三维检测·义齿检测
Are_You_Okkk_5 小时前
基于MonkeyCode解析AI研发新模式,根治开发低效痛点
大数据·人工智能·开源·ai编程
AI玫瑰助手5 小时前
Python函数:默认参数的定义与注意事项
开发语言·python·信息可视化
好评笔记6 小时前
机器学习面试八股——常用损失函数
人工智能·深度学习·算法·机器学习·校招
weixin_468466856 小时前
全局与局部注意力机制新手实战指南
人工智能·python·深度学习·算法·自然语言处理·transformer·注意力机制