之前过段时间就会忘记ToTensor是否对输入进行了transpose,因此在这里记录一下ToTensor的作用
ToTensor的作用
该数据增强能把输入的数据格式(numpy.ndarray或者PIL Image)转成tensor格式,输入的shape会从[H, W, C]变成]C, H, W],输入的数值范围会从[0, 255]变成[0, 1],
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
代码示例
python
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_tensor)
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
# 判断pic类是否符合要求
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
default_float_dtype = torch.get_default_dtype()
# 如果pic为np.ndarray
if isinstance(pic, np.ndarray):
# 如果pic的ndim为2, 则扩充一个维度
if pic.ndim == 2:
pic = pic[:, :, None]
# [H, W, C] --> [C, H, W]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility, 将pic的数值除以255
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
python
from torchvision.transforms import ToTensor
import numpy as np
x = np.arange(0, 81, dtype="uint8").reshape(9, 9)
print(x.shape)
print(x)
x_tensor = ToTensor()(x)
print(x_tensor.shape)
print(x_tensor)
"""
(9, 9)
[[ 0 1 2 3 4 5 6 7 8]
[ 9 10 11 12 13 14 15 16 17]
[18 19 20 21 22 23 24 25 26]
[27 28 29 30 31 32 33 34 35]
[36 37 38 39 40 41 42 43 44]
[45 46 47 48 49 50 51 52 53]
[54 55 56 57 58 59 60 61 62]
[63 64 65 66 67 68 69 70 71]
[72 73 74 75 76 77 78 79 80]]
torch.Size([1, 9, 9])
tensor([[[0.0000, 0.0039, 0.0078, 0.0118, 0.0157, 0.0196, 0.0235, 0.0275,
0.0314],
[0.0353, 0.0392, 0.0431, 0.0471, 0.0510, 0.0549, 0.0588, 0.0627,
0.0667],
[0.0706, 0.0745, 0.0784, 0.0824, 0.0863, 0.0902, 0.0941, 0.0980,
0.1020],
[0.1059, 0.1098, 0.1137, 0.1176, 0.1216, 0.1255, 0.1294, 0.1333,
0.1373],
[0.1412, 0.1451, 0.1490, 0.1529, 0.1569, 0.1608, 0.1647, 0.1686,
0.1725],
[0.1765, 0.1804, 0.1843, 0.1882, 0.1922, 0.1961, 0.2000, 0.2039,
0.2078],
[0.2118, 0.2157, 0.2196, 0.2235, 0.2275, 0.2314, 0.2353, 0.2392,
0.2431],
[0.2471, 0.2510, 0.2549, 0.2588, 0.2627, 0.2667, 0.2706, 0.2745,
0.2784],
[0.2824, 0.2863, 0.2902, 0.2941, 0.2980, 0.3020, 0.3059, 0.3098,
0.3137]]])
"""