PyTorch torch.where 指南

📝 前言

在深度学习的日常开发中,我们经常会遇到这样的场景:需要根据某些条件对张量中的元素进行筛选或替换。如果用传统的循环写法,不仅代码臃肿,更重要的是会错失GPU加速的黄金机会

今天,让我们一起来探索PyTorch中一个既优雅又高效的工具 ------ torch.where

🎨 基础入门

基本语法

python 复制代码
torch.where(condition, x, y)

📌 关键说明:

  • condition:布尔张量(True/False),决定每个位置的选择

  • x、y:与 condition 形状一致(或可广播)的张量,数据类型可兼容

  • 返回值:与 condition 形状相同的新张量

快速认识 torch.where

python 复制代码
import torch

# 一个最简单的例子
condition = torch.tensor([True, False, True])
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

result = torch.where(condition, x, y)
print(result)  # tensor([1, 5, 3])

代码对比

python 复制代码
# ❌ 繁琐的循环写法
result = torch.empty_like(x)
for i in range(len(x)):
    if condition[i]:
        result[i] = x[i]
    else:
        result[i] = y[i]

# ✅ 优雅的 torch.where
result = torch.where(condition, x, y)

性能对比

性能测试代码:

python 复制代码
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import time

print("🚀 torch.where 性能测试 (100万数据)")
print("-"*40)

# 准备数据(固定随机种子,保证测试可复现)
torch.manual_seed(42)
cond = torch.rand(1_000_000) > 0.5
x = torch.randn(1_000_000)
y = torch.randn(1_000_000)

# 测试 torch.where(循环100次取平均,增加同步确保时间准确)
start = time.time()
for _ in range(100):
    res_where = torch.where(cond, x, y)
# 强制同步,避免PyTorch异步执行导致时间不准
torch.cuda.synchronize() if torch.cuda.is_available() else None
total_seconds = time.time() - start
avg_seconds = total_seconds / 100
torch_time = avg_seconds * 1000  # 转换为毫秒

# 测试 Python循环(为了公平,也循环10次取平均,避免单次波动)
start = time.time()
result = torch.empty_like(x)
# 循环10次,和torch.where的测试量级匹配
for _ in range(10):
    for i in range(1_000_000):
        result[i] = x[i] if cond[i] else y[i]
loop_total_seconds = time.time() - start
loop_avg_seconds = loop_total_seconds / 10
loop_time = loop_avg_seconds * 1000  # 转换为毫秒

# 计算加速比
speedup = loop_time / torch_time

# 输出结果
print(f"⚡ torch.where:  {torch_time:.3f} ms")
print(f"🐍 Python循环:    {loop_time:.3f} ms")
print(f"🚀 加速比:        {speedup:.1f}倍")
print("-"*40)

bars = 40
scale_factor = 100  # 放大100倍,避免torch_bars为0
torch_bars_scaled = int(bars * (torch_time * scale_factor) / loop_time)
# 限制最大长度不超过总bars数
torch_bars_scaled = min(torch_bars_scaled, bars)

print(f"\n🐍 Python循环:  {'█' * bars} {loop_time:.0f}ms")
print(f"⚡ torch.where: {'█' * torch_bars_scaled}{'░' * (bars - torch_bars_scaled)} {torch_time:.0f}ms (放大{scale_factor}倍显示)")
print(f"\n✨ torch.where 快 {speedup:.1f} 倍!")

# 验证结果正确性(确保两种方式结果一致)
assert torch.allclose(res_where, result), "两种方法结果不一致!"
print("\n✅ 结果验证通过:torch.where和循环计算结果一致")

🔥 实战应用场景

📊 数据清洗:标记异常值

python 复制代码
# 传感器数据,包含异常值(-999)
sensor_data = torch.tensor([23.5, -999, 24.8, 22.1, -999, 25.3])

# 清洗数据:将-999标记为None
clean_data = torch.where(
    sensor_data == -999,
    torch.tensor(float('nan')),
    sensor_data
)

print(clean_data)
# tensor([23.5, nan, 24.8, 22.1, nan, 25.3])

🖼️ 图像处理:一键二值化

python 复制代码
# 模拟灰度图像 (4x4)
image = torch.tensor([
    [0.1, 0.8, 0.3, 0.9],
    [0.6, 0.2, 0.7, 0.4],
    [0.5, 0.5, 0.8, 0.1],
    [0.3, 0.9, 0.2, 0.6]
])

# 二值化处理
binary = torch.where(image > 0.5, 1.0, 0.0)

print("二值化结果:")
print(binary)

🎯 关键点检测:坐标掩码

python 复制代码
# 检测到的关键点坐标和置信度
x_coords = torch.tensor([120, 245, 178, 320])
confidence = torch.tensor([0.95, 0.12, 0.88, 0.05])

# 只保留高置信度的关键点
valid_coords = torch.where(
    confidence > 0.5,
    x_coords,
    torch.tensor(-1)  # -1 表示无效点
)

print(valid_coords)  # tensor([120, -1, 178, -1])
相关推荐
hboot2 天前
AI工程师第四课 - 深度学习入门
pytorch·python·神经网络
weiwei228445 天前
神经网络模型导出及开放标准格式ONNX
pytorch·onnx
程序猿追14 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
闵孚龙14 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
bryant_meng15 天前
【VAE】From Pixels to Faces: Building a VAE from Scratch
pytorch·vae·log-sigma2·重参数
装不满的克莱因瓶15 天前
了解多标签图像分类方法——从Sigmoid输出到真实世界复杂视觉理解
人工智能·pytorch·python·深度学习·机器学习·分类·数据挖掘
冷小鱼15 天前
TensorFlow 2.21 进阶实战:从训练优化到生产部署的完整指南
人工智能·pytorch·python·tensorflow
冷小鱼15 天前
PyTorch 2.12 完全指南:从动态图到编译优化的深度学习框架演进
人工智能·pytorch·深度学习
IRevers15 天前
【大模型】Gemma4在ROCm和vLLM部署
人工智能·pytorch·深度学习·大模型·datawhale·vllm·amdev
盼小辉丶15 天前
PyTorch强化学习实战(14)——优先经验回放机制
pytorch·python·深度学习·强化学习