【发布时间】:2017-12-04 03:39:21
【问题描述】:
我已经开始从事一个项目,其中我需要检测给定 scikit-learn 估计器的 可训练 参数,并且如果可能的话,找到分类变量的允许值(以及连续变量的合理间隔)。
我可以使用estimator.get_params()获取带有参数的字典,然后使用estimator.set_params(**{'var1':val1, 'var2':val2})设置一个值,依此类推。
例如,对于 KNN 分类器,我们有以下参数字典:
{'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}.
现在,我可以使用值的类型来推断是分类的(str 类型)、连续的(float)、离散的(int)等等。一个可能相关的问题是默认设置为NoneType 的参数,但出于充分的理由,我可能不会碰这些。
现在的挑战变成了推断和定义参数网格以用于例如RandomizedSearchCV。对于离散和连续变量,问题可以使用例如try-except 块与 scipy.stats 模块的组合,可能将间隔限制在默认值附近的“附近”(但同时注意不要将例如 n_jobs 设置为一些疯狂的价值——可能需要硬编码,或者稍后明确设置)。如果您有类似的经验,并且有一些技巧/窍门,我很想听听。
但现在真正的问题是:如何推断例如algorithm 允许的值实际上是{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}??
我刚刚开始研究这个问题,如果我们尝试将它设置为一些不允许的值,也许我们可以解析得到的错误消息?我在这里寻找好主意,因为我想避免手动执行此操作(如果必须,我会这样做,但这似乎相当不雅......)
【问题讨论】:
-
自我注意:这可能是一个非常困难/无法解决的问题。我已经浏览了 api 和源代码,并查看了例如auto-sklearn 解决了这个问题。目前看来,手动(硬编码)解决方案是可行的方法。
-
您遇到了有趣的问题。除了parsing the signature and default parameters,我想我会尝试解析 scikit-learn 的文档字符串,如this。要尝试的另一件事是解析字符串化函数,例如
__init__估计器,但这是一个 - 乱七八糟的远景,因为我没有看到在那里进行任何检查&你可能需要查看整个层次结构。 -
您好!很高兴你发现这个主题很有趣。是的,那是/是我考虑/正在考虑的选项之一(解析文档)。但是让我担心的是文档字符串编写方式的一致性,并且没有可以利用的强制约定(但我可能错了)。我可能会花一点时间来实现一个解析器并在一堆文档字符串上对其进行测试......
-
是的,在查看了一些文档字符串后,我意识到这不是一件容易的事。有一些一致性,但不足以使这变得容易。祝你好运!让我们知道这是怎么回事!
-
谢谢,我会保持这个线程打开并报告任何进展。周末愉快!
标签: python scikit-learn parameter-passing cross-validation