torch.where
torch.where
1. 基本语法
torch.where(condition, x, y)
condition:布尔条件(torch.BoolTensor),决定选择x还是y。x:如果condition为True,则选择x对应位置的元素。y:如果condition为False,则选择y对应位置的元素。
返回值:
- 一个新的张量,其元素来自
x或y,取决于condition。
2. 使用示例
示例 1:基本条件选择
import torch
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
condition = torch.tensor([True, False, True, False])
result = torch.where(condition, a, b)
print(result) # 输出: tensor([ 1, 20, 3, 40])
解释:
condition为True的位置选择a的值(1和3)。condition为False的位置选择b的值(20和40)。
示例 2:基于张量的条件
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[10, 20], [30, 40]])
condition = x > 2 # 返回布尔张量
print(condition)
# tensor([[False, False],
# [ True, True]])
result = torch.where(condition, x, y)
print(result)
# tensor([[10, 20],
# [ 3, 4]])
解释:
x > 2的位置选择x的值(3和4)。- 其余位置选择
y的值(10和20)。
示例 3:仅用 condition(类似 np.where(condition))
如果只传入 condition,torch.where() 会返回满足条件的 索引(类似于 np.where(condition)):
a = torch.tensor([1, 2, 3, 4, 5])
indices = torch.where(a > 3)
print(indices) # 输出: (tensor([3, 4]),)
解释:
- 返回
a > 3的索引位置(3和4)。
示例 4:用于替换特定值
x = torch.tensor([1, -2, 3, -4])
# 将负数替换为 0
result = torch.where(x < 0, torch.tensor(0), x)
print(result) # 输出: tensor([1, 0, 3, 0])
解释:
x < 0的位置替换为0,否则保留原值。
3. 实际应用场景
(1) 三元组损失(Triplet Loss)中的难样本挖掘
在 TripletLoss 中,可以用 torch.where 替代 mask 操作:
# 原始代码:
dist_ap = dist[i][mask[i]].max()
dist_an = dist[i][mask[i] == 0].min()
# 用 torch.where 实现:
dist_ap = torch.where(mask[i], dist[i], -torch.inf).max() # 仅计算正样本
dist_an = torch.where(~mask[i], dist[i], torch.inf).min() # 仅计算负样本
(2) 梯度裁剪
grad = torch.randn(5)
clipped_grad = torch.where(grad > 1.0, torch.tensor(1.0), grad)
(3) 动态调整学习率
lr = torch.where(epoch < 10, 0.1, 0.01) # 前 10 轮 lr=0.1,之后 lr=0.01
4. 注意事项
x和y的形状必须相同(或可广播)。condition必须是torch.BoolTensor(可以用>,<,==等生成)。torch.where(condition)返回的是索引,而不是值。
总结
| 用法 | 说明 |
|---|---|
torch.where(cond, x, y) | 类似 if-else,选择 x 或 y |
torch.where(cond) | 返回 cond 为 True 的索引 |
| 适用场景 | 条件选择、掩码操作、动态调整参数 |
torch.where() 在 PyTorch 中非常灵活,可以用于 条件计算、索引查找、动态调整张量值 等任务。