【问题标题】:Computing the mean of a list efficiently in Haskell在 Haskell 中有效计算列表的平均值
【发布时间】:2011-03-19 01:51:22
【问题描述】:

我设计了一个函数来计算列表的平均值。虽然它工作得很好,但我认为它可能不是最好的解决方案,因为它需要两个功能而不是一个。是否可以仅使用一个递归函数完成这项工作?

calcMeanList (x:xs) = doCalcMeanList (x:xs) 0 0

doCalcMeanList (x:xs) sum length =  doCalcMeanList xs (sum+x) (length+1)
doCalcMeanList [] sum length = sum/length

【问题讨论】:

  • 请记住,对于这个问题的任何解决方案(相当于简单除法)都会为空列表生成 NaN。不一定是问题,只是我认为值得注意的事情。
  • 抱歉提交了重复的问题。下次我会更仔细地搜索。
  • @snowmantw:你不可能知道,该问题的标题中没有任何内容表明这是一个关于计算平均值的问题。 @Don Stewart:我不认为这是一个骗局。代码非常相似,但关于代码的问题却大不相同。

标签: performance list haskell


【解决方案1】:

虽然我不确定将它写在一个函数中是否“最好”,但可以按如下方式完成:

如果您提前知道长度(这里称其为“n”)很容易 - 您可以计算每个值“添加”到平均值的多少;这将是价值/长度。自avg(x1, x2, x3) = sum(x1, x2, x3)/length = (x1 + x2 + x3)/3 = x1/3 + x2/3 + x2/3

如果你事先不知道长度,那就有点棘手了:

假设我们使用列表 {x1,x2,x3} 而不知道它的 n=3。

第一次迭代只是 x1(因为我们假设它只有 n=1) 第二次迭代将添加 x2/2 并将现有平均值除以 2,所以现在我们有 x1/2 + x2/2

在第三次迭代之后,我们有 n=3,我们希望有 x1/3 +x2/3 + x3/3,但我们有 x1/2 + x2/2

所以我们需要乘以 (n-1) 并除以 n 得到 x1/3 + x2/3,然后我们只需将当前值 (x3) 除以 n 即可得到 x1/3 + x2/3 + x3/3

一般:

给定 n-1 项的平均值(算术平均值 - avg),如果您想将一项(newval)添加到平均值中,您的等式将是:

avg*(n-1)/n + newval/n。这个方程可以用归纳法在数学上证明。

希望这会有所帮助。

*请注意,此解决方案的效率低于您在示例中所做的简单地对变量求和并除以总长度。

【讨论】:

    【解决方案2】:

    为了跟进 Don 2010 年的回复,在 GHC 8.0.2 上我们可以做得更好。首先让我们试试他的版本。

    module Main (main) where
    
    import System.CPUTime.Rdtsc (rdtsc)
    import Text.Printf (printf)
    import qualified Data.Vector.Unboxed as U
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean' :: U.Vector Double -> Double
    mean' xs = s / fromIntegral n
      where
        Pair n s       = U.foldl' k (Pair 0 0) xs
        k (Pair n s) x = Pair (n+1) (s+x)
    
    main :: IO ()
    main = do
      s <- rdtsc
      let r = mean' (U.enumFromN 1 30000000)
      e <- seq r rdtsc
      print (e - s, r)
    

    这给了我们

    [nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
    [1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
    Linking MeanD ...
    (372877482,1.50000005e7)
         240,104,176 bytes allocated in the heap
               6,832 bytes copied during GC
              44,384 bytes maximum residency (1 sample(s))
              25,248 bytes maximum slop
                 230 MB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
      Gen  1         1 colls,     0 par    0.006s   0.006s     0.0062s    0.0062s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    0.087s  (  0.087s elapsed)
      GC      time    0.006s  (  0.006s elapsed)
      EXIT    time    0.006s  (  0.006s elapsed)
      Total   time    0.100s  (  0.099s elapsed)
    
      %GC     time       6.2%  (6.2% elapsed)
    
      Alloc rate    2,761,447,559 bytes per MUT second
    
      Productivity  93.8% of total user, 93.8% of total elapsed
    

    但是代码很简单:理想情况下应该不需要向量:优化代码应该可以通过内联列表生成来实现。幸运的是,GHC 可以为我们做到这一点[0]。

    module Main (main) where
    
    import System.CPUTime.Rdtsc (rdtsc)
    import Text.Printf (printf)
    import Data.List (foldl')
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean' :: [Double] -> Double
    mean' xs = v / fromIntegral l
      where
        Pair l v = foldl' f (Pair 0 0) xs
        f (Pair l' v') x = Pair (l' + 1) (v' + x)
    
    main :: IO ()
    main = do
      s <- rdtsc
      let r = mean' $ fromIntegral <$> [1 :: Int .. 30000000]
          -- This is slow!
          -- r = mean' [1 .. 30000000]
      e <- seq r rdtsc
      print (e - s, r)
    

    这给了我们:

    [nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
    [1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
    Linking MeanD ...
    (128434754,1.50000005e7)
             104,064 bytes allocated in the heap
               3,480 bytes copied during GC
              44,384 bytes maximum residency (1 sample(s))
              17,056 bytes maximum slop
                   1 MB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         0 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
      Gen  1         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    0.032s  (  0.032s elapsed)
      GC      time    0.000s  (  0.000s elapsed)
      EXIT    time    0.000s  (  0.000s elapsed)
      Total   time    0.033s  (  0.032s elapsed)
    
      %GC     time       0.1%  (0.1% elapsed)
    
      Alloc rate    3,244,739 bytes per MUT second
    
      Productivity  99.8% of total user, 99.8% of total elapsed
    

    [0]:注意我必须如何映射fromIntegral:没有这个,GHC 无法消除[Double],并且解决方案要慢得多。这有点可悲:我不明白为什么 GHC 无法内联/决定它不需要没有这个。如果你确实有真正的分数集合,那么这个技巧对你不起作用,向量可能仍然是必要的。

    【讨论】:

    • 还有一个有趣的提示:如果我们在[Int] 上工作并使用-fllvm,在这种情况下我们能够得到几乎恒定时间的答案。
    【解决方案3】:

    关于你能做的最好的事情是this version

    import qualified Data.Vector.Unboxed as U
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean :: U.Vector Double -> Double
    mean xs = s / fromIntegral n
      where
        Pair n s       = U.foldl' k (Pair 0 0) xs
        k (Pair n s) x = Pair (n+1) (s+x)
    
    main = print (mean $ U.enumFromN 1 (10^7))
    

    它融合到 Core 中的最佳循环(您可以编写的最好的 Haskell):

    main_$s$wfoldlM'_loop :: Int#
                                  -> Double#
                                  -> Double#
                                  -> Int#
                                  -> (# Int#, Double# #)    
    main_$s$wfoldlM'_loop =
      \ (sc_s1nH :: Int#)
        (sc1_s1nI :: Double#)
        (sc2_s1nJ :: Double#)
        (sc3_s1nK :: Int#) ->
        case ># sc_s1nH 0 of _ {
          False -> (# sc3_s1nK, sc2_s1nJ #);
          True ->
            main_$s$wfoldlM'_loop
              (-# sc_s1nH 1)
              (+## sc1_s1nI 1.0)
              (+## sc2_s1nJ sc1_s1nI)
              (+# sc3_s1nK 1)
        }
    

    还有以下组件:

    Main_mainzuzdszdwfoldlMzqzuloop_info:
    .Lc1pN:
            testq %r14,%r14
            jg .Lc1pQ
            movq %rsi,%rbx
            movsd %xmm6,%xmm5
            jmp *(%rbp)
    .Lc1pQ:
            leaq 1(%rsi),%rax
            movsd %xmm6,%xmm0
            addsd %xmm5,%xmm0
            movsd %xmm5,%xmm7
            addsd .Ln1pS(%rip),%xmm7
            decq %r14
            movsd %xmm7,%xmm5
            movsd %xmm0,%xmm6
            movq %rax,%rsi
            jmp Main_mainzuzdszdwfoldlMzqzuloop_info
    

    基于 Data.Vector。例如,

    $ ghc -Odph --make A.hs -fforce-recomp
    [1 of 1] Compiling Main             ( A.hs, A.o )
    Linking A ...
    $ time ./A
    5000000.5
    ./A  0.04s user 0.00s system 93% cpu 0.046 total
    

    查看the statistics package中的高效实现。

    【讨论】:

      【解决方案4】:

      对于那些想知道在 Haskell 中 gloomcoder 和 Assaf 的方法会是什么样子的人,这里有一个翻译:

      avg [] = 0
      avg x@(t:ts) = let xlen = toRational $ length x
                         tslen = toRational $ length ts
                         prevAvg = avg ts
                     in (toRational t) / xlen + prevAvg * tslen / xlen
      

      这种方式可确保正确计算每个步骤的“迄今为止的平均值”,但这样做的代价是大量冗余的长度乘/除,并且每个步骤的长度计算效率非常低。没有经验丰富的 Haskeller 会这样写。

      一个稍微好一点的方法是:

      avg2 [] = 0
      avg2 x = fst $ avg_ x
          where 
            avg_ [] = (toRational 0, toRational 0)
            avg_ (t:ts) = let
                 (prevAvg, prevLen) = avg_ ts
                 curLen = prevLen + 1
                 curAvg = (toRational t) / curLen + prevAvg * prevLen / curLen
              in (curAvg, curLen)
      

      这避免了重复的长度计算。但它需要一个辅助函数,而这正是原始发布者试图避免的。而且它仍然需要大量取消超出长度的术语。

      为了避免长度被抵消,我们可以将和和长度相加,最后除:

      avg3 [] = 0
      avg3 x = (toRational total) / (toRational len)
          where 
            (total, len) = avg_ x
            avg_ [] = (0, 0)
            avg_ (t:ts) = let 
                (prevSum, prevLen) = avg_ ts
             in (prevSum + t, prevLen + 1)
      

      这可以更简洁地写成一个 foldr:

      avg4 [] = 0
      avg4 x = (toRational total) / (toRational len)
          where
            (total, len) = foldr avg_ (0,0) x
            avg_ t (prevSum, prevLen) = (prevSum + t, prevLen + 1)
      

      可以根据上面的帖子进一步简化。

      Fold 确实是通往这里的路。

      【讨论】:

        【解决方案5】:

        当我看到你的问题时,我立刻想到“你想要一个fold!”

        果然,a similar question 之前在 StackOverflow 上被问过,this answer 有一个非常高效的解决方案,您可以在 GHCi 等交互式环境中进行测试:

        import Data.List
        
        let avg l = let (t,n) = foldl' (\(b,c) a -> (a+b,c+1)) (0,0) l 
                    in realToFrac(t)/realToFrac(n)
        
        avg ([1,2,3,4]::[Int])
        2.5
        avg ([1,2,3,4]::[Double])
        2.5
        

        【讨论】:

          【解决方案6】:

          您的解决方案很好,使用两个功能并不比一个差。不过,您可以将尾递归函数放在 where 子句中。

          但如果你想在一行中完成:

          calcMeanList = uncurry (/) . foldr (\e (s,c) -> (e+s,c+1)) (0,0)
          

          【讨论】:

          • 为什么是 foldr 而不是 foldl?似乎更适合我。
          • foldl、foldl' 或 foldr 可以在这里使用,因为无论如何你都必须遍历整个列表(这是我选择的那个)...我认为如果性能很重要,可以在这里使用 foldl'
          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2012-06-03
          • 1970-01-01
          相关资源
          最近更新 更多