【问题标题】:Find gradient of a function: Sympy vs. Jax查找函数的梯度:Sympy vs. Jax
【发布时间】:2020-08-03 01:46:09
【问题描述】:

我有一个函数Black_Cox() 调用其他函数,如下所示:

import numpy as np
from scipy import stats

# Parameters
D = 100
r = 0.05
γ = 0.1

# Normal CDF
N = lambda x: stats.norm.cdf(x)

H = lambda V, T, L, σ: np.exp(-r*T) * N( (np.log(V/L) + (r-0.5*σ**2)*T) / (σ*np.sqrt(T)) )

# Black-Scholes
def C_BS(V, K, T, σ):
    d1 = (np.log(V/K) + (r + 0.5*σ**2)*T ) / ( σ*np.sqrt(T) )
    d2 = d1 - σ*np.sqrt(T)
    return V*N(d1) - np.exp(-r*T)*K*N(d2)

def BL(V, T, D, L, σ):
    return L * H(V, T, L, σ) - L * (L/V)**(2*r/σ**2-1) * H(L**2/V, T, L, σ) \
              + C_BS(V, L, T, σ) - (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, L, T, σ) \
              - C_BS(V, D, T, σ) + (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, D, T, σ)

def Bb(V, T, C, γ, σ, a):
    b = (np.log(C/V) - γ*T) / σ
    μ = (r - a - 0.5*σ**2 - γ) / σ
    m = np.sqrt(μ**2 + 2*r)
    return C*np.exp(b*(μ-m)) * ( N((b-m*T)/np.sqrt(T)) + np.exp(2*m*b)*N((b+m*T)/np.sqrt(T)) )

def Black_Cox(V, T, C=160, σ=0.1, a=0):
    return np.exp(γ*T)*BL(V*np.exp(-γ*T), T, D*np.exp(-γ*T), C*np.exp(-γ*T), σ) + Bb(V, T, C, γ, σ, a)

我需要使用 Black_Cox 函数 w.r.t 的导数。 V。更准确地说,我需要在更改其他参数的数千条路径中评估此导数,找到导数并在某个 V 处进行评估。

最好的方法是什么?

  • 我是否应该使用 sympy 来找到这个导数,然后在我选择的 V 上进行评估,就像我在 Mathematica 中所做的那样:D[BlackCox[V, 10, 100, 160], V] /. V -> 180,或者

  • 我应该只使用jax吗?

如果sympy,你会如何建议我这样做?

对于jax,我了解我需要执行以下导入:

import jax.numpy as np
from jax.scipy import stats
from jax import grad

并在获得渐变之前重新评估我的功能:

func = lambda x: Black_Cox(x,10,160,0.1)
grad(func)(180.0)

如果我仍然需要使用 numpy 版本的函数,我是否必须为每个函数创建 2 个实例,或者是否有一种优雅的方法可以为 jax 目的复制函数?

【问题讨论】:

    标签: python sympy autodiff jax


    【解决方案1】:

    Jax 不提供任何内置方法来使用 numpy 和 scipy 的 jax 版本重新编译 numpy 函数。但是您可以使用如下所示的 sn-p 自动执行此操作:

    import inspect
    from functools import wraps
    import numpy as np
    import jax.numpy
    
    def replace_globals(func, globals_):
      """Recompile a function with replaced global values."""
      namespace = func.__globals__.copy()
      namespace.update(globals_)
      source = inspect.getsource(func)
      exec(source, namespace)
      return wraps(func)(namespace[func.__name__])
    

    它是这样工作的:

    def numpy_func(N):
      return np.arange(N) ** 2
    
    jax_func = replace_globals(numpy_func, {"np": jax.numpy})
    

    现在您可以评估 numpy 版本了:

    numpy_func(10)
    # array([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81])
    

    以及 jax 版本:

    jax_func(10)
    # DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)
    

    只需确保在包装更复杂的函数时替换所有相关的全局变量即可。

    【讨论】:

      猜你喜欢
      • 2018-05-18
      • 1970-01-01
      • 1970-01-01
      • 2022-06-12
      • 2021-09-02
      • 1970-01-01
      • 2011-10-28
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多