【问题标题】:Training/Test data with SparkML in Scala在 Scala 中使用 SparkML 训练/测试数据
【发布时间】:2021-12-30 15:42:09
【问题描述】:

在过去的几个小时里,我一直面临着一个问题。 理论上,当我们拆分数据进行训练和测试时,我们应该独立地对训练数据进行标准化,以免引入偏差,然后在训练完模型后,我们是否使用相同的“参数”值对测试集进行标准化训练集。

到目前为止,我只设法在没有管道的情况下做到了,看起来像这样:

val training = splitData(0)
val test = splitData(1)
val assemblerTraining = new VectorAssembler()
 .setInputCols(training.columns)
 .setOutputCol("features")   
val standardScaler = new StandardScaler()
 .setInputCol("features")
 .setOutputCol("normFeatures")
 .setWithStd(true)
 .setWithMean(true)
val scalerModel = standardScaler.fit(training)
val scaledTrainingData = scalerModel.transform(training)
val scaledTestData = scalerModel.transform(test)

我将如何使用管道实现这一点? 我的问题是,如果我创建这样的管道:

        val pipelineTraining = new Pipeline()
            .setStages(
                Array(
                    assemblerTraining,
                    standardScaler,
                    lr
                )
            )

其中 lr 是线性回归,则无法从管道内部实际访问缩放模型。

我还考虑过使用中间管道进行缩放,如下所示:

val pipelineScalingModel = new Pipeline()
 .setStages(Array(assemblerTraining, standardScaler))
 .fit(training)

val pipelineTraining = new Pipeline()
 .setStages(Array(pipelineScalingModel,lr))

val scaledTestData = pipelineScalingModel.transform(test)

但我不知道这是否是正确的做法。

任何建议将不胜感激。

【问题讨论】:

    标签: scala apache-spark pipeline apache-spark-mllib


    【解决方案1】:

    如果其他人遇到这个问题,我就是这样处理的:

    我意识到我不允许修改 [forbiddenColumnName] 变量。因此我放弃了在那个阶段尝试使用管道。 我创建了自己的标准化函数并为每个单独的功能调用它,如下所示:

    def standardizeColumn( dfTrain : DataFrame, dfTest : DataFrame, columnName : String) : Array[DataFrame] = {
       val withMeanStd = dfTrain.select(mean(col(columnName)), stddev(col(columnName))).collect
       val auxDFTrain = dfTrain.withColumn(columnName, (col(columnName) - withMeanStd(0).getDouble(0))/withMeanStd(0).getDouble(1))
       val auxDFTest = dfTest.withColumn(columnName, (col(columnName) - withMeanStd(0).getDouble(1))/withMeanStd(0).getDouble(1))
            Array(auxDFTrain, auxDFTest)
    }
    
    for (columnName <- training.columns){
       if ((columnName != [forbiddenColumnName]) && (columnExists(training, columnName))){
           val auxResult = standardizeColumn(training, test, columnName)
           training = auxResult(0)
           test = auxResult(1)
       }
    }
    

    [提及] 我的变量数量非常少~15,因此这不是一个非常漫长的过程。我严重怀疑这是否是处理更大数据集的正确方法。

    【讨论】:

      猜你喜欢
      • 2021-04-09
      • 1970-01-01
      • 2020-09-14
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-06-21
      • 2018-05-19
      • 2021-03-14
      相关资源
      最近更新 更多