【问题标题】:Is there a function for picking the max value in a pytorch tensor which follows a requirenment是否有一个函数可以在遵循要求的 pytorch 张量中选择最大值
【发布时间】:2021-10-06 13:20:54
【问题描述】:

我正在使用强化学习和 pytorch 在 python 中编写一个黑白棋机器人。在程序中,我扫描棋盘以寻找合法动作。 AI应该选择概率最高的棋步,并且根据之前的计算是合法的。这里我需要一个像这样工作的函数:

a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))

输出应该是a中最大值的id,在本例中为3。这个函数存在吗?如果没有,还有其他原因吗?

【问题讨论】:

  • 我从来没有玩过奥赛罗。虽然我认为更简单的方法是在棋盘上生成一组[0, 1] 掩码,以将非法动作归零。然后只需执行argmax 即可获得最佳移动的索引。

标签: python pytorch tensor


【解决方案1】:

假设你的张量中至少有一个非负值,你将它乘以掩码本身以删除排序中的排除值:

>>> torch.argmax(a*b)
tensor(3)

如果不是这种情况,您仍然可以使用 torch.where 将排除的值替换为将被 argmax 忽略的某个值(例如 a.min()),从而摆脱它:

>>> torch.where(b, a, a.min()).argmax()
tensor(3)

【讨论】:

    猜你喜欢
    • 2020-11-16
    • 1970-01-01
    • 1970-01-01
    • 2013-10-03
    • 1970-01-01
    • 2021-10-17
    • 2020-10-21
    • 1970-01-01
    • 2020-08-31
    相关资源
    最近更新 更多