数据集相关类代码回顾理解 | StratifiedShuffleSplit\transforms.ToTensor\Counter

【PyTorch】图像多分类项目

目录

StratifiedShuffleSplit

transforms.ToTensor

Counter


StratifiedShuffleSplit

复制代码
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

创建StratifiedShuffleSplit对象,用于将数据集划分为训练集和测试集。

  • n_splits=1:划分次数为1,大于1则多次划分,每次划分生成一组新训练集和新测试集。
  • test_size=0.2:测试集比例为0.2,即测试集的大小占总样本的20%
  • random_state=0:随机种子为0,类似random的种子,保证每次抽样到的数据一样

StratifiedShuffleSplit是scikit-learn库中的一个类,用于创建训练集和测试集的划分,同时保持每个类别中的样本比例一致。核心思想:分层抽样。

StratifiedShuffleSplit 类的工作原理:

先根据每个类别的样本数量将数据集划分为尽可能相等的子集(分层)

然后在这些子集中随机选择样本拆分创建训练集和测试集(随机拆分)

插入空格更好理解:Stratified Shuffle Split分层随机拆分类!

transforms.ToTensor

复制代码
data_transformer = transforms.Compose([transforms.ToTensor()])

transforms.ToTensor()的作用是将PIL图像或NumPy数组转换为PyTorch张量,并且将图像的像素值从[0, 255]范围缩放到[0.0, 1.0]范围,即在[0.0, 1.0]范围内对像素值进行归一化。转换后的张量形状为(C, H, W)

Compose是 torchvision.transforms 模块的一个类,创建一个Compose对象时,需要传入一个包含一个或多个变换操作的列表。Compose对象一般包含四个变换操作:调整图像大小、从中心裁剪图像、将图像转换为张量以及归一化。

Counter

复制代码
counter_train=collections.Counter(y_train)

用于统计图像标签,即每类标签图像数量,Counter是用于计数的子类字典。例如PyTorch torchvision包中STL-10数据集的训练数据集:

相关推荐
阿尔的代码屋2 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python
AI探索者20 小时前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者20 小时前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
FishCoderh21 小时前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅21 小时前
Python函数入门详解(定义+调用+参数)
python
曲幽1 天前
我用FastAPI接ollama大模型,差点被asyncio整崩溃(附对话窗口实战)
python·fastapi·web·async·httpx·asyncio·ollama
两万五千个小时1 天前
落地实现 Anthropic Multi-Agent Research System
人工智能·python·架构
哈里谢顿1 天前
Python 高并发服务限流终极方案:从原理到生产落地(2026 实战指南)
python
用户8356290780512 天前
无需 Office:Python 批量转换 PPT 为图片
后端·python
markfeng82 天前
Python+Django+H5+MySQL项目搭建
python·django