【发布时间】:2017-08-14 19:22:41
【问题描述】:
我制作了一个名为 model 的树分类器,并尝试像这样使用 export graphviz 函数:
export_graphviz(decision_tree=model,
out_file='NT_model.dot',
feature_names=X_train.columns,
class_names=model.classes_,
leaves_parallel=True,
filled=True,
rotate=False,
rounded=True)
由于某种原因,我的运行引发了这个异常:
TypeError Traceback (most recent call last) <ipython-input-298-40fe56bb0c85> in <module>() 6 filled=True, 7 rotate=False, ----> 8 rounded=True) C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, max_depth, feature_names, class_names, label, filled, leaves_parallel, impurity, node_ids, proportion, rotate, rounded, special_characters) 431 recurse(decision_tree, 0, criterion="impurity") 432 else: --> 433 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) 434 435 # If required, draw leaf nodes at same depth as each other C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in recurse(tree, node_id, criterion, parent, depth) 319 out_file.write('%d [label=%s' 320 % (node_id, --> 321 node_to_str(tree, node_id, criterion))) 322 323 if filled: C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in node_to_str(tree, node_id, criterion) 289 np.argmax(value), 290 characters[2]) --> 291 node_string += class_name 292 293 # Clean up any trailing newlines TypeError: ufunc 'add' did not contain a loop with signature matching types dtype('<U90') dtype('<U90') dtype('<U90')
我的可视化超参数是:
print(model)
DecisionTreeClassifier(class_weight={1.0: 10, 0.0: 1}, criterion='gini',
max_depth=7, max_features=None, max_leaf_nodes=None,
min_impurity_split=1e-07, min_samples_leaf=50,
min_samples_split=2, min_weight_fraction_leaf=0.0,
presort=False, random_state=0, splitter='best')
print(model.classes_)
[ 0. , 1. ]
我们将不胜感激!
【问题讨论】:
-
确保您使用的是 scikit-learn 的更新版本。如果仍然遇到问题,那么您需要提供更多详细信息以便我们提供帮助。从错误的完整堆栈跟踪开始。然后提供用于训练
model的代码以及一些示例数据。 -
我使用的是安装在anaconda3上的版本
-
为我的问题添加了更多描述,谢谢通知我!
-
看起来与
model.classes_中存在的类名有关的错误。你能打印出来吗?