【问题标题】:Tree classifier to graphviz ERROR树分类器到 graphviz 错误
【发布时间】:2017-08-14 19:22:41
【问题描述】:

我制作了一个名为 model 的树分类器,并尝试像这样使用 export graphviz 函数:

export_graphviz(decision_tree=model,
                    out_file='NT_model.dot',
                    feature_names=X_train.columns,
                    class_names=model.classes_,
                    leaves_parallel=True,
                    filled=True,
                    rotate=False,
                    rounded=True)

由于某种原因,我的运行引发了这个异常:


TypeError                                 Traceback (most recent call last)
<ipython-input-298-40fe56bb0c85> in <module>()
      6                     filled=True,
      7                     rotate=False,
----> 8                     rounded=True)

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, 
max_depth, feature_names, class_names, label, filled, leaves_parallel, 
impurity, node_ids, proportion, rotate, rounded, special_characters)
    431             recurse(decision_tree, 0, criterion="impurity")
    432         else:
--> 433             recurse(decision_tree.tree_, 0, 
criterion=decision_tree.criterion)
    434
    435         # If required, draw leaf nodes at same depth as each other

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in recurse(tree, node_id, criterion, parent, 
depth)
    319             out_file.write('%d [label=%s'
    320                            % (node_id,
--> 321                               node_to_str(tree, node_id, 
criterion)))
    322
    323             if filled:

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in node_to_str(tree, node_id, criterion)
    289                                           np.argmax(value),
    290                                           characters[2])
--> 291             node_string += class_name
    292
    293         # Clean up any trailing newlines

TypeError: ufunc 'add' did not contain a loop with signature matching types 
dtype('<U90') dtype('<U90') dtype('<U90')

我的可视化超参数是:

print(model)
DecisionTreeClassifier(class_weight={1.0: 10, 0.0: 1}, criterion='gini',
        max_depth=7, max_features=None, max_leaf_nodes=None,
        min_impurity_split=1e-07, min_samples_leaf=50,
        min_samples_split=2, min_weight_fraction_leaf=0.0,
        presort=False, random_state=0, splitter='best')

print(model.classes_)
[ 0. , 1. ]

我们将不胜感激!

【问题讨论】:

  • 确保您使用的是 scikit-learn 的更新版本。如果仍然遇到问题,那么您需要提供更多详细信息以便我们提供帮助。从错误的完整堆栈跟踪开始。然后提供用于训练model 的代码以及一些示例数据。
  • 我使用的是安装在anaconda3上的版本
  • 为我的问题添加了更多描述,谢谢通知我!
  • 看起来与model.classes_ 中存在的类名有关的错误。你能打印出来吗?

标签: python-3.x scikit-learn


【解决方案1】:

正如您在documentation of export_graphviz 中指定的那样,参数class_names 适用于字符串,而不是浮点数或整数。

class_names : 字符串列表,bool 或 None,可选(默认=None)

尝试将model.classes_ 转换为字符串列表,然后再将它们传递给export_graphviz。

在对export_graphviz() 的调用中尝试class_names=['0', '1']class_names=['0.0', '1.0']

对于更通用的解决方案,请使用:

class_names=[str(x) for x in model.classes_]

但是,您在model.fit() 中将浮点值作为y 传递是否有特定原因?因为这在分类任务中大部分是不需要的。您是否有实际的 y 标签,或者您是否在拟合模型之前将字符串标签转换为数字?

【讨论】:

  • 这里的y标签原本是数字,作为二进制变量
猜你喜欢
  • 2011-02-07
  • 2017-05-03
  • 2019-10-12
  • 2020-09-22
  • 2021-10-30
  • 1970-01-01
  • 1970-01-01
  • 2019-10-08
  • 1970-01-01
相关资源
最近更新 更多