【发布时间】:2021-07-21 05:52:03
【问题描述】:
如何以与 JAX 兼容的方式实现以下功能(例如,使用 jax.numpy)?
def actions(state: tuple[int, ...]) -> list[tuple[int, ...]]:
l = []
iterables = [range(1, i+1) for i in state]
ns = list(range(len(iterables)))
for i, iterable in enumerate(iterables):
for value in iterable:
action = tuple(value if n == i else 0 for n in ns)
l.append(action)
return l
>>> state = (3, 1, 2)
>>> actions(state)
[(1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 2)]
【问题讨论】:
-
Jax 和 numpy 一样,不能有效地对元组和列表进行操作——输出一个二维数组是否足以满足您的用例?
-
当然,可以将数组作为输入(1D ... n)和输出(2D ... m x n)。元组只是纯 Python 等价物(因为我需要它们是不可变的)。