模型预处理的ToTensor和Normalize
flyfish
py
import torch
import numpy as np
from torchvision import transforms
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# data0 =np.random.randint(0,255,size = [4,5,3],dtype='uint8')
# data0 = data0.astype(np.float64)
data0 = np.random.random((4, 5, 3)) # H x W x C
data0 = np.round(data0,4)
print(data0.shape)
print(data0)
data1 = transforms.ToTensor()(data0)
print(data1.shape) # C x H x W
print(data1)
data2 = transforms.Normalize(mean, std)(data1)
print(data2)
ToTensor
是数据维度发生变化H x W x C
变为 C x H x W
,数值没有变化
Normalize
是 (data - mean) / std
使用numpy实现验证
py
data1 = np.transpose(data0, (2, 0, 1))
print(data1.shape)
_std = np.array(std).reshape((3, 1, 1))
_mean = np.array(mean).reshape((3, 1, 1))
data2 = (data1 - _mean) / _std
print(data2)
原始数据的形状和内容 可以是图像的高度,宽度,通道
(4, 5, 3)
[[[0.8284 0.3419 0.6621]
[0.59 0.2306 0.4112]
[0.0636 0.406 0.2778]
[0.9551 0.2097 0.7681]
[0.3097 0.642 0.1968]]
[[0.722 0.9844 0.4942]
[0.1847 0.2435 0.3691]
[0.658 0.5643 0.9468]
[0.4002 0.7807 0.4393]
[0.2461 0.9049 0.0585]]
[[0.2606 0.067 0.6186]
[0.284 0.8524 0.2102]
[0.0447 0.0209 0.1313]
[0.0587 0.594 0.1016]
[0.6942 0.4514 0.7125]]
[[0.8787 0.7917 0.1181]
[0.9044 0.7948 0.3599]
[0.1706 0.7463 0.899 ]
[0.0758 0.2224 0.5447]
[0.3336 0.6096 0.3065]]]
ToTensor 后的形状和内容
torch.Size([3, 4, 5])
tensor([[[0.8284, 0.5900, 0.0636, 0.9551, 0.3097],
[0.7220, 0.1847, 0.6580, 0.4002, 0.2461],
[0.2606, 0.2840, 0.0447, 0.0587, 0.6942],
[0.8787, 0.9044, 0.1706, 0.0758, 0.3336]],
[[0.3419, 0.2306, 0.4060, 0.2097, 0.6420],
[0.9844, 0.2435, 0.5643, 0.7807, 0.9049],
[0.0670, 0.8524, 0.0209, 0.5940, 0.4514],
[0.7917, 0.7948, 0.7463, 0.2224, 0.6096]],
[[0.6621, 0.4112, 0.2778, 0.7681, 0.1968],
[0.4942, 0.3691, 0.9468, 0.4393, 0.0585],
[0.6186, 0.2102, 0.1313, 0.1016, 0.7125],
[0.1181, 0.3599, 0.8990, 0.5447, 0.3065]]], dtype=torch.float64)
Normalize 后的形状和内容
tensor([[[ 1.4996, 0.4585, -1.8402, 2.0528, -0.7655],
[ 1.0349, -1.3114, 0.7555, -0.3703, -1.0432],
[-0.9799, -0.8777, -1.9227, -1.8616, 0.9135],
[ 1.7192, 1.8314, -1.3729, -1.7869, -0.6611]],
[[-0.5094, -1.0063, -0.2232, -1.0996, 0.8304],
[ 2.3589, -0.9487, 0.4835, 1.4496, 2.0040],
[-1.7366, 1.7696, -1.9424, 0.6161, -0.0205],
[ 1.4987, 1.5125, 1.2960, -1.0429, 0.6857]],
[[ 1.1382, 0.0231, -0.5698, 1.6093, -0.9298],
[ 0.3920, -0.1640, 2.4036, 0.1480, -1.5444],
[ 0.9449, -0.8702, -1.2209, -1.3529, 1.3622],
[-1.2796, -0.2049, 2.1911, 0.6164, -0.4422]]], dtype=torch.float64)
使用numpy实现验证的结果
(3, 4, 5)
[[[ 1.49956332 0.45851528 -1.84017467 2.05283843 -0.76550218]
[ 1.0349345 -1.31135371 0.75545852 -0.37030568 -1.04323144]
[-0.97991266 -0.87772926 -1.92270742 -1.86157205 0.91353712]
[ 1.71921397 1.83144105 -1.37292576 -1.78689956 -0.66113537]]
[[-0.509375 -1.00625 -0.22321429 -1.09955357 0.83035714]
[ 2.35892857 -0.94866071 0.48348214 1.44955357 2.00401786]
[-1.73660714 1.76964286 -1.94241071 0.61607143 -0.02053571]
[ 1.49866071 1.5125 1.29598214 -1.04285714 0.68571429]]
[[ 1.13822222 0.02311111 -0.56977778 1.60933333 -0.92977778]
[ 0.392 -0.164 2.40355556 0.148 -1.54444444]
[ 0.94488889 -0.87022222 -1.22088889 -1.35288889 1.36222222]
[-1.27955556 -0.20488889 2.19111111 0.61644444 -0.44222222]]]
两者除了保留小数位数不同外,其他一致