【发布时间】:2021-03-25 12:22:21
【问题描述】:
我收到此错误但不知道如何解决它。我想为我的回归设置两个 x 变量,所以我将它们放在代码中。但是我收到了这个错误,不知道如何重塑我的数组来解决这个问题。
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score,mean_squared_error
X = maindf[['Graduate Degree','Asian American Population']].values.reshape(-1,1)
Y = maindf["Democrats 2016"].values.reshape(-1,1)
x_train, x_test, y_train, y_test, = train_test_split(X, Y,train_size=49, random_state=np.random)
DecisionTreeRegModel = DecisionTreeRegressor(max_depth=3).fit(x_train, y_train)
y_pred = DecisionTreeRegModel.predict(x_test)
from sklearn import tree
这是错误。
ValueError Traceback (most recent call last)
<ipython-input-85-9aaccff5b23d> in <module>
5 X = maindf[['Graduate Degree','Asian American Population']].values.reshape(-1,1)
6 Y = maindf["Democrats 2016"].values.reshape(-1,1)
----> 7 x_train, x_test, y_train, y_test, = train_test_split(X, Y,train_size=49, random_state=np.random)
8 DecisionTreeRegModel = DecisionTreeRegressor(max_depth=3).fit(x_train, y_train)
9 y_pred = DecisionTreeRegModel.predict(x_test)
~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in train_test_split(*arrays, **options)
2125 raise TypeError("Invalid parameters passed: %s" % str(options))
2126
-> 2127 arrays = indexable(*arrays)
2128
2129 n_samples = _num_samples(arrays[0])
~\anaconda3\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
291 """
292 result = [_make_indexable(X) for X in iterables]
--> 293 check_consistent_length(*result)
294 return result
295
~\anaconda3\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
254 uniques = np.unique(lengths)
255 if len(uniques) > 1:
--> 256 raise ValueError("Found input variables with inconsistent numbers of"
257 " samples: %r" % [int(l) for l in lengths])
258
ValueError: Found input variables with inconsistent numbers of samples: [100, 50]
【问题讨论】:
-
您是否检查过
X.shape和y.shape以了解为什么train_test_split看到一个输入有100 行而另一个只有50 行? -
旁白:推荐
.to_numpy()优于.values。见docs。
标签: python arrays numpy scikit-learn jupyter-notebook