python
import torch
test_data = torch.randn(1, 12, 32, 32)
batchsize, num_channels, height, width = test_data.data.size()
在提供的代码中,test_data
是一个形状为 (1, 12, 32, 32)
的随机张量,表示一个批次(batch)中有 1 张图像,每张图像有 12 个通道,图像的高度和宽度均为 32 像素。
注意事项
test_data.size()
返回一个包含张量各个维度大小的元组,可以直接解包到多个变量中。- 确保在调用
.size()
时,使用的是size()
而不是data.size()
,后者在新的 PyTorch 版本中已不推荐使用。
打印内容:
python
print(f"Batch Size: {batchsize}, Channels: {num_channels}, Height: {height}, Width: {width}")
#Batch Size: 1, Channels: 12, Height: 32, Width: 32