【问题标题】:How to detect which values are allowable in a parameter grid?如何检测参数网格中允许哪些值?
【发布时间】:2017-12-04 03:39:21
【问题描述】:

我已经开始从事一个项目,其中我需要检测给定 scikit-learn 估计器的 可训练 参数,并且如果可能的话,找到分类变量的允许值(以及连续变量的合理间隔)。

我可以使用estimator.get_params()获取带有参数的字典,然后使用estimator.set_params(**{'var1':val1, 'var2':val2})设置一个值,依此类推。

例如,对于 KNN 分类器,我们有以下参数字典: {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}.

现在,我可以使用值的类型来推断是分类的(str 类型)、连续的(float)、离散的(int)等等。一个可能相关的问题是默认设置为NoneType 的参数,但出于充分的理由,我可能不会碰这些。

现在的挑战变成了推断和定义参数网格以用于例如RandomizedSearchCV。对于离散和连续变量,问题可以使用例如try-except 块与 scipy.stats 模块的组合,可能将间隔限制在默认值附近的“附近”(但同时注意不要将例如 n_jobs 设置为一些疯狂的价值——可能需要硬编码,或者稍后明确设置)。如果您有类似的经验,并且有一些技巧/窍门,我很想听听。

但现在真正的问题是:如何推断例如algorithm 允许的值实际上是{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}??

我刚刚开始研究这个问题,如果我们尝试将它设置为一些不允许的值,也许我们可以解析得到的错误消息?我在这里寻找好主意,因为我想避免手动执行此操作(如果必须,我会这样做,但这似乎相当不雅......)

【问题讨论】:

  • 自我注意:这可能是一个非常困难/无法解决的问题。我已经浏览了 api 和源代码,并查看了例如auto-sklearn 解决了这个问题。目前看来,手动(硬编码)解决方案是可行的方法。
  • 您遇到了有趣的问题。除了parsing the signature and default parameters,我想我会尝试解析 scikit-learn 的文档字符串,如this。要尝试的另一件事是解析字符串化函数,例如__init__ 估计器,但这是一个 - 乱七八糟的远景,因为我没有看到在那里进行任何检查&你可能需要查看整个层次结构。
  • 您好!很高兴你发现这个主题很有趣。是的,那是/是我考虑/正在考虑的选项之一(解析文档)。但是让我担心的是文档字符串编写方式的一致性,并且没有可以利用的强制约定(但我可能错了)。我可能会花一点时间来实现一个解析器并在一堆文档字符串上对其进行测试......
  • 是的,在查看了一些文档字符串后,我意识到这不是一件容易的事。有一些一致性,但不足以使这变得容易。祝你好运!让我们知道这是怎么回事!
  • 谢谢,我会保持这个线程打开并报告任何进展。周末愉快!

标签: python scikit-learn parameter-passing cross-validation


【解决方案1】:

我找到了针对我正在查看的特定示例的解决方案,但是,它并不能很好地推广到其他文档字符串,因为对于 sklearn 中的每个估计器的编写方式没有设置约定。

因此,我发布我的“解决方案”,以便其他人可以接管并可能对其进行改进。看下面的sn-p:

import re
from pprint import pprint 
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
doc = knn.__doc__ # Get the doc string
#from sklearn.svm import SVC
#svc = SVC()
#doc = svc.__doc__
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern
re.compile(pattern)
matches = re.findall(pattern, doc)

clf_params = {}
previous_param = ''
for param, _, value in matches:
    if ":" in param and param[-4]!="_": # 'Hack-y'
        if param not in clf_params.keys():
            clf_params[param] = list()
            previous_param = param
        else:
            if len(value)>0:
                clf_params[previous_param].append(value)
pprint(clf_params)

这个 sn-p 打印

{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'],
 'leaf_size : ': [],
 'metric : ': [],
 'metric_params : ': [],
 'n_jobs : ': [],
 'n_neighbors : ': [],
 'p : ': [],
 'weights : ': ['uniform', 'distance']}

哪个是正确的。

但是,如果我们对SVC().__doc__ 重复相同的过程,我们会看到它失败了。

我希望有人觉得这有点用处。

【讨论】:

  • 这真的是 hack-y,很遗憾我发现这近三年就像 17 年一样贫瘠。无赖。
  • 好的,这是我从文档字符串中获取所有这些的尝试: str(Algorithm().__doc__).split('Parameters\n ----------\n') [1].split('\n\n Attributes\n')[0].replace('\n ', '\n') 这不会创建字典,但很简单,只需提取解释的“参数" 文档字符串中的部分,其中解释了所有参数并列出了所有可能/预期/接受的值输入,这些输入很好地被一个制表符缩进,现在剩下的就是从中获取缩进的行字符串,我相信我们可以管理。
【解决方案2】:

我尝试从文档字符串(LinearSVC 作为示例算法)中获取所有这些,这得到了splitlines() 的极大帮助:

liner = str(LinearSVC().__doc__).split('Parameters\n    ----------\n')[1].split('\n\n    Attributes\n')[0].replace('\n        ', '\n').splitlines()

这不会创建字典,但很简单,可以从文档字符串中仅提取解释过的“参数”部分,该部分解释了所有参数并列出了所有可能/预期/接受的值输入,这很好缩进一个,制表符,现在我们可以使用带有条件的简单循环,使用“:”作为我们的锚来识别可能/预期/接受的值输入的行:

for i in liner:
   ...:     if " : " in i: #<<< the key is to use " : " as our anchor
   ...:         print(i)

最终结果,打印到:

    penalty : str, 'l1' or 'l2' (default='l2')
    loss : str, 'hinge' or 'squared_hinge' (default='squared_hinge')
    dual : bool, (default=True)
    tol : float, optional (default=1e-4)
    C : float, optional (default=1.0)
    multi_class : str, 'ovr' or 'crammer_singer' (default='ovr')
    fit_intercept : bool, optional (default=True)
    intercept_scaling : float, optional (default=1)
    class_weight : {dict, 'balanced'}, optional
    verbose : int, (default=0)
    random_state : int, RandomState instance or None, optional (default=None)
    max_iter : int, (default=1000)

很高兴我可以分享,如果其他人需要完整的文档字符串参数打印输出,只需使用:

print(str(LinearSVC().__doc__).split('Parameters\n    ----------\n')[1].split('\n\n    Attributes\n')[0].replace('\n        ', '\n'))

编辑: 如果这不打算打印出来 - 将它作为字符串对象的最佳方法是使用列表推导,但它需要一些难看的替换,因为文档字符串中有 extensive 符号:

docstring_short = str([i for i in liner.splitlines() if " : " in i]).replace('["    ', '').replace('    ', ',\n').replace('", "', '').replace('", \'', '').replace("', '", '').replace("', \"", '').replace(']', '')

【讨论】:

    猜你喜欢
    • 2019-11-03
    • 2020-10-13
    • 2021-03-27
    • 1970-01-01
    • 2019-08-26
    • 1970-01-01
    • 2014-09-03
    • 2021-05-02
    • 2015-04-28
    相关资源
    最近更新 更多