PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数 说明
input 输入张量
start_dim 开始展平的维度(包含该维)
end_dim 结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

python 复制代码
import torch

x = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])  # shape = (2, 2, 2)

out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

python 复制代码
x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)

print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

python 复制代码
x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)

print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数 说明
view() 更底层、需要张量是连续的,手动指定形状
flatten() 更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

python 复制代码
import torch.nn as nn

class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)              # shape: (N, 16, 1, 1)
        x = torch.flatten(x, 1)       # shape: (N, 16)
        x = self.fc(x)
        return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):
    def __init__(self):
        super(FlattenCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8]
        )

        self.fc = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)              # CIFAR-10 共 10 类
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batch
        x = self.fc(x)
        return x

# ==== 2. 准备数据 ====
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ==== 4. 训练过程 ====
def train(model, loader, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数 特点说明
flatten() 高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view() 底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape() 更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

python 复制代码
x = torch.randn(32, 3, 64, 64)  # batch of images

# flatten
f1 = torch.flatten(x, 1)

# view
f2 = x.view(32, -1)

# reshape
f3 = x.reshape(32, -1)

print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

python 复制代码
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量

try:
    y.view(-1)  # 会报错
except RuntimeError as e:
    print("view error:", e)

print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

python 复制代码
import torch
import time

x = torch.randn(1024, 512, 28, 28)

# 保证是连续的
x_contig = x.contiguous()

N = 1000

def benchmark(op, name):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(N):
        _ = op(x_contig)
    torch.cuda.synchronize()
    end = time.time()
    print(f"{name}: {(end - start)*1000:.2f} ms")

benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

复制代码
flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景 推荐方式 原因
模型中展平 CNN 输出 flatten() 简洁、安全,尤其在复杂网络中
确保连续张量、追求速度 view() 性能最佳
张量可能非连续 reshape() 自动处理不连续情况,代码更鲁棒

六、小结

用法 效果
torch.flatten(x) 将所有维展平成一维
torch.flatten(x, 1) 保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2) 展平指定维度区间

相关推荐
研梦非凡20 分钟前
CVPR 2025|基于视觉语言模型的零样本3D视觉定位
人工智能·深度学习·计算机视觉·3d·ai·语言模型·自然语言处理
Monkey的自我迭代24 分钟前
多目标轮廓匹配
人工智能·opencv·计算机视觉
每日新鲜事24 分钟前
Saucony索康尼推出全新 WOOOLLY 运动生活羊毛系列 生动无理由,从专业跑步延展运动生活的每一刻
大数据·人工智能
空白到白30 分钟前
机器学习-聚类
人工智能·算法·机器学习·聚类
中新赛克1 小时前
双引擎驱动!中新赛克AI安全方案入选网安创新大赛优胜榜单
人工智能·安全
飞哥数智坊1 小时前
解决AI幻觉,只能死磕模型?OpenAI给出不一样的思路
人工智能·openai
聚客AI1 小时前
🌈多感官AI革命:解密多模态对齐与融合的底层逻辑
人工智能·llm·掘金·日新计划
大佬,救命!!!1 小时前
整理python快速构建数据可视化前端的Dash库
python·信息可视化·学习笔记·dash·记录成长
孔丘闻言1 小时前
python调用mysql
android·python·mysql
zzywxc7871 小时前
AI在金融、医疗、教育、制造业等领域的落地案例
人工智能·机器学习·金融·prompt·流程图