Pytorch(三)

一、经典网络架构图像分类模型

数据预处理部分:

  • 数据增强
  • 数据预处理
  • DataLoader模块直接读取batch数据

网络模块设置:

  • 加载预训练模型,torchvision中有很多经典网络架构,可以直接调用
  • 注意别人训练好的任务跟咱们的并不完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成自己的任务
  • 续联时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

网络模型保存与测试:

  • 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
  • 读取模型进行实际测试

二、迁移学习

利用别人训练好的模型来训练自己的模型

注:两种物体尽可能相似

迁移学习网站: Start Locally | PyTorch

三、花图像分类案例

未完结

python 复制代码
#数据读取与预处理操作
data_dir = './a/'
# 训练集
train_dir = data_dir + '/train'
#验证集
valid_ir = data_dir + '/valid'

#制作数据源
data_transfroms = {
    'train':transforms.Compose([transforms.RandomRotation(45), #随机旋转(-45~45)
    transforms.CenterCrop(224), #从中心开始裁剪
    transforms.RandomHorizontalFlip(p = 0.5), #随机水平翻转
    transforms.RandomVerticalFlip(p = 0.5), #随机垂直翻转
    transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue = 0.1),
    transforms.RandomGrayscale(p = 0.025), #概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

#batch数据制作
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transfroms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size = batch_size,shuffle = True) for x in ['train','valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes


#读取标签对应的实际名字
with open('cat_to_name.json','r') as f:
    cat_to_name = json.load(f)

#加载model中提供的模型,并且直接用训练好的权重当做初始化参数
model_name = 'resnet'
#是否用人家训练好的特征来做
feature_extract = True

#是否用GPU来训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('cuda is not available. Training on CPU')
else:
    print('cuda is available. Training on GPU')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameter():
            param.requires_grad = False

model_ft = models.resnet152()
相关推荐
꧁꫞꯭零꯭点꯭꫞꧂5 分钟前
LangChain 提示词模板与链式调用笔记
人工智能·笔记·langchain
xingyuzhisuan7 分钟前
从零精通GPU服务器模型部署:标准化流程与性能调优实战
运维·服务器·人工智能
一起聊电气14 分钟前
告别盲目制冷!AI空调自控,让建筑自主呼吸、按需耗能
人工智能
java1234_小锋15 分钟前
什么是 RAG(检索增强生成)?请简述 Spring AI 实现 RAG 的完整流程,包括涉及的核心组件。
java·人工智能·spring·rag
小真zzz15 分钟前
9.8分登顶:搜极星如何以绝对中立与专业,定义AI时代品牌洞察新范式
大数据·人工智能·搜索引擎·ai
weixin_3975740918 分钟前
Agent推理可视化打破AI黑盒,让思考过程透明可见
人工智能
Saniffer_SH22 分钟前
【每日一题】不只是点亮画面:UniGraf 如何把 HDMI/DP 接口问题拆成可定位、可复现、可自动化验证的测试流程?
运维·人工智能·测试工具·fpga开发·性能优化·自动化·压力测试
ai产品老杨25 分钟前
解耦异构算力与多协议接入:基于 Docker 与 GB28181 的企业级 AI 视频管理平台架构演进与源码交付实践
人工智能·docker·音视频
郑寿昌25 分钟前
2026 全球 AI 工厂市场格局与发展趋势
大数据·人工智能·microsoft
HackTwoHub28 分钟前
AI赋能Chrome MCP × JS逆向Skill自动化JS逆向挖洞
javascript·人工智能·chrome·安全·web安全·网络安全·自动化