【问题标题】:Restricting function signatures while using ForwardDiff in Julia在 Julia 中使用 ForwardDiff 时限制函数签名
【发布时间】:2019-05-06 18:15:03
【问题描述】:

我正在尝试在几乎所有函数都被限制为只能接收浮点数的库中使用 ForwardDiff。我想概括这些函数签名,以便可以使用 ForwardDiff,同时仍然具有足够的限制性,因此函数只接受数值而不是日期之类的东西。我有很多具有相同名称但类型不同的函数(即,将“时间”作为浮点数或具有相同函数名称的日期的函数)并且不想从头到尾删除类型限定符。

最小的工作示例

using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
    return sum(exp.(x))
end
function grad_F(x::Array)
  return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error

function G(x::Array{Float64,1})
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error

function G(x)
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.

有没有办法限制函数只接受数值(整数和浮点数)以及 ForwardDiff 使用但不允许符号、日期等的任何双数结构。

【问题讨论】:

    标签: julia automatic-differentiation


    【解决方案1】:

    ForwardDiff.Dual 是抽象类型Real 的子类型。但是,您遇到的问题是 Julia 的类型参数是不变的,而不是协变的。然后,以下内容返回 false。

    # check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
    julia> Array{Float64, 1} <: Array{Real, 1}
    false
    

    这就是你的函数定义

    function G(x::Array{Real,1})
        return sum(exp.(x))
    end
    

    不正确(不适合您使用)。这就是您收到以下错误的原因。

    julia> G(x)
    ERROR: MethodError: no method matching G(::Array{Float64,1})
    

    正确的定义应该是

    function G(x::Array{<:Real,1})
        return sum(exp.(x))
    end
    

    或者如果您需要轻松访问数组的具体元素类型

     function G(x::Array{T,1}) where {T<:Real}
         return sum(exp.(x))
     end
    

    grad_F 函数也是如此。

    您可能会发现阅读 Julia 文档中的 the relevant section 以了解类型很有用。


    您可能还希望将您的函数类型注释为 AbstractArray{&lt;:Real,1} 类型而不是 Array{&lt;:Real, 1},以便您的函数可以处理其他类型的数组,例如 StaticArraysOffsetArrays 等,而无需重新定义。

    【讨论】:

      【解决方案2】:

      这将接受由任何类型的数字参数化的任何类型的数组:

      function foo(xs::AbstractArray{<:Number})
        @show typeof(xs)
      end
      

      或:

      function foo(xs::AbstractArray{T}) where T<:Number
        @show typeof(xs)
      end
      

      如果需要引用body函数内部的类型参数T

      x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
      x2 = [1, 2, 3,4, 5]
      x3 = 1:5
      x4 = 1.0:5.0
      x5 = [1//2, 1//4, 1//8]
      
      xss = [x1, x2, x3, x4, x5]
      
      function foo(xs::AbstractArray{T}) where T<:Number
        @show xs typeof(xs) T
        println()
      end
      
      for xs in xss
        foo(xs)
      end
      

      输出:

      xs = [1.0, 2.0, 3.0, 4.0, 5.0]
      typeof(xs) = Array{Float64,1}
      T = Float64
      
      xs = [1, 2, 3, 4, 5]
      typeof(xs) = Array{Int64,1}
      T = Int64
      
      xs = 1:5
      typeof(xs) = UnitRange{Int64}
      T = Int64
      
      xs = 1.0:1.0:5.0
      typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
      T = Float64
      
      xs = Rational{Int64}[1//2, 1//4, 1//8]
      typeof(xs) = Array{Rational{Int64},1}
      T = Rational{Int64}
      

      您可以在此处运行示例代码:https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2011-10-05
        • 1970-01-01
        • 2020-04-30
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多