【问题标题】:Why my GaussianProcessRegressor model returns Constant predictions为什么我的 GaussianProcessRegressor 模型返回常量预测
【发布时间】:2020-09-12 23:41:42
【问题描述】:

我正在使用 Sklearn 库中的 GaussianProcessregressor 进行预测。我的 X_train 是一个包含 x 和 y 坐标的二维数组,y_train 是华氏温度向量(值在 30 到 60 F 之间,平均值为 42F),这是模型:

from sklearn.gaussian_process import GaussianProcessRegressor
length_scale_param=1.9
length_scale_bounds_param=(1e-05, 100000.0)
nu_param=2.5
matern=Matern(length_scale=length_scale_param, length_scale_bounds=length_scale_bounds_param, nu=nu_param)
gpr = GaussianProcessRegressor(kernel=matern,normalize_y=True)

我将 normalize_y 设置为 True 以获得等于我的数据的实际平均值 42 的先验平均值,而不是等于 0 的默认平均值。

我正在对 2D 网格进行预测:

rx, ry = np.arange(min(X[:,0]),max(X[:,0]),0.01), np.arange(min(X[:,1]),max(X[:,1]),0.01)
gx, gy = np.meshgrid(rx, ry)
X_2D = np.c_[gx.ravel(), gy.ravel()]

我得到以下曲面图:

正如您在该图中看到的那样,预测值是恒定的,并且始终等于平均值​​。

我尝试更改内核和内核参数,但仍然遇到同样的问题。

我还尝试将优化器设置为无(而不是用于通过最大化对数边际可能性来优化内核参数的默认优化器,当优化器=无时,内核的初始参数保持固定) ,我得到以下结果:

但在这里我必须实现网格搜索以更好地选择内核的初始参数(考虑到我的数据集的大小,这很耗时)。

我猜在第一种情况下优化器由于某种原因不能正常工作。

有什么建议吗?

这是我的 X_train :

