模型打印每层shape
import torch
from models.enet import ENet
from torchsummary import summary #pip install torchsummary
model=ENet(12)
aaa=torch.load("./save/ENet",map_location="cpu")
# print(aaa)
model.load_state_dict(aaa["state_dict"])
summary(model, (3, 256, 256))