PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数 说明
input 输入张量
start_dim 开始展平的维度(包含该维)
end_dim 结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

python 复制代码
import torch

x = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])  # shape = (2, 2, 2)

out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

python 复制代码
x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)

print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

python 复制代码
x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)

print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数 说明
view() 更底层、需要张量是连续的,手动指定形状
flatten() 更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

python 复制代码
import torch.nn as nn

class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)              # shape: (N, 16, 1, 1)
        x = torch.flatten(x, 1)       # shape: (N, 16)
        x = self.fc(x)
        return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):
    def __init__(self):
        super(FlattenCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8]
        )

        self.fc = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)              # CIFAR-10 共 10 类
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batch
        x = self.fc(x)
        return x

# ==== 2. 准备数据 ====
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ==== 4. 训练过程 ====
def train(model, loader, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数 特点说明
flatten() 高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view() 底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape() 更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

python 复制代码
x = torch.randn(32, 3, 64, 64)  # batch of images

# flatten
f1 = torch.flatten(x, 1)

# view
f2 = x.view(32, -1)

# reshape
f3 = x.reshape(32, -1)

print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

python 复制代码
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量

try:
    y.view(-1)  # 会报错
except RuntimeError as e:
    print("view error:", e)

print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

python 复制代码
import torch
import time

x = torch.randn(1024, 512, 28, 28)

# 保证是连续的
x_contig = x.contiguous()

N = 1000

def benchmark(op, name):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(N):
        _ = op(x_contig)
    torch.cuda.synchronize()
    end = time.time()
    print(f"{name}: {(end - start)*1000:.2f} ms")

benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

复制代码
flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景 推荐方式 原因
模型中展平 CNN 输出 flatten() 简洁、安全,尤其在复杂网络中
确保连续张量、追求速度 view() 性能最佳
张量可能非连续 reshape() 自动处理不连续情况,代码更鲁棒

六、小结

用法 效果
torch.flatten(x) 将所有维展平成一维
torch.flatten(x, 1) 保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2) 展平指定维度区间

相关推荐
XIAO·宝17 分钟前
机器学习--数据预处理
人工智能·机器学习·数据预处理
爱喝奶茶的企鹅22 分钟前
Ethan独立开发新品速递 | 2025-08-21
人工智能
爱喝奶茶的企鹅24 分钟前
Ethan开发者创新项目日报 | 2025-08-21
人工智能
算家计算40 分钟前
字节跳动开源Seed-OSS-36B:512K上下文,代理与长上下文基准新SOTA
人工智能·开源·资讯
前端小趴菜0541 分钟前
python - 条件判断
python
THMAIL42 分钟前
大模型“知识”的外挂:RAG检索增强生成详解
人工智能
汀丶人工智能44 分钟前
AI Compass前沿速览:DINOv3-Meta视觉基础模型、DeepSeek-V3.1、Qwen-Image、Seed-OSS、CombatVLA-3D动
人工智能
范男1 小时前
基于Pytochvideo训练自己的的视频分类模型
人工智能·pytorch·python·深度学习·计算机视觉·3d·视频
hui函数1 小时前
Flask-WTF表单验证全攻略
后端·python·flask·web·表单验证
二向箔reverse1 小时前
机器学习算法核心总结
人工智能·算法·机器学习