array([[-0.07175708, -0.04827261],
       [ 0.20393194,  0.20058493],
       [ 0.3603364 ,  0.07715549],
       [ 0.17013275,  0.06315295],
       [ 0.09156826, -0.02107808],
       [-0.14215737,  0.01280404],
       [ 0.06130448, -0.13786868],
       [ 0.2392198 ,  0.1786702 ],
       [ 0.06257225, -0.00621065],
       [ 0.32712505,  0.25779511],
       [ 0.29779007, -0.08769269],
       [-0.14826638, -0.0370103 ],
       [ 0.41075394, -0.1100057 ],
       [ 0.34963454,  0.20687578],
       [ 0.4809849 , -0.20138262],
       [-0.19123097, -0.06000154],
       [-0.0335645 , -0.02598649],
       [ 0.47650189, -0.11234306],
       [ 0.35300743, -0.12135059],
       [ 0.15285929,  0.26463927],
       [ 0.25162424,  0.26882754],
       [-0.12485825, -0.02486853],
       [ 0.46869993,  0.01067606],
       [ 0.46410817, -0.17518689],
       [ 0.36756061,  0.1329964 ],
       [ 0.41387258,  0.06388724],
       [ 0.24489864,  0.1566825 ],
       [ 0.34972446,  0.22217119],
       [-0.10762011, -0.24574283],
       [ 0.43273621,  0.0916413 ],
       [ 0.39971044,  0.19253515],
       [ 0.35053608, -0.17008844],
       [ 0.02222162, -0.21485839],
       [ 0.30105785,  0.23001327],
       [ 0.05772036,  0.06681724],
       [-0.43849245,  0.1222685 ],
       [ 0.09869866,  0.02871409],
       [ 0.2033424 ,  0.1212952 ],
       [ 0.27993967,  0.22868547],
       [ 0.15177833,  0.23868958],
       [-0.21212757, -0.11004732],
       [ 0.44694002,  0.05587976],
       [ 0.21171764, -0.11056078],
       [ 0.02776326, -0.28147262],
       [ 0.44578859, -0.0587219 ],
       [ 0.29600242,  0.06741206],
       [ 0.27655553,  0.27980429],
       [ 0.20468395,  0.19475542],
       [ 0.38154889,  0.04721793],
       [ 0.01957093, -0.26531009],
       [ 0.05286766,  0.02185995],
       [ 0.3056768 ,  0.22414755],
       [ 0.16743847,  0.16073349],
       [ 0.05609711,  0.07843347],
       [ 0.41648273,  0.17360153],
       [ 0.18231324,  0.26745677],
       [ 0.14966242,  0.10538568],
       [ 0.02549186, -0.01958948],
       [-0.0352719 , -0.02737327],
       [ 0.16600666,  0.07729444],
       [-0.12564782, -0.12275318],
       [ 0.37777642,  0.24001348],
       [-0.27694849,  0.00378039],
       [ 0.44526109,  0.12339138],
       [ 0.3685266 , -0.09494673],
       [-0.1995266 , -0.02930646],
       [-0.12903661, -0.10557621],
       [ 0.1709348 , -0.01605571],
       [ 0.26204141,  0.00431368],
       [-0.07393948,  0.00719171],
       [ 0.25412697, -0.13938606],
       [ 0.21738421, -0.05103692],
       [-0.46865246,  0.11646383],
       [ 0.10859337, -0.24675289],
       [ 0.31137355, -0.01317134],
       [-0.32543566,  0.01758948],
       [ 0.1353631 ,  0.09693234],
       [ 0.22925417, -0.08178113],
       [ 0.19070138,  0.07616783],
       [ 0.35729195,  0.16464414],
       [-0.18762354, -0.1619709 ],
       [ 0.38675886, -0.05008602],
       [ 0.40249564,  0.18417801],
       [-0.26503112, -0.07816367],
       [-0.5       ,  0.1422947 ],
       [ 0.23234044,  0.15395552],
       [ 0.41635281,  0.28778189],
       [-0.00504366, -0.05262536],
       [-0.23091464, -0.15458275],
       [ 0.31935293,  0.15605484],
       [ 0.24921385, -0.05876454],
       [-0.39930397,  0.28697901],
       [ 0.05286766,  0.02185995],
       [ 0.12650071,  0.08691902],
       [-0.41328647,  0.11521126],
       [-0.02549319, -0.21558453],
       [ 0.38447761,  0.18176482],
       [-0.49606913,  0.04726729],
       [ 0.26226766,  0.09769927],
       [ 0.37959486,  0.16020508],
       [ 0.39688515,  0.28609912],
       [-0.21750272, -0.05315777],
       [-0.16742417,  0.31337447],
       [ 0.35049142,  0.16397509],
       [ 0.09923472, -0.05051281],
       [ 0.39039074, -0.00533958],
       [ 0.34954183,  0.070406  ],
       [-0.03250529, -0.09619029],
       [-0.02553826, -0.21512205],
       [ 0.32684651, -0.00806486],
       [-0.035674  , -0.10242529],
       [ 0.3840333 ,  0.19410431],
       [ 0.34593852,  0.03607444],
       [ 0.49294163, -0.19796509],
       [ 0.00115703, -0.10888053],
       [ 0.38564422, -0.05671838],
       [ 0.38633704,  0.15706933],
       [ 0.41442829,  0.07688914],
       [ 0.00182541, -0.18194074],
       [ 0.19541211,  0.19816678],
       [ 0.21203674,  0.03370675],
       [ 0.22605457, -0.0154448 ],
       [ 0.32304629,  0.04642338],
       [ 0.40787352,  0.12211336],
       [ 0.06104107, -0.26257386],
       [ 0.14581334,  0.17887325],
       [ 0.19600414, -0.0199909 ],
       [-0.11808573,  0.04732613],
       [ 0.42421385, -0.00113821],
       [ 0.23317682,  0.05307291],
       [ 0.07724509, -0.20107056],
       [ 0.05623529, -0.31337447],
       [-0.1586227 ,  0.29292413],
       [ 0.10418996,  0.01066445],
       [ 0.41380266, -0.07030375],
       [ 0.24685584,  0.10346794],
       [ 0.10166612,  0.13223216],
       [ 0.21053369,  0.02633374],
       [-0.35277745,  0.27849323],
       [-0.20414733, -0.0153229 ],
       [-0.26929086, -0.19337318],
       [ 0.26345883, -0.05154861],
       [ 0.13480402,  0.09701327],
       [ 0.2934898 ,  0.07205294],
       [-0.00824799,  0.03543839],
       [ 0.43831267,  0.21319967]])

这是 Y_train :

