【问题标题】:How to set a custom loss function in Spark MLlib如何在 Spark MLlib 中设置自定义损失函数
【发布时间】:2018-04-27 18:23:44
【问题描述】:

我想使用我自己的损失函数而不是 spark MLlib 中 linear regression model 的平方损失。到目前为止,在文档中找不到任何提到它是否可能的部分。

【问题讨论】:

    标签: scala apache-spark machine-learning regression apache-spark-mllib


    【解决方案1】:

    TLDR; 使用自定义损失函数并不容易,因为您不能简单地将损失函数传递给 spark 模型。但是,您可以轻松地为自己编写自定义模型。

    长答案:
    如果你看LinearRegressionWithSGD的代码你会看到:

    class LinearRegressionWithSGD private[mllib] (
        private var stepSize: Double,
        private var numIterations: Int,
        private var regParam: Double,
        private var miniBatchFraction: Double)
      extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
    
      private val gradient = new LeastSquaresGradient() #Loss Function
      private val updater = new SimpleUpdater()
      @Since("0.8.0")
      override val optimizer = new GradientDescent(gradient, updater) #Optimizer
        .setStepSize(stepSize)
        .setNumIterations(numIterations)
        .setRegParam(regParam)
        .setMiniBatchFraction(miniBatchFraction)
    

    那么,我们来看看最小二乘损失函数是如何实现的here

    class LeastSquaresGradient extends Gradient {
      override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
        val diff = dot(data, weights) - label
        val loss = diff * diff / 2.0
        val gradient = data.copy
        scal(diff, gradient)
        (gradient, loss)
      }
    
      override def compute(
          data: Vector,
          label: Double,
          weights: Vector,
          cumGradient: Vector): Double = {
        val diff = dot(data, weights) - label
        axpy(diff, data, cumGradient)
        diff * diff / 2.0
      }
    }
    

    因此,您可以简单地编写一个类似LeastSquaresGradient 的类并实现compute 函数并在您的LinearRegressionWithSGD 模型中使用它。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-03-22
      • 2018-11-24
      • 2021-09-21
      • 1970-01-01
      • 1970-01-01
      • 2020-01-13
      相关资源
      最近更新 更多