PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数 说明
input 输入张量
start_dim 开始展平的维度(包含该维)
end_dim 结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

python 复制代码
import torch

x = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])  # shape = (2, 2, 2)

out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

python 复制代码
x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)

print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

python 复制代码
x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)

print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数 说明
view() 更底层、需要张量是连续的,手动指定形状
flatten() 更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

python 复制代码
import torch.nn as nn

class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)              # shape: (N, 16, 1, 1)
        x = torch.flatten(x, 1)       # shape: (N, 16)
        x = self.fc(x)
        return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):
    def __init__(self):
        super(FlattenCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8]
        )

        self.fc = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)              # CIFAR-10 共 10 类
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batch
        x = self.fc(x)
        return x

# ==== 2. 准备数据 ====
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ==== 4. 训练过程 ====
def train(model, loader, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数 特点说明
flatten() 高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view() 底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape() 更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

python 复制代码
x = torch.randn(32, 3, 64, 64)  # batch of images

# flatten
f1 = torch.flatten(x, 1)

# view
f2 = x.view(32, -1)

# reshape
f3 = x.reshape(32, -1)

print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

python 复制代码
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量

try:
    y.view(-1)  # 会报错
except RuntimeError as e:
    print("view error:", e)

print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

python 复制代码
import torch
import time

x = torch.randn(1024, 512, 28, 28)

# 保证是连续的
x_contig = x.contiguous()

N = 1000

def benchmark(op, name):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(N):
        _ = op(x_contig)
    torch.cuda.synchronize()
    end = time.time()
    print(f"{name}: {(end - start)*1000:.2f} ms")

benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

复制代码
flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景 推荐方式 原因
模型中展平 CNN 输出 flatten() 简洁、安全,尤其在复杂网络中
确保连续张量、追求速度 view() 性能最佳
张量可能非连续 reshape() 自动处理不连续情况,代码更鲁棒

六、小结

用法 效果
torch.flatten(x) 将所有维展平成一维
torch.flatten(x, 1) 保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2) 展平指定维度区间

相关推荐
得赢科技几秒前
智能菜谱研发公司推荐 适配中小型餐饮
大数据·运维·人工智能
一个无名的炼丹师4 分钟前
多模态RAG系统进阶:从零掌握olmOCR与MinerU的部署与应用
python·大模型·ocr·多模态·rag
lovod8 分钟前
视觉SLAM十四讲合集
计算机视觉·slam·视觉slam·g2o·ba·位姿图
victory043117 分钟前
Gradio实现中英文切换,不影响页面状态,不得刷新页面情况下
人工智能
u01092727119 分钟前
使用XGBoost赢得Kaggle比赛
jvm·数据库·python
MediaTea26 分钟前
<span class=“js_title_inner“>Python:实例对象</span>
开发语言·前端·javascript·python·ecmascript
微光闪现33 分钟前
践行“科技向善”,微乐播捐赠108,888元助力唇腭裂儿童绽放笑容
人工智能
闵帆41 分钟前
反演学习器面临的鸿沟
人工智能·学习·机器学习
feasibility.43 分钟前
多模态模型Qwen3-VL在Llama-Factory中断LoRA微调训练+测试+导出+部署全流程--以具身智能数据集open-eqa为例
人工智能·python·大模型·nlp·llama·多模态·具身智能
我需要一个支点43 分钟前
douyin无水印视频下载
爬虫·python