【PyTorch】PyTorch之包装容器

文章目录


前言

介绍pytorch关于model的包装容器。

Containers

1. torch.nn.Sequential(arg: OrderedDict[str, Module])

Sequential是一个顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中。另外,可以传递一个包含模块的 OrderedDict。

Sequential 的 forward() 方法接受任何输入,并将其转发到它包含的第一个模块。然后,对于每个后续模块,它将输出顺序链接到输入,最终返回最后一个模块的输出。

与手动调用一系列模块相比,Sequential 提供的价值在于它允许将整个容器视为单个模块,这样在 Sequential 上执行转换会应用于它存储的每个模块(它们各自是 Sequential 的已注册子模块)。

Sequential 和 torch.nn.ModuleList 之间有什么区别?ModuleList 正是其字面意思 - 用于存储模块的列表!另一方面,Sequential 中的层以级联的方式连接。

python 复制代码
# 使用 Sequential 创建一个小模型。当运行 model 时,
# 输入首先会传递给 Conv2d(1,20,5)。然后,Conv2d(1,20,5) 的输出将用作第一个 ReLU 的输入;
# 第一个 ReLU 的输出将成为 Conv2d(20,64,5) 的输入。最后,Conv2d(20,64,5) 的输出将用作第二个 # ReLU 的输入
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# 使用带有 OrderedDict 的 Sequential。这在功能上与上面的代码相同
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

Sequential类的方法:
append(module)
Parameters:
module (nn.Module) -- 要添加的module
Return type:
Sequential

将给定的模块追加到末尾。

2. torch.nn.ModuleList(modules=None)

Parameters:
modules (iterable, optional) -- 要添加的模块的可迭代对象。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

ModuleList的方法列表:

append(module)

Parameters:
module (nn.Module) -- 要添加的模块
Return type:
ModuleList

将给定的模块追加到列表的末尾。

extend(module)

Parameters:
modules (iterable) -- i要追加的模块的可迭代对象。
Return type:
ModuleList

从 Python 可迭代对象中追加模块到列表的末尾。

insert(index, module)

Parameters:
index (int) -- 要插入的索引。
module (nn.Module) -- 要插入的模块。

在列表的给定索引之前插入给定的模块。

3. torch.nn.ModuleDict(modules=None)

Parameters:
modules (iterable, optional) -- 一个映射(字典),格式为 (string: module) 或一个类型为 (string, module) 的键值对的可迭代对象。

ModuleDict 可以像常规的 Python 字典一样进行索引,但它包含的模块已经被正确注册,并且将被所有 Module 方法看到。

ModuleDict 是一个有序字典,它遵循以下顺序:

  1. 插入的顺序
  2. 在 update() 中,遵循合并的 OrderedDict、dict(从 Python 3.6 开始)或另一个
    ModuleDict(update() 的参数)

请注意,对于其他无序映射类型的 update()(例如,在 Python 版本 3.6 之前的普通 dict),不保留合并映射的顺序。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

ModuleDict类的方法:

clear()

移除ModuleDict中的所有项目。

items()

Return type:
Iterable[Tuple[str, Module]]

返回ModuleDict键/值对的可迭代对象。

keys()

Return type:
Iterable[str]

返回ModuleDict键的可迭代对象。

pop(key)

Parameters:
key (str) -- key to pop from the ModuleDict
Return type:
Module

从ModuleDict中移除key并返回它的模块。

update(modules)

Parameters:
modules (iterable) -- 一个从字符串到 Module 的映射(字典),或者一个键值对类型为 (string, Module) 的可迭代对象。

使用映射或可迭代对象的键值对来更新 ModuleDict,覆盖现有的键。

注意:

如果 modules 是 OrderedDict、ModuleDict 或键值对的可迭代对象,其中新元素的顺序将被保留。

values()

Return type:
Iterable[Module]

返回ModuleDict值的可迭代对象。

4. torch.nn.ParameterList(values=None)

Parameters:
parameters (iterable, optional) -- 要添加到列表中的元素的可迭代对象.

ParameterList 可以像常规的 Python 列表一样使用,但作为 Parameter 的 Tensor 已经被正确注册,并且将被所有 Module 方法看到。

请注意,构造函数、分配列表的元素、append() 方法和 extend() 方法将把任何 Tensor 转换为 Parameter。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

ParameterList类的方法:

append(value)

Parameters:
value (Any) -- 将要添加的value
Return type:
ParameterList

将给定的值追加到列表的末尾。

extend(values)

Parameters:
values (iterable) -- iterable of values to append
Return type:
ParameterList

将Python可迭代对象中的值附加到列表的末尾。

5. torch.nn.ParameterDict(parameters=None)

Parameters:
values (iterable, optional) -- 一个映射(字典),其元素为 (string : Any) 或一个类型为 (string, Any) 的键值对的可迭代对象。

ParameterDict 在一个字典中保存参数。

ParameterDict 可以像常规的 Python 字典一样进行索引,但它包含的 Parameters 已经被正确注册,并且将被所有 Module 方法看到。其他对象将被处理得像常规的 Python 字典一样。

ParameterDict 是一个有序字典。使用其他无序映射类型(例如 Python 的普通 dict)进行的 update() 操作不会保留合并映射的顺序。另一方面,OrderedDict 或另一个 ParameterDict 将保留它们的顺序。

请注意,构造函数、分配字典的元素和 update() 方法将把任何 Tensor 转换为 Parameter。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterDict({
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

ParameterDict类的方法:

clear()

移除ParameterDict容器中的所有项目。

copy()

Return type:
ParameterDict

返回这个ParameterDict实例的副本。

fromkeys(keys, default=None)

Parameters:
keys (iterable, string) -- 用于创建新 ParameterDict 的键
default (Parameter, optional) -- 为所有键设置的值
Return type:
ParameterDict

返回一个带有参数keys的新的 ParameterDict。

get(key, default=None)

Parameters:
key (str) -- 要从 ParameterDict 获取的键
default (Parameter, optional) -- 如果键不存在时返回的值
Return type:
Any

如果存在与 key 关联的参数,则返回该参数。否则,如果提供了 default,则返回 default,如果没有提供则返回 None。

items()

Return type:
Iterable[Tuple[str, Any]]

返回 ParameterDict 键/值对的可迭代对象。

keys()

Return type:
Iterable[str]

返回 ParameterDict 的键的可迭代对象。

pop(key)

Parameters:
key (str) -- 要从 ParameterDict 弹出的键
Return type:
Any

从 ParameterDict 中删除键并返回其参数。

popitem()

Return type:
Tuple[str, Any]

从 ParameterDict 中删除并返回最后插入的(键,参数)对。

setdefault(key, default=None)

Parameters:
key (str) -- 要为其设置默认值的键
default (Any) -- 设置为键的参数
Return type:
Any

如果 key 在 ParameterDict 中,则返回其值。如果不在,则插入 key 并将其参数设置为 default,并返回 default。default 默认为 None。

update(parameters)

Parameters:
parameters (iterable) -- 从字符串到 Parameter 的映射(字典),或类型为(字符串,Parameter)的键值对的可迭代对象。

使用映射或可迭代对象中的键值对更新 ParameterDict,覆盖现有键。

注意:

如果 parameters 是 OrderedDict、ParameterDict 或键值对的可迭代对象,则其中新元素的顺序将被保留。

values()

Return type:
Iterable[Any]

返回 ParameterDict 值的可迭代对象。

相关推荐
大写-凌祁3 小时前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
焦耳加热3 小时前
阿德莱德大学Nat. Commun.:盐模板策略实现废弃塑料到单原子催化剂的高值转化,推动环境与能源催化应用
人工智能·算法·机器学习·能源·材料工程
CodeCraft Studio3 小时前
PDF处理控件Aspose.PDF教程:使用 Python 将 PDF 转换为 Base64
开发语言·python·pdf·base64·aspose·aspose.pdf
深空数字孪生3 小时前
储能调峰新实践:智慧能源平台如何保障风电消纳与电网稳定?
大数据·人工智能·物联网
wan5555cn3 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威4 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
困鲲鲲4 小时前
Python中内置装饰器
python
摩羯座-185690305945 小时前
Python数据可视化基础:使用Matplotlib绘制图表
大数据·python·信息可视化·matplotlib
今天也要学习吖5 小时前
谷歌nano banana官方Prompt模板发布,解锁六大图像生成风格
人工智能·学习·ai·prompt·nano banana·谷歌ai
Hello123网站5 小时前
glean-企业级AI搜索和知识发现平台
人工智能·产品运营·ai工具