「小土堆」pytorch DataSet

「小土堆」pytorch DataSet

python 复制代码
from cProfile import label

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):

        # root_dir = "hymenoptera_data/train"
        # label_dir = "ants_img"
        # 这两个值是由后面的实例传递过来的
        self.root_dir = root_dir
        self.label_dir = label_dir

        # 将其整合
        self.path = os.path.join(root_dir, label_dir)

        # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
        self.img_name_path = os.listdir(self.path)

    def __getitem__(self, index):

        img_name = self.img_name_path[index]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path)

        label = self.label_dir
        # 返回的是图片和名字
        return img, label

    def __len__(self):
        return len(self.img_path)


ants_dataset = MyData('hymenoptera_data/train', 'ants_img')
# for item in ants_dataset.img_path:
#     print(item)
# print(ants_dataset.__len__())
# print(ants_dataset.__getitem__(0))


bees_dataset = MyData('hymenoptera_data/train', 'bees_img')
# print(bees_dataset.__len__())
# print(bees_dataset.__getitem__(0))
#
# print(len(ants_dataset+bees_dataset))

# ant_dataset中包含两个值,一个img一个label
img, label = ants_dataset[0]
img.show()

​ 视频中一开始是先写class的以至于一开始没有弄懂 'root_dir' 和 'label_dir' 是干什么的,在创建实例之后进行传参就可以很好的理解了,前者是指文件夹的路径,后者是文件夹下的分类,由于文件夹下面分别有两个类别的例子,所以分为root和label两类。

dataset提供了访问和处理大量自然语言处理(NLP)数据集的工具,简单来说就是对数据集中的图片进行操作的一个简单的库。

python 复制代码
 def __init__(self, root_dir, label_dir):

        # root_dir = "hymenoptera_data/train"
        # label_dir = "ants_img"
        # 这两个值是由后面的实例传递过来的
        self.root_dir = root_dir
        self.label_dir = label_dir

        # 将其整合
        self.path = os.path.join(root_dir, label_dir)

        # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
        self.img_name_path = os.listdir(self.path)

​ 上述代码中从上到下来看,首先是MyData库,其继承了Dataset这个类,第一个函数就是对数据的初始化,可以理解成java中的构造器一样的功能。

​ os.path.join()此函数是将路径整合在一起赋值给self.path

​ os.listdir()此函数是返回self.path路径下包含的文件夹或文件夹的名字的列表,重点是它返回的是一个列表,这个列表中包含了文件夹下面的文件的信息

python 复制代码
    def __getitem__(self, index):

        img_name = self.img_name_path[index]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path)

        label = self.label_dir
        # 返回的是图片和名字
        return img, label

    def __len__(self):
        return len(self.img_path)

​ getitem这个函数的功能主要是对列表中的图片信息进行整合和赋值

相关推荐
陈鋆10 分钟前
智慧城市初探与解决方案
人工智能·智慧城市
qdprobot10 分钟前
ESP32桌面天气摆件加文心一言AI大模型对话Mixly图形化编程STEAM创客教育
网络·人工智能·百度·文心一言·arduino
QQ395753323711 分钟前
金融量化交易模型的突破与前景分析
人工智能·金融
QQ395753323711 分钟前
金融量化交易:技术突破与模型优化
人工智能·金融
The_Ticker24 分钟前
CFD平台如何接入实时行情源
java·大数据·数据库·人工智能·算法·区块链·软件工程
Elastic 中国社区官方博客30 分钟前
Elasticsearch 开放推理 API 增加了对 IBM watsonx.ai Slate 嵌入模型的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
jwolf230 分钟前
摸一下elasticsearch8的AI能力:语义搜索/vector向量搜索案例
人工智能·搜索引擎
有Li39 分钟前
跨视角差异-依赖网络用于体积医学图像分割|文献速递-生成式模型与transformer在医学影像中的应用
人工智能·计算机视觉
傻啦嘿哟42 分钟前
如何使用 Python 开发一个简单的文本数据转换为 Excel 工具
开发语言·python·excel
B站计算机毕业设计超人1 小时前
计算机毕业设计SparkStreaming+Kafka旅游推荐系统 旅游景点客流量预测 旅游可视化 旅游大数据 Hive数据仓库 机器学习 深度学习
大数据·数据仓库·hadoop·python·kafka·课程设计·数据可视化