【发布时间】:2018-04-05 13:34:28
【问题描述】:
我在developer guide 之后实现了几个自定义估算器,因此它们都继承自BaseEstimator。其中一些使用其他 scikit-learn 估计器或转换器作为属性(例如,构建一个集成)。从 BaseEstimator 继承应该让我可以方便地通过 get_params() 访问参数并通过 set_params() 设置它们,如here 所述,以 component__parameter 形式,例如用于网格搜索。在下面找到一个最小的示例。
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression
class MyForecaster(BaseEstimator):
def __init__(self, base_estimator=LinearRegression()):
self.base_estimator = base_estimator
def fit(self, X, y):
pass
def predict(self, X, y):
pass
# instantiate forecaster and set parameters
mf = MyForecaster()
mf.set_params(**{"base_estimator" : "ElasticNet", "base_estimator__alpha": 0.05})
这失败了:
ValueError: Invalid parameter alpha for estimator LinearRegression. Check the list of available parameters with `estimator.get_params().keys()`.
这表明它尝试为嵌套属性设置参数 first,而不是先检查是否要覆盖“更高级别”属性(ElasticNet 具有属性 alpha,而不是 LinearRegression)。
处理此问题的一种方法是为每个估算器覆盖 set_params(),以确保正确处理它。
是否有任何“内置”方法来实现这一点,而我只是忽略了另一种解决方案?这真的是 scikit-learn 的预期行为吗?
编辑:
确实由于一些非常大的巧合,一个非常相似的问题似乎已在 0.19.1 版本中得到修复。但是,我的特殊情况仍然失败,只有 Pipelines 的情况是固定的!
为了使其可重现,我将 set_params() 的当前代码复制到我的最小示例中(仅在第 20 行添加了注释)
1 def set_params(self, **params):
2 if not params:
3 # Simple optimization to gain speed (inspect is slow)
4 return self
5 valid_params = self.get_params(deep=True)
6
7 nested_params = defaultdict(dict) # grouped by prefix
8 for key, value in params.items():
9 key, delim, sub_key = key.partition('__')
10 if key not in valid_params:
11 raise ValueError('Invalid parameter %s for estimator %s. '
12 'Check the list of available parameters '
13 'with `estimator.get_params().keys()`.' %
14 (key, self))
15
16 if delim:
17 nested_params[key][sub_key] = value
18 else:
19 setattr(self, key, value)
20 #valid_params[key] = value
21
22 for key, sub_params in nested_params.items():
23 valid_params[key].set_params(**sub_params)
24
25 return self
它失败了,因为它会在第 19 行设置属性,但由于它没有更新 valid_params,它仍然会在下一次迭代中尝试设置属性时失败。所以我添加了第 20 行来解决这个问题。 它确实可以在当前的 0.19.1 修复中进行测试,因为它仅针对 Pipelines 进行了测试。 Here,set_param() 被覆盖为第一次调用 _BaseComposition 的 _set_param(),显然这是处理的。
我应该在 scikit-learn github 中提出这个问题还是重新打开其他问题?
【问题讨论】:
标签: scikit-learn