【发布时间】:2020-07-16 18:18:03
【问题描述】:
我在 PyTorch 中有二维张量,代表模型置信度。我要:
- 如果行中的第二个值大于或等于阈值,则所有其他值应更改为 0
- 其他值不应更改
简单的方法是:
- 遍历行
- 检查第二个值
- 如果值大于或等于,则创建零行,将第二个值从行更改为第二个值并替换行
- 否则什么都不做
但是,它效率低下。有没有一种矢量化/张量化的方式来做到这一点?
【问题讨论】:
标签: python machine-learning pytorch
我在 PyTorch 中有二维张量,代表模型置信度。我要:
简单的方法是:
但是,它效率低下。有没有一种矢量化/张量化的方式来做到这一点?
【问题讨论】:
标签: python machine-learning pytorch
我会首先构造一个新的零矩阵,然后根据需要将项目从矩阵移动到零矩阵。您复制第二个元素低于阈值的行中的所有行。对于所有其他行,您只需复制第二个元素。
import torch
threshold = .2
X = torch.rand((100, 10))
new = torch.zeros_like(X)
mask = X[:, 2] <= threshold
new[mask] = X[mask]
new[~mask, 2] = X[~mask, 2]
【讨论】:
试试这个
import numpy as np
x[(x[:,1] >= 0.5).nonzero(), np.r_[0, 2:x.shape[1]]] = 0.0
首先,使用(x[:,1] >= 0.5).nonzero()获取行索引,
然后取列索引np.r_[0, 2:x.shape[1]] 除了第二列。
【讨论】: