【问题标题】:How can I extend Scala collections with an argmax method?如何使用 argmax 方法扩展 Scala 集合?
【发布时间】:2011-03-04 07:04:09
【问题描述】:

我想在所有有意义的集合中添加一个argMax 方法。 怎么做?使用隐式?

【问题讨论】:

  • 集合不是表达式。它怎么会有一个 argmax 方法?您是否试图找到集合中最大的元素?

标签: scala scala-2.8 extend


【解决方案1】:

又好又容易? :

val l = List(1,0,10,2)
l.zipWithIndex.maxBy(x => x._1)._2

【讨论】:

    【解决方案2】:

    根据其他答案,您可以很容易地结合每个人的优势(最少拨打f() 等)。在这里,我们对所有 Iterables 进行了隐式转换(因此它们可以透明地调用 .argmax()),以及出于某种原因首选的独立方法。 ScalaTest 测试启动。

    class Argmax[A](col: Iterable[A]) {
      def argmax[B](f: A => B)(implicit ord: Ordering[B]): Iterable[A] = {
        val mapped = col map f
        val max = mapped max ord
        (mapped zip col) filter (_._1 == max) map (_._2)
      }
    }
    
    object MathOps {
      implicit def addArgmax[A](col: Iterable[A]) = new Argmax(col)
    
      def argmax[A, B](col: Iterable[A])(f: A => B)(implicit ord: Ordering[B]) = {
        new Argmax(col) argmax f
      }
    }
    
    class MathUtilsTests extends FunSuite {
      import MathOps._
    
      test("Can argmax with unique") {
        assert((-10 to 0).argmax(_ * -1).toSet === Set(-10))
        // or alternate calling syntax
        assert(argmax(-10 to 0)(_ * -1).toSet === Set(-10))
      }
    
      test("Can argmax with multiple") {
        assert((-10 to 10).argmax(math.pow(_, 2)).toSet === Set(-10, 10))
      }
    }
    

    【讨论】:

      【解决方案3】:

      这是一个大致基于@Daniel 接受的答案的变体,它也适用于 Set。

      def argMax[A, B: Ordering](input: GenIterable[A], f: A => B) : GenSet[A] = argMaxZip(input, f) map (_._1) toSet
      
      def argMaxZip[A, B: Ordering](input: GenIterable[A], f: A => B): GenIterable[(A, B)] = {
        if (input.isEmpty) Nil
        else {
          val fPairs = input map (x => (x, f(x)))
          val maxF = fPairs.map(_._2).max
          fPairs filter (_._2 == maxF)
        }
      }
      

      当然,也可以做一个产生 (B, I​​terable[A]) 的变体。

      【讨论】:

        【解决方案4】:

        这是一种使用隐式构建器模式的方法。与以前的解决方案相比,它具有与任何 Traversable 一起工作的优势,并返回类似的 Traversable。可悲的是,这是非常必要的。如果有人愿意,它可能会变成一个相当丑陋的折叠。

        object RichTraversable {
          implicit def traversable2RichTraversable[A](t: Traversable[A]) = new RichTraversable[A](t)
        }
        
        class RichTraversable[A](t: Traversable[A]) { 
          def argMax[That, C](g: A => C)(implicit bf : scala.collection.generic.CanBuildFrom[Traversable[A], A, That], ord:Ordering[C]): That = {
            var minimum:C = null.asInstanceOf[C]
            val repr = t.repr
            val builder = bf(repr)
            for(a<-t){
              val test: C = g(a)
              if(test == minimum || minimum == null){
                builder += a
                minimum = test
              }else if (ord.gt(test, minimum)){
                builder.clear
                builder += a
                minimum = test
              }
            }
            builder.result
          }
        }
        
        Set(-2, -1, 0, 1, 2).argmax(x=>x*x) == Set(-2, 2)
        List(-2, -1, 0, 1, 2).argmax(x=>x*x) == List(-2, 2)
        

        【讨论】:

          【解决方案5】:

          在 Scala 2.8 上,这有效:

          val list = List(1, 2, 3)
          def f(x: Int) = -x
          val argMax = list max (Ordering by f)
          

          正如mkneissl 所指出的,这不会返回最大点的。这是一个替代实现,它试图减少对f 的调用次数。如果对f 的调用没有那么重要,请参阅mkneissl's answer。另外,请注意他的答案是 curried,这提供了卓越的类型推断。

          def argMax[A, B: Ordering](input: Iterable[A], f: A => B) = {
            val fList = input map f
            val maxFList = fList.max
            input.view zip fList filter (_._2 == maxFList) map (_._1) toSet
          }
          
          scala> argMax(-2 to 2, (x: Int) => x * x)
          res15: scala.collection.immutable.Set[Int] = Set(-2, 2)
          

          【讨论】:

          • 如果 f 达到其域的多个元素的最大值,则失败。考虑 x=>x*x 和 x 从 -2 到 2
          • 集合是可迭代的,但这不适用于集合,最明显的是当 f 为多个元素产生相同的结果时。原因是input map f 产生了另一个 Set,它现在小于输入。因此 zip 是不同步的,并且产生了错误的结果。即使 f 的结果是唯一的,迭代顺序也不一致,因此 zip 最终会被打乱。要求输入是 Seq 可以修复它。如果不需要执行 mySet.toSeq,则需要另一种解决方案(例如,执行 input map {a=&gt;(a,f(a)} 等的解决方案)。
          【解决方案6】:

          argmax 函数(我从Wikipedia 了解到)

          def argMax[A,B](c: Traversable[A])(f: A=>B)(implicit o: Ordering[B]): Traversable[A] = {
            val max = (c map f).max(o)
            c filter { f(_) == max }
          }
          

          如果你真的想要,你可以把它拉到收藏中

          implicit def enhanceWithArgMax[A](c: Traversable[A]) = new {
            def argMax[B](f: A=>B)(implicit o: Ordering[B]): Traversable[A] = ArgMax.argMax(c)(f)(o)
          }
          

          并像这样使用它

          val l = -2 to 2
          assert (argMax(l)(x => x*x) == List(-2,2))
          assert (l.argMax(x => x*x) == List(-2,2))
          

          (Scala 2.8)

          【讨论】:

            【解决方案7】:

            是的,通常的方法是使用“pimp my library”模式来装饰您的收藏。例如(N.B. 仅作为说明,并非正确或有效的示例):

            
            trait PimpedList[A] {
               val l: List[A]
            
              //example argMax, not meant to be correct
              def argMax[T <% Ordered[T]](f:T => T) = {error("your definition here")}
            }
            
            implicit def toPimpedList[A](xs: List[A]) = new PimpedList[A] { 
              val l = xs
            }
            
            scala> def f(i:Int):Int = 10
            f: (i: Int) Int
            
            scala> val l = List(1,2,3)
            l: List[Int] = List(1, 2, 3)
            
            scala> l.argMax(f)
            java.lang.RuntimeException: your definition here
                at scala.Predef$.error(Predef.scala:60)
                at PimpedList$class.argMax(:12)
                    //etc etc...
            

            【讨论】:

              【解决方案8】:

              您可以使用Pimp my Library 模式将函数添加到Scala 中的现有API。您可以通过定义一个隐式转换函数来做到这一点。例如,我有一个类 Vector3 来表示 3D 向量:

              class Vector3 (val x: Float, val y: Float, val z: Float)
              

              假设我希望能够通过编写如下内容来缩放矢量:2.5f * v。我当然不能直接向Float 类添加* 方法,但我可以提供这样的隐式转换函数:

              implicit def scaleVector3WithFloat(f: Float) = new {
                  def *(v: Vector3) = new Vector3(f * v.x, f * v.y, f * v.z)
              }
              

              请注意,这会返回一个结构类型的对象(new { ... } 构造),其中包含 * 方法。

              我还没有测试过,但我想你可以这样做:

              implicit def argMaxImplicit[A](t: Traversable[A]) = new {
                  def argMax() = ...
              }
              

              【讨论】:

                猜你喜欢
                • 1970-01-01
                • 2015-04-10
                • 2016-07-08
                • 1970-01-01
                • 2012-04-03
                • 2014-06-02
                • 1970-01-01
                • 1970-01-01
                相关资源
                最近更新 更多