【深度学习基础】`view` 和 `reshape` 的参数详解

目录

基本概念

viewreshape 都用于调整张量的形状,它们的参数是新的形状,每个维度的大小可以指定为具体的数值或者 -1-1 表示这个维度的大小由张量的总元素数量自动推断。

参数详解
  • new_shape:这是一个 tuple 或者一个 list,定义了新的形状。每个元素代表对应维度的大小。
  • -1:特殊值,表示该维度的大小由其他维度自动推断。

示例

假设有一个张量 tensor,形状为 [batch_size, seq_len, num_labels]

python 复制代码
import torch

tensor = torch.randn(4, 3, 5)  # 示例张量,形状为 (4, 3, 5)

要将其形状调整为 [12, 5],可以使用 viewreshape

python 复制代码
# 使用 view
reshaped_tensor_view = tensor.view(-1, 5)
print("View tensor shape:", reshaped_tensor_view.shape)  # 输出: torch.Size([12, 5])

# 使用 reshape
reshaped_tensor_reshape = tensor.reshape(-1, 5)
print("Reshape tensor shape:", reshaped_tensor_reshape.shape)  # 输出: torch.Size([12, 5])

viewreshape 在具体应用中的参数解释

在序列标记分类任务中,我们通常需要将 logits 和标签调整为适合计算损失的形状。

假设 logits 的形状为 [batch_size, seq_len, num_labels],我们希望将其调整为 [batch_size * seq_len, num_labels],以便与标签 [batch_size * seq_len] 对应。

以下是使用 viewreshape 的示例:

python 复制代码
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForTokenClassification

# 初始化模型和tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name, num_labels=5)  # 假设有5个分类

# 假设输入文本
text = "I love natural language processing."
inputs = tokenizer(text, return_tensors="pt")

# 获取模型输出
outputs = model(**inputs)
seq_logits = outputs.logits

# 假设标签映射
tags_to_idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-LOC': 3, 'I-LOC': 4}
tags = torch.tensor([[0, 0, 0, 0, 1, 2, 3, 4]])  # 示例标签,形状为 (batch_size, seq_len)

# 使用 reshape 调整形状
pred = seq_logits.reshape([-1, len(tags_to_idx)])
label = tags.reshape([-1])
ignore_index = tags_to_idx["O"]

# 计算损失
criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
loss = criterion(pred, label)
print("Loss with reshape:", loss.item())

# 使用 view 调整形状
pred_view = seq_logits.view(-1, len(tags_to_idx))
label_view = tags.view(-1)

# 计算损失
loss_view = criterion(pred_view, label_view)
print("Loss with view:", loss_view.item())
参数解释
  • seq_logits.reshape([-1, len(tags_to_idx)])seq_logits.view(-1, len(tags_to_idx)])
    • -1:表示这个维度的大小由其他维度自动推断。这里是将 [batch_size, seq_len, num_labels] 调整为 [batch_size * seq_len, num_labels]
    • len(tags_to_idx):表示 num_labels,即分类的数量。

更多示例

高维张量示例

假设有一个四维张量,形状为 [2, 2, 3, 4],我们希望将其调整为 [4, 3, 4]

python 复制代码
import torch

tensor = torch.randn(2, 2, 3, 4)
print("Original shape:", tensor.shape)  # 输出: torch.Size([2, 2, 3, 4])

# 使用 view 调整形状
view_tensor = tensor.view(4, 3, 4)
print("View tensor shape:", view_tensor.shape)  # 输出: torch.Size([4, 3, 4])

# 使用 reshape 调整形状
reshape_tensor = tensor.reshape(4, 3, 4)
print("Reshape tensor shape:", reshape_tensor.shape)  # 输出: torch.Size([4, 3, 4])
非连续内存示例
python 复制代码
import torch

tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
transpose_tensor = tensor.t()  # 转置张量
print("Transpose shape:", transpose_tensor.shape)  # 输出: torch.Size([3, 2])

# 使用 view(会报错,因为内存不连续)
try:
    view_tensor = transpose_tensor.view(-1)
except RuntimeError as e:
    print("Error using view:", e)

# 使用 contiguous 方法确保内存连续
contiguous_tensor = transpose_tensor.contiguous()
view_tensor = contiguous_tensor.view(-1)
print("Contiguous view tensor:", view_tensor)
print("Contiguous view tensor shape:", view_tensor.shape)  # 输出: torch.Size([6])

# 使用 reshape
reshape_tensor = transpose_tensor.reshape(-1)
print("Reshape tensor:", reshape_tensor)
print("Reshape tensor shape:", reshape_tensor.shape)  # 输出: torch.Size([6])

总结

  • viewreshape 参数
    • 参数是一个 tuple 或者 list,定义新的形状。
    • -1 表示该维度的大小由其他维度自动推断。
  • view 的限制:要求输入张量是连续的。
  • reshape 的灵活性:可以处理非连续内存的张量。

通过这些详细的例子和解释,你可以更好地理解如何使用 viewreshape 来调整张量的形状。

相关推荐
大连好光景几秒前
软件测试笔记(2)
人工智能·功能测试·模块测试
纪伊路上盛名在10 分钟前
机器学习中的固定随机种子方案
人工智能·机器学习·数据分析·随机种子
SteveSenna16 分钟前
项目:Trossen Arm MuJoCo
人工智能·学习·算法
兢谨网安17 分钟前
AI安全:从技术加固到体系化防御的实战演进
人工智能·安全·网络安全·渗透测试
CoderJia程序员甲29 分钟前
GitHub 热榜项目 - 日榜(2026-03-29)
人工智能·ai·大模型·github·ai教程
龙腾AI白云36 分钟前
什么是AI智能体(AI Agent)
人工智能·深度学习·自然语言处理·数据分析
Sagittarius_A*40 分钟前
监督学习(Supervised Learning)
人工智能·学习·机器学习·监督学习
向上的车轮1 小时前
AI智能体开发:需求分析要点与实战指南
人工智能·需求分析
fobwebs1 小时前
wordpress GEO插件指南
人工智能·wordpress·geo·ai搜索优化·geo优化
GMATG_LIU1 小时前
电子背散射衍射(EBSD)技术的优势
人工智能