Pytorch使用torch.utils.data.random_split拆分数据集,拆分后的数据集状况

对于这个API,我最开始的预想是从 '猫1猫2猫3猫4狗1狗2狗3狗4' 中分割出 '猫1猫2狗4狗1' 和 '猫4猫3狗2狗3' ,但是打印结果和我预想的不一样

数据集文件的存放路径如下图

测试代码如下

python 复制代码
import torch
import torchvision

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224
    torchvision.transforms.ToTensor(),  # 转换为张量
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])
dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',
                                                 transform=transform)

val_ratio = 0.2
val_size = int(len(dataset) * val_ratio)
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


cats_num = 0
dogs_num = 0
for x,y in train_dataset:
    if y == 0:
        cats_num += 1
    else:
        dogs_num += 1

print("cats_num: ",cats_num)
print("dogs_num: ",dogs_num)

cats_num2 = 0
dogs_num2 = 0
for x,y in val_dataset:
    if y == 0:
        cats_num2 += 1
    else:
        dogs_num2 += 1

print("cats_num2: ",cats_num2)
print("dogs_num2: ",dogs_num2)

输出如下

可以看到总共25000张图片的数据集,分割后并不是cats_num:10000,dogs_num:10000,cats_num2:2500,dogs_num2:2500

也就是说,分割后的状况是猫狗的数量并不一定相等,如结果为 '猫1猫2猫4狗1' 和 '狗4猫3狗2狗3'

相关推荐
www.02几秒前
Linux 终端守护神 Tmux :如何优雅地管理后台实验与恢复会话
linux·运维·服务器·人工智能·tmux
Agent手记5 分钟前
制造业物流延迟预警系统,从0到1落地实操指南 | 企业级AI Agent架构实战
人工智能·ai
u0110225127 分钟前
如何自定义查询历史记录面板的展示风格_时间轴样式设计
jvm·数据库·python
2301_769340679 分钟前
HTML怎么实现快捷跳转顶部_HTML固定悬浮锚点按钮【介绍】
jvm·数据库·python
老马952716 分钟前
opencode7-桌面应用实战2
java·人工智能·后端
DogDaoDao16 分钟前
【GitHub】Ruflo:面向 Claude Code 的企业级多智能体编排平台深度解析
人工智能·深度学习·大模型·github·ai编程·claude·ruflo
yuanpan18 分钟前
Python + PyAutoGUI 实战:Windows 自动化办公脚本开发入门
windows·python·自动化
m0_6091604921 分钟前
MySQL如何限制触发器递归调用的深度_防止触发器死循环方法
jvm·数据库·python
璞华Purvar22 分钟前
2026化工新材料PLM行业白皮书:璞华易研,以垂直深耕重构研发数智底座
人工智能
zhonghaoxincekj22 分钟前
轴距可调式元器件双边无损成形钳
经验分享·科技·深度学习·学习·测试工具·创业创新·制造