【发布时间】:2021-10-29 16:28:31
【问题描述】:
假设我有一个函数(为简单起见,两个系列之间的协方差,尽管问题更笼统):
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
现在我有一个“数据框”D(一个二维数组,其列是我的系列),我想对cov 进行矢量化处理,使矢量化函数的应用产生协方差矩阵。现在,有一种明显的方法:
cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))
但这似乎有点笨拙。有这样做的“规范”方式吗?
【问题讨论】: