【问题标题】:Correct usage of inheritance for wrapper classes正确使用包装类的继承
【发布时间】:2015-09-30 16:01:47
【问题描述】:

我正在尝试使用 ejml 库为矩阵运算编写 scala 包装器。基本上我只使用 SimpleMatrix。但是,我想要矩阵和向量的不同类,例如只能反转矩阵或明确声明函数返回向量,而不是矩阵。目前,我无法返回具体类而不是特征。

我从一个特征开始,MLMatrixLike:

trait MLMatrixLike {
  def data: SimpleMatrix
  protected def internalMult(implicit that: MLMatrixLike): SimpleMatrix = {
    data.mult(that.data)
  }
  def *(implicit that: MLMatrixLike): MLVector = MLVector(internalMult)
}

我的矩阵类和向量类都在扩展这个特征:

case class MLMatrix(data: SimpleMatrix) extends MLMatrixLike {

  def this(rawData: Array[Array[Double]]) = this(new SimpleMatrix(rawData))

  def apply(row: Int, col:Int): Double = data.get(row, col)

  def transpose(): MLMatrix = MLMatrix(data.transpose())

  def invert(): MLMatrix = MLMatrix(data.invert())

  def *(implicit that: MLMatrix): MLMatrix = MLMatrix(internalMult)

  def *(that: Double): MLMatrix = MLMatrix(data.scale(that))

  def -(that: MLMatrix): MLMatrix = MLMatrix(data.minus(that.data))
}

object MLMatrix {
  def apply(rawData: Array[Array[Double]]) = new MLMatrix(rawData)
}

case class MLVector(data: SimpleMatrix) extends MLMatrixLike {

  def this(rawData: Array[Double]) = {
    this(new SimpleMatrix(Array(rawData)).transpose())
  }

  def apply(index: Int): Double = data.get(index)

  def transpose(): MLVector = MLVector(data.transpose())

  def -(that: MLVector): MLVector = MLVector(data.minus(that.data))
}

object MLVector {
  def apply(rawData: Array[Double]) = new MLVector(rawData)
}

在我看来,这个设置不是很好。我只想定义一次乘法 (*),因为 SimpleMatrix 调用总是相同的,我可以从参数“that”的类型推断返回类型应该是矩阵还是向量。因此,我想在 MLMatrixLike 中按照这个(不工作的)函数定义一个函数:

def *[T <: MLMatrixLike](that :T): T = {
  new T(data.mult(that.data))
}

当然,这是行不通的,因为没有这样的构造函数 T,但目前我看不到,我怎样才能得到类似的东西。返回 MLMatrixLike 在我看来是不正确的,因为这样我无法在编译期间检查是否返回了正确的类型。

类似的问题适用于转置和减号 - 这里的返回类型始终是自己的类。

非常感谢!

【问题讨论】:

    标签: scala


    【解决方案1】:

    我不确定将SimpleMatrix 包装在另外两个类中有什么好处。但是,您可以通过将 MLMatrixLike 设为具有自身类型的泛型并定义抽象构造函数来解决重复问题。

    trait MLMatrixLike[Self <: MLMatrixLike[Self]] {
      this: Self =>
      def data: SimpleMatrix
    
      def createNew(data: SimpleMatrix): Self
    
      def *[T <: MLMatrixLike[T]](that: T): T = that.createNew(data.mult(that.data))
    
      def *(that: Double): Self = createNew(data.scale(that))
    
      def -(that: Self): Self = createNew(data.minus(that.data))
    
      def transpose: Self = createNew(data.transpose())
    }
    
    case class MLMatrix(data: SimpleMatrix) extends MLMatrixLike[MLMatrix] {
      this: MLMatrix =>
    
      def this(rawData: Array[Array[Double]]) = this(new SimpleMatrix(rawData))
    
      override def createNew(data: SimpleMatrix): MLMatrix = MLMatrix(data)
    
      def apply(row: Int, col: Int): Double = data.get(row, col)
    
      def invert(): MLMatrix = MLMatrix(data.invert())
    
    }
    
    object MLMatrix {
      def apply(rawData: Array[Array[Double]]) = new MLMatrix(rawData)
    }
    
    case class MLVector(data: SimpleMatrix) extends MLMatrixLike[MLVector] {
      this: MLVector =>
    
      def this(rawData: Array[Double]) = {
        this(new SimpleMatrix(Array(rawData)).transpose())
      }
    
      override def createNew(data: SimpleMatrix): MLVector = MLVector(data)
    
      def apply(index: Int): Double = data.get(index)
    
    }
    
    object MLVector {
      def apply(rawData: Array[Double]) = new MLVector(rawData)
    }
    

    顺便提一下,列向量乘以行向量是一个矩阵,所以乘法的签名可能不应该返回that的类型。但是,基于静态信息(您需要知道两个参数的维度),您无法判断乘法是返回向量还是矩阵,因此您也可以只返回 MLMatrixLike

    【讨论】:

    • 非常感谢您的回答!
    • 不客气。有趣的是,由于案例类带有copy 方法,您甚至不需要在MLMatrixLike 中定义抽象构造函数。不幸的是,您需要一些黑魔法才能让MLMatrixLike 知道copy 方法,因此它更像是一个玩物。见this answer
    猜你喜欢
    • 1970-01-01
    • 2017-04-14
    • 2013-11-17
    • 2012-06-22
    • 2015-12-04
    • 1970-01-01
    • 1970-01-01
    • 2020-11-13
    • 2014-08-15
    相关资源
    最近更新 更多