array([[39.9],
       [45.7],
       [46.1],
       [42.5],
       [43.5],
       [39.7],
       [42.9],
       [45.8],
       [42.6],
       [44.2],
       [45.2],
       [23.4],
       [49.3],
       [45. ],
       [48.6],
       [41.1],
       [39.9],
       [48.3],
       [48.5],
       [46.1],
       [45.5],
       [28.7],
       [49.1],
       [48.2],
       [44.2],
       [45.3],
       [44.9],
       [45.1],
       [43.3],
       [46.5],
       [45.3],
       [48.3],
       [43.4],
       [45.3],
       [41.9],
       [37.5],
       [41.9],
       [47.3],
       [45.3],
       [46.3],
       [36.7],
       [47.1],
       [46.1],
       [46.8],
       [49.3],
       [45.9],
       [46. ],
       [45.9],
       [44.4],
       [45. ],
       [37.7],
       [45.2],
       [46. ],
       [42.8],
       [45.2],
       [47.7],
       [45.3],
       [39. ],
       [39. ],
       [43.6],
       [26.3],
       [46.2],
       [40.4],
       [46.6],
       [48.4],
       [42.4],
       [36.6],
       [44.9],
       [43.5],
       [42.3],
       [46.4],
       [45.8],
       [39.4],
       [44.3],
       [45.2],
       [40.8],
       [45.7],
       [45.4],
       [42.9],
       [44.8],
       [30.4],
       [47.1],
       [44.7],
       [38.4],
       [38.2],
       [45.3],
       [45. ],
       [38.1],
       [42.5],
       [45.4],
       [44.6],
       [41.1],
       [38.2],
       [45.3],
       [40.2],
       [41.5],
       [48. ],
       [36.1],
       [44.7],
       [46.8],
       [45.6],
       [40.6],
       [43.5],
       [44.8],
       [42.6],
       [44.9],
       [43.2],
       [40.6],
       [41.5],
       [46. ],
       [41.7],
       [48.7],
       [49.6],
       [48.4],
       [41.3],
       [47.8],
       [47.3],
       [46.2],
       [43.8],
       [46.2],
       [44.9],
       [46.1],
       [44.5],
       [46.3],
       [43.2],
       [46.1],
       [44.1],
       [40. ],
       [47.3],
       [41.4],
       [46. ],
       [46. ],
       [44.4],
       [40.7],
       [44.5],
       [45.2],
       [43.9],
       [44.1],
       [42.9],
       [42.4],
       [40.6],
       [42.7],
       [45.2],
       [45. ],
       [42.4],
       [46. ]])

【问题讨论】:

  • 您能分享数据以便我们重新创建预测吗?
  • 我在帖子中添加了一些数据点。

标签: python scikit-learn regression gaussian


【解决方案1】:

鉴于模型上面的数据需要一个噪声项来提高性能。通过将白色内核添加到母体中,我得到了与平均值不同的预测。下面是两者的对比:

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

length_scale_param=1.9
length_scale_bounds_param=(1e-05, 100000.0)
nu_param=2.5
matern=Matern(length_scale=length_scale_param,
              length_scale_bounds=length_scale_bounds_param,nu=nu_param)
kernel = matern + WhiteKernel()
gpr_0 = GaussianProcessRegressor(kernel=matern,normalize_y=True,
                                 n_restarts_optimizer=0)
gpr_1 =  GaussianProcessRegressor(kernel=kernel,normalize_y=True,
                                  n_restarts_optimizer=0)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
                                                    random_state=42)

gpr_0.fit(X_train,y_train)
gpr_1.fit(X_train,y_train)                                                    

y_pred_0 = gpr_0.predict(X_test)
y_pred_1 = gpr_1.predict(X_test)

plt.scatter(y_test,y_pred_0,label='matern only')
plt.scatter(y_test,y_pred_1,label='matern + noise kernel')
plt.plot(np.arange(y.min(),y.max(),1),np.arange(y.min(),y.max(),1),'--',
         color='grey')
plt.xlabel('y_test')
plt.xlabel('y_pred')
plt.legend(frameon=False)

结果如下:

我还建议设置n_restarts_optimizer=9 以允许更多迭代。默认值为n_restarts_optimizer=0,仅允许一次迭代。

【讨论】:

  • 谢谢,添加 whitekernel 有帮助。这里我使用 MSE 来评估模型。当我设置 nu=2.5 时,我得到 MSE=6.39,但是当我设置 nu=0.1 时,我得到 MSE=3.72。 nu 参数是否也在训练阶段进行了优化?鉴于文档 nu 控制函数的平滑度,它通常采用 [0.5, 1.5, 2.5, inf] 中的值。如何为 nu 参数选择提供最佳 MSE 分数的最佳值!
  • 你在测试集上得到了那种改进吗?也许较小的nu 会导致过度拟合?根据docs nu 没有优化。
  • 您选择超参数值的问题是model selection 的问题。首先,我建议尝试gridsearchcv。另请参阅question
  • 是的,我在测试集上得到了改进,我正在使用网格搜索来选择最好的 nu。根据一些文献评论,Matern 内核是最适合回归问题的内核。也许,我宁愿实现网格搜索来选择内核类型及其参数。谢谢,我会检查文档。
猜你喜欢
  • 2019-10-01
  • 2019-05-28
  • 2017-12-08
  • 2021-02-11
  • 1970-01-01
  • 2020-01-17
  • 2020-05-12
  • 2020-06-26
  • 1970-01-01
相关资源
最近更新 更多