【问题标题】:Vectorization guidelnes for jaxjax 的矢量化指南
【发布时间】:2021-10-29 16:28:31
【问题描述】:

假设我有一个函数(为简单起见,两个系列之间的协方差,尽管问题更笼统):

def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

现在我有一个“数据框”D(一个二维数组,其列是我的系列),我想对cov 进行矢量化处理,使矢量化函数的应用产生协方差矩阵。现在,有一种明显的方法:

cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))

但这似乎有点笨拙。有这样做的“规范”方式吗?

【问题讨论】:

    标签: python jax


    【解决方案1】:

    如果您想用vmap 表达与嵌套for 循环等效的逻辑,那么是的,它需要嵌套的vmap。我认为您所写的内容可能与这样的操作一样规范,尽管如果使用装饰器编写可能会更清晰:

    from functools import partial
    
    @partial(jax.vmap, in_axes=(1, None))
    @partial(jax.vmap, in_axes=(None, 1))
    def cov(x, y):
       return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
    

    不过,对于这个特定的函数,请注意,如果您愿意,可以使用单点积来表达相同的内容:

    result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))
    

    【讨论】:

    • 谢谢!特定的函数实际上是内置的,所以这不是一个很好的例子(我实际上是在写一个健壮的协方差,其中最里面的函数是不同的)。
    猜你喜欢
    • 2021-11-05
    • 1970-01-01
    • 2023-03-16
    • 2017-06-18
    • 1970-01-01
    • 1970-01-01
    • 2020-05-18
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多