【PyTorch】PyTorch之包装容器

文章目录


前言

介绍pytorch关于model的包装容器。

Containers

1. torch.nn.Sequential(arg: OrderedDictstr, Module)

Sequential是一个顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中。另外,可以传递一个包含模块的 OrderedDict。

Sequential 的 forward() 方法接受任何输入,并将其转发到它包含的第一个模块。然后,对于每个后续模块,它将输出顺序链接到输入,最终返回最后一个模块的输出。

与手动调用一系列模块相比,Sequential 提供的价值在于它允许将整个容器视为单个模块,这样在 Sequential 上执行转换会应用于它存储的每个模块(它们各自是 Sequential 的已注册子模块)。

Sequential 和 torch.nn.ModuleList 之间有什么区别?ModuleList 正是其字面意思 - 用于存储模块的列表!另一方面,Sequential 中的层以级联的方式连接。

python 复制代码
# 使用 Sequential 创建一个小模型。当运行 model 时,
# 输入首先会传递给 Conv2d(1,20,5)。然后,Conv2d(1,20,5) 的输出将用作第一个 ReLU 的输入;
# 第一个 ReLU 的输出将成为 Conv2d(20,64,5) 的输入。最后,Conv2d(20,64,5) 的输出将用作第二个 # ReLU 的输入
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# 使用带有 OrderedDict 的 Sequential。这在功能上与上面的代码相同
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

Sequential类的方法:
append(module)
Parameters:
module (nn.Module) -- 要添加的module
Return type:
Sequential

将给定的模块追加到末尾。

2. torch.nn.ModuleList(modules=None)

Parameters:
modules (iterable, optional) -- 要添加的模块的可迭代对象。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

ModuleList的方法列表:

append(module)

Parameters:
module (nn.Module) -- 要添加的模块
Return type:
ModuleList

将给定的模块追加到列表的末尾。

extend(module)

Parameters:
modules (iterable) -- i要追加的模块的可迭代对象。
Return type:
ModuleList

从 Python 可迭代对象中追加模块到列表的末尾。

insert(index, module)

Parameters:
index (int) -- 要插入的索引。
module (nn.Module) -- 要插入的模块。

在列表的给定索引之前插入给定的模块。

3. torch.nn.ModuleDict(modules=None)

Parameters:
modules (iterable, optional) -- 一个映射(字典),格式为 (string: module) 或一个类型为 (string, module) 的键值对的可迭代对象。

ModuleDict 可以像常规的 Python 字典一样进行索引,但它包含的模块已经被正确注册,并且将被所有 Module 方法看到。

ModuleDict 是一个有序字典,它遵循以下顺序:

  1. 插入的顺序
  2. 在 update() 中,遵循合并的 OrderedDict、dict(从 Python 3.6 开始)或另一个
    ModuleDict(update() 的参数)

请注意,对于其他无序映射类型的 update()(例如,在 Python 版本 3.6 之前的普通 dict),不保留合并映射的顺序。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

ModuleDict类的方法:

clear()

移除ModuleDict中的所有项目。

items()

Return type:
IterableTuple\[str, Module]

返回ModuleDict键/值对的可迭代对象。

keys()

Return type:
Iterablestr

返回ModuleDict键的可迭代对象。

pop(key)

Parameters:
key (str) -- key to pop from the ModuleDict
Return type:
Module

从ModuleDict中移除key并返回它的模块。

update(modules)

Parameters:
modules (iterable) -- 一个从字符串到 Module 的映射(字典),或者一个键值对类型为 (string, Module) 的可迭代对象。

使用映射或可迭代对象的键值对来更新 ModuleDict,覆盖现有的键。

注意:

如果 modules 是 OrderedDict、ModuleDict 或键值对的可迭代对象,其中新元素的顺序将被保留。

values()

Return type:
IterableModule

返回ModuleDict值的可迭代对象。

4. torch.nn.ParameterList(values=None)

Parameters:
parameters (iterable, optional) -- 要添加到列表中的元素的可迭代对象.

ParameterList 可以像常规的 Python 列表一样使用,但作为 Parameter 的 Tensor 已经被正确注册,并且将被所有 Module 方法看到。

请注意,构造函数、分配列表的元素、append() 方法和 extend() 方法将把任何 Tensor 转换为 Parameter。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

ParameterList类的方法:

append(value)

Parameters:
value (Any) -- 将要添加的value
Return type:
ParameterList

将给定的值追加到列表的末尾。

extend(values)

Parameters:
values (iterable) -- iterable of values to append
Return type:
ParameterList

将Python可迭代对象中的值附加到列表的末尾。

5. torch.nn.ParameterDict(parameters=None)

Parameters:
values (iterable, optional) -- 一个映射(字典),其元素为 (string : Any) 或一个类型为 (string, Any) 的键值对的可迭代对象。

ParameterDict 在一个字典中保存参数。

ParameterDict 可以像常规的 Python 字典一样进行索引,但它包含的 Parameters 已经被正确注册,并且将被所有 Module 方法看到。其他对象将被处理得像常规的 Python 字典一样。

ParameterDict 是一个有序字典。使用其他无序映射类型(例如 Python 的普通 dict)进行的 update() 操作不会保留合并映射的顺序。另一方面,OrderedDict 或另一个 ParameterDict 将保留它们的顺序。

请注意,构造函数、分配字典的元素和 update() 方法将把任何 Tensor 转换为 Parameter。

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterDict({
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

ParameterDict类的方法:

clear()

移除ParameterDict容器中的所有项目。

copy()

Return type:
ParameterDict

返回这个ParameterDict实例的副本。

fromkeys(keys, default=None)

Parameters:
keys (iterable, string) -- 用于创建新 ParameterDict 的键
default (Parameter, optional) -- 为所有键设置的值
Return type:
ParameterDict

返回一个带有参数keys的新的 ParameterDict。

get(key, default=None)

Parameters:
key (str) -- 要从 ParameterDict 获取的键
default (Parameter, optional) -- 如果键不存在时返回的值
Return type:
Any

如果存在与 key 关联的参数,则返回该参数。否则,如果提供了 default,则返回 default,如果没有提供则返回 None。

items()

Return type:
IterableTuple\[str, Any]

返回 ParameterDict 键/值对的可迭代对象。

keys()

Return type:
Iterablestr

返回 ParameterDict 的键的可迭代对象。

pop(key)

Parameters:
key (str) -- 要从 ParameterDict 弹出的键
Return type:
Any

从 ParameterDict 中删除键并返回其参数。

popitem()

Return type:
Tuplestr, Any

从 ParameterDict 中删除并返回最后插入的(键,参数)对。

setdefault(key, default=None)

Parameters:
key (str) -- 要为其设置默认值的键
default (Any) -- 设置为键的参数
Return type:
Any

如果 key 在 ParameterDict 中,则返回其值。如果不在,则插入 key 并将其参数设置为 default,并返回 default。default 默认为 None。

update(parameters)

Parameters:
parameters (iterable) -- 从字符串到 Parameter 的映射(字典),或类型为(字符串,Parameter)的键值对的可迭代对象。

使用映射或可迭代对象中的键值对更新 ParameterDict,覆盖现有键。

注意:

如果 parameters 是 OrderedDict、ParameterDict 或键值对的可迭代对象,则其中新元素的顺序将被保留。

values()

Return type:
IterableAny

返回 ParameterDict 值的可迭代对象。

相关推荐
Lei活在当下3 小时前
【AI手记系列-2026/6/18】iSparto & Harness,Caveman 以及AI时代的生存指南
人工智能·llm·openai
冬奇Lab5 小时前
每日一个开源项目(第134篇):Zvec - 阿里开源的嵌入式向量数据库,向量搜索界的 SQLite
数据库·人工智能·llm
冬奇Lab5 小时前
Agent 系列(22):Context Engineering 深度——三种上下文管理策略的量化对比
人工智能·agent
hboot5 小时前
AI工程师第二课 - 数据处理
人工智能·python·数据分析
程序员cxuan5 小时前
DeepSeek 杀入多模态,识图功能正式上线!
人工智能·后端·程序员
米小虾7 小时前
告别单打独斗:2026年多Agent协作架构实战指南
人工智能·agent
IT_陈寒8 小时前
SpringBoot这个自动配置坑我跳了三次
前端·人工智能·后端
Larcher9 小时前
AI Loop:让AI像人一样自主完成任务的核心机制
javascript·人工智能·设计模式
牧艺9 小时前
从零到协同:构建类飞书在线文档系统的五个技术重难点
前端·人工智能