数据集相关类代码回顾理解 | StratifiedShuffleSplit\transforms.ToTensor\Counter

【PyTorch】图像多分类项目

目录

StratifiedShuffleSplit

transforms.ToTensor

Counter


StratifiedShuffleSplit

复制代码
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

创建StratifiedShuffleSplit对象,用于将数据集划分为训练集和测试集。

  • n_splits=1:划分次数为1,大于1则多次划分,每次划分生成一组新训练集和新测试集。
  • test_size=0.2:测试集比例为0.2,即测试集的大小占总样本的20%
  • random_state=0:随机种子为0,类似random的种子,保证每次抽样到的数据一样

StratifiedShuffleSplit是scikit-learn库中的一个类,用于创建训练集和测试集的划分,同时保持每个类别中的样本比例一致。核心思想:分层抽样。

StratifiedShuffleSplit 类的工作原理:

先根据每个类别的样本数量将数据集划分为尽可能相等的子集(分层)

然后在这些子集中随机选择样本拆分创建训练集和测试集(随机拆分)

插入空格更好理解:Stratified Shuffle Split分层随机拆分类!

transforms.ToTensor

复制代码
data_transformer = transforms.Compose([transforms.ToTensor()])

transforms.ToTensor()的作用是将PIL图像或NumPy数组转换为PyTorch张量,并且将图像的像素值从[0, 255]范围缩放到[0.0, 1.0]范围,即在[0.0, 1.0]范围内对像素值进行归一化。转换后的张量形状为(C, H, W)

Compose是 torchvision.transforms 模块的一个类,创建一个Compose对象时,需要传入一个包含一个或多个变换操作的列表。Compose对象一般包含四个变换操作:调整图像大小、从中心裁剪图像、将图像转换为张量以及归一化。

Counter

复制代码
counter_train=collections.Counter(y_train)

用于统计图像标签,即每类标签图像数量,Counter是用于计数的子类字典。例如PyTorch torchvision包中STL-10数据集的训练数据集:

相关推荐
Learn-Python2 小时前
MongoDB-only方法
python·sql
小途软件2 小时前
用于机器人电池电量预测的Sarsa强化学习混合集成方法
java·人工智能·pytorch·python·深度学习·语言模型
扫地的小何尚3 小时前
NVIDIA RTX PC开源AI工具升级:加速LLM和扩散模型的性能革命
人工智能·python·算法·开源·nvidia·1024程序员节
wanglei2007083 小时前
生产者消费者
开发语言·python
清水白石0084 小时前
《从零到进阶:Pydantic v1 与 v2 的核心差异与零成本校验实现原理》
数据库·python
昵称已被吞噬~‘(*@﹏@*)’~4 小时前
【RL+空战】学习记录03:基于JSBSim构造简易空空导弹模型,并结合python接口调用测试
开发语言·人工智能·python·学习·深度强化学习·jsbsim·空战
2501_941877984 小时前
从配置热更新到运行时自适应的互联网工程语法演进与多语言实践随笔分享
开发语言·前端·python
酩酊仙人4 小时前
fastmcp构建mcp server和client
python·ai·mcp
且去填词5 小时前
DeepSeek API 深度解析:从流式输出、Function Calling 到构建拥有“手脚”的 AI 应用
人工智能·python·语言模型·llm·agent·deepseek
rgeshfgreh5 小时前
Python条件与循环实战指南
python