【问题标题】:sklearn RidgeCV with sample_weightsklearn RidgeCV 与 sample_weight
【发布时间】:2015-10-28 06:29:12
【问题描述】:

我正在尝试使用 sklearn 进行加权岭回归。但是,当我调用 fit 方法时,代码会中断。我得到的例外是:

Exception: Data must be 1-dimensional

但我确信(通过检查打印语句)我传递的数据具有正确的形状。

print temp1.shape       #(781, 21)
print temp2.shape       #(781,)
print weights.shape     #(781,)

result=RidgeCV(normalize=True).fit(temp1,temp2,sample_weight=weights)

可能出了什么问题??

这是整个输出:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-65-a5b1eba5d9cf> in <module>()
     22 
     23 
---> 24     result=RidgeCV(normalize=True).fit(temp2,temp1, sample_weight=weights)
     25 
     26 

/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/ridge.pyc in fit(self, X, y, sample_weight)
    868                                   gcv_mode=self.gcv_mode,
    869                                   store_cv_values=self.store_cv_values)
--> 870             estimator.fit(X, y, sample_weight=sample_weight)
    871             self.alpha_ = estimator.alpha_
    872             if self.store_cv_values:

/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/ridge.pyc in fit(self, X, y, sample_weight)
    793                               else alpha)
    794             if error:
--> 795                 out, c = _errors(weighted_alpha, y, v, Q, QT_y)
    796             else:
    797                 out, c = _values(weighted_alpha, y, v, Q, QT_y)

/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/ridge.pyc in _errors(self, alpha, y, v, Q, QT_y)
    685         w = 1.0 / (v + alpha)
    686         c = np.dot(Q, self._diag_dot(w, QT_y))
--> 687         G_diag = self._decomp_diag(w, Q)
    688         # handle case where y is 2-d
    689         if len(y.shape) != 1:

/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/ridge.pyc in _decomp_diag(self, v_prime, Q)
    672     def _decomp_diag(self, v_prime, Q):
    673         # compute diagonal of the matrix: dot(Q, dot(diag(v_prime), Q^T))
--> 674         return (v_prime * Q ** 2).sum(axis=-1)
    675 
    676     def _diag_dot(self, D, B):

/usr/local/lib/python2.7/dist-packages/pandas/core/ops.pyc in wrapper(left, right, name)
    531             return left._constructor(wrap_results(na_op(lvalues, rvalues)),
    532                                      index=left.index, name=left.name,
--> 533                                      dtype=dtype)
    534     return wrapper
    535 

/usr/local/lib/python2.7/dist-packages/pandas/core/series.pyc in __init__(self, data, index, dtype, name, copy, fastpath)
    209             else:
    210                 data = _sanitize_array(data, index, dtype, copy,
--> 211                                        raise_cast_failure=True)
    212 
    213                 data = SingleBlockManager(data, index, fastpath=True)

/usr/local/lib/python2.7/dist-packages/pandas/core/series.pyc in _sanitize_array(data, index, dtype, copy, raise_cast_failure)
   2683     elif subarr.ndim > 1:
   2684         if isinstance(data, np.ndarray):
-> 2685             raise Exception('Data must be 1-dimensional')
   2686         else:
   2687             subarr = _asarray_tuplesafe(data, dtype=dtype)

Exception: Data must be 1-dimensional

【问题讨论】:

    标签: scikit-learn


    【解决方案1】:

    错误似乎是由于 sample_weights 是 Pandas 系列而不是 numpy 数组:

    from sklearn.linear_model import RidgeCV
    
    temp1 = pd.DataFrame(np.random.rand(781, 21))
    temp2 = pd.Series(temp1.sum(1))
    weights = pd.Series(1 + 0.1 * np.random.rand(781))
    
    result = RidgeCV(normalize=True).fit(temp1, temp2, 
                                         sample_weight=weights)
    # Exception: Data must be 1-dimensional
    

    如果您改用 numpy 数组,错误就会消失:

    result = RidgeCV(normalize=True).fit(temp1, temp2,
                                         sample_weight=weights.values)
    

    这似乎是一个错误;我已经打开了scikit-learn issue 来报告这个问题。

    【讨论】:

      猜你喜欢
      • 2019-02-05
      • 2020-01-16
      • 1970-01-01
      • 2014-05-11
      • 2018-09-09
      • 2017-01-20
      • 2020-12-20
      • 2019-08-12
      • 2016-03-27
      相关资源
      最近更新 更多