【问题标题】:PySpark LinearRegressionWithSGD, model predict dimensions mismatchPySpark LinearRegressionWithSGD,模型预测维度不匹配
【发布时间】:2017-12-27 16:30:07
【问题描述】:

我遇到了以下错误:

AssertionError: 维度不匹配

我使用 PySpark 的 LinearRegressionWithSGD 训练了一个线性回归模型。 但是,当我尝试对训练集进行预测时,出现“维度不匹配”错误。

值得一提:

  1. 使用 StandardScaler 缩放数据,但未使用预测值。
  2. 从代码中可以看出,用于训练的特征是由 PCA 生成的。

一些代码:

pca_transformed = pca_model.transform(data_std)
X = pca_transformed.map(lambda x: (x[0], x[1]))
data = train_votes.zip(pca_transformed)
labeled_data = data.map(lambda x : LabeledPoint(x[0], x[1:]))
linear_regression_model = LinearRegressionWithSGD.train(labeled_data, iterations=10)

预测是错误的来源,这些是我尝试过的变体:

pred = linear_regression_model.predict(pca_transformed.collect())
pred = linear_regression_model.predict([pca_transformed.collect()])    
pred = linear_regression_model.predict(X.collect())
pred = linear_regression_model.predict([X.collect()])

回归权重:

DenseVector([1.8509, 81435.7615])

使用的向量:

pca_transformed.take(1)
[DenseVector([-0.1745, -1.8936])]

X.take(1)
[(-0.17449817243564397, -1.8935926689554488)]

labeled_data.take(1)
[LabeledPoint(22221.0, [-0.174498172436,-1.89359266896])]

【问题讨论】:

    标签: apache-spark machine-learning pyspark apache-spark-mllib


    【解决方案1】:

    这行得通:

    pred = linear_regression_model.predict(pca_transformed)
    

    pca_transformed 是 RDD 类型。

    function 处理 RDD 和数组的方式不同:

    def predict(self, x):
        """
        Predict the value of the dependent variable given a vector or
        an RDD of vectors containing values for the independent variables.
        """
        if isinstance(x, RDD):
            return x.map(self.predict)
        x = _convert_to_vector(x)
        return self.weights.dot(x) + self.intercept
    

    使用简单数组时,可能会出现维度不匹配的问题(如上题中的错误)。

    可以看出,如果 x 不是 RDD,它会被转换为向量。问题是除非你取 x[0],否则点积将不起作用。

    这是重现的错误:

    j = _convert_to_vector(pca_transformed.take(1))
    linear_regression_model.weights.dot(j) + linear_regression_model.intercept
    

    这很好用:

    j = _convert_to_vector(pca_transformed.take(1))
    linear_regression_model.weights.dot(j[0]) + linear_regression_model.intercept
    

    【讨论】:

      猜你喜欢
      • 2021-03-30
      • 2020-07-26
      • 2020-06-26
      • 2018-08-25
      • 2018-06-04
      • 2020-09-29
      • 1970-01-01
      • 1970-01-01
      • 2020-01-17
      相关资源
      最近更新 更多