PyTorch torch.where 指南

📝 前言

在深度学习的日常开发中,我们经常会遇到这样的场景:需要根据某些条件对张量中的元素进行筛选或替换。如果用传统的循环写法,不仅代码臃肿,更重要的是会错失GPU加速的黄金机会

今天,让我们一起来探索PyTorch中一个既优雅又高效的工具 ------ torch.where

🎨 基础入门

基本语法

python 复制代码
torch.where(condition, x, y)

📌 关键说明:

  • condition:布尔张量(True/False),决定每个位置的选择

  • x、y:与 condition 形状一致(或可广播)的张量,数据类型可兼容

  • 返回值:与 condition 形状相同的新张量

快速认识 torch.where

python 复制代码
import torch

# 一个最简单的例子
condition = torch.tensor([True, False, True])
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

result = torch.where(condition, x, y)
print(result)  # tensor([1, 5, 3])

代码对比

python 复制代码
# ❌ 繁琐的循环写法
result = torch.empty_like(x)
for i in range(len(x)):
    if condition[i]:
        result[i] = x[i]
    else:
        result[i] = y[i]

# ✅ 优雅的 torch.where
result = torch.where(condition, x, y)

性能对比

性能测试代码:

python 复制代码
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import time

print("🚀 torch.where 性能测试 (100万数据)")
print("-"*40)

# 准备数据(固定随机种子,保证测试可复现)
torch.manual_seed(42)
cond = torch.rand(1_000_000) > 0.5
x = torch.randn(1_000_000)
y = torch.randn(1_000_000)

# 测试 torch.where(循环100次取平均,增加同步确保时间准确)
start = time.time()
for _ in range(100):
    res_where = torch.where(cond, x, y)
# 强制同步,避免PyTorch异步执行导致时间不准
torch.cuda.synchronize() if torch.cuda.is_available() else None
total_seconds = time.time() - start
avg_seconds = total_seconds / 100
torch_time = avg_seconds * 1000  # 转换为毫秒

# 测试 Python循环(为了公平,也循环10次取平均,避免单次波动)
start = time.time()
result = torch.empty_like(x)
# 循环10次,和torch.where的测试量级匹配
for _ in range(10):
    for i in range(1_000_000):
        result[i] = x[i] if cond[i] else y[i]
loop_total_seconds = time.time() - start
loop_avg_seconds = loop_total_seconds / 10
loop_time = loop_avg_seconds * 1000  # 转换为毫秒

# 计算加速比
speedup = loop_time / torch_time

# 输出结果
print(f"⚡ torch.where:  {torch_time:.3f} ms")
print(f"🐍 Python循环:    {loop_time:.3f} ms")
print(f"🚀 加速比:        {speedup:.1f}倍")
print("-"*40)

bars = 40
scale_factor = 100  # 放大100倍,避免torch_bars为0
torch_bars_scaled = int(bars * (torch_time * scale_factor) / loop_time)
# 限制最大长度不超过总bars数
torch_bars_scaled = min(torch_bars_scaled, bars)

print(f"\n🐍 Python循环:  {'█' * bars} {loop_time:.0f}ms")
print(f"⚡ torch.where: {'█' * torch_bars_scaled}{'░' * (bars - torch_bars_scaled)} {torch_time:.0f}ms (放大{scale_factor}倍显示)")
print(f"\n✨ torch.where 快 {speedup:.1f} 倍!")

# 验证结果正确性(确保两种方式结果一致)
assert torch.allclose(res_where, result), "两种方法结果不一致!"
print("\n✅ 结果验证通过:torch.where和循环计算结果一致")

🔥 实战应用场景

📊 数据清洗:标记异常值

python 复制代码
# 传感器数据,包含异常值(-999)
sensor_data = torch.tensor([23.5, -999, 24.8, 22.1, -999, 25.3])

# 清洗数据:将-999标记为None
clean_data = torch.where(
    sensor_data == -999,
    torch.tensor(float('nan')),
    sensor_data
)

print(clean_data)
# tensor([23.5, nan, 24.8, 22.1, nan, 25.3])

🖼️ 图像处理:一键二值化

python 复制代码
# 模拟灰度图像 (4x4)
image = torch.tensor([
    [0.1, 0.8, 0.3, 0.9],
    [0.6, 0.2, 0.7, 0.4],
    [0.5, 0.5, 0.8, 0.1],
    [0.3, 0.9, 0.2, 0.6]
])

# 二值化处理
binary = torch.where(image > 0.5, 1.0, 0.0)

print("二值化结果:")
print(binary)

🎯 关键点检测:坐标掩码

python 复制代码
# 检测到的关键点坐标和置信度
x_coords = torch.tensor([120, 245, 178, 320])
confidence = torch.tensor([0.95, 0.12, 0.88, 0.05])

# 只保留高置信度的关键点
valid_coords = torch.where(
    confidence > 0.5,
    x_coords,
    torch.tensor(-1)  # -1 表示无效点
)

print(valid_coords)  # tensor([120, -1, 178, -1])
相关推荐
*Lisen1 小时前
从零手写 FlashAttention(PyTorch实现 + 原理推导)
人工智能·pytorch·python
Jmayday3 小时前
Pytorch:CNN理论基础
人工智能·pytorch·cnn
Jmayday4 小时前
Pytorch:AI歌词生成器
人工智能·pytorch·python
AI技术增长4 小时前
Pytorch图像去噪实战(八):Noise2Void盲点网络图像去噪实战,只有单张带噪图也能训练
人工智能·pytorch·python
隔壁大炮4 小时前
Day07-RNN层(循环网络层)
人工智能·pytorch·python·rnn·深度学习·神经网络·计算机视觉
带电的小王6 小时前
【动手学深度学习】8.4. 循环神经网络
人工智能·pytorch·rnn·深度学习
ting94520007 小时前
动手学深度学习(PyTorch版)深度详解(4):深度学习计算实战详解
人工智能·pytorch·深度学习
kishu_iOS&AI9 小时前
NLP —— LSTM/GRU模型
人工智能·pytorch·深度学习·自然语言处理·gru·lstm
AI技术增长9 小时前
Pytorch图像去噪实战(九):SwinIR图像去噪实战,用Transformer解决CNN纹理恢复不足问题
pytorch·cnn·transformer
Jmayday10 小时前
Pytorch:CNN进行图象分类案例
人工智能·pytorch·cnn