【问题标题】:printing a gridsearch in 3D for hyperparameter visualization打印 3D 网格搜索以实现超参数可视化
【发布时间】:2020-03-08 07:13:30
【问题描述】:

试图可视化超参数及其结果,我无法将它们绘制在 3d 图中

我尝试构建一个函数为:

PlotGridSearch(grid,xparam,yparam,zlabels):

被称为

gs= GridSearchCV(DecisionTreeClassifier()
                 ,HyperParams
                 , scoring='accuracy'
                 , cv=50).fit(train_data,train_labels)

 PlotGridSearch(gs
                ,'param_max_depth'
                ,'param_max_leaf_nodes'
                ,'mean_test_score')

但我可以使用从 param_max_depth y param_max_leaf_nodes 提取的正确标签将 mean_test_score 列转换为必要的矩阵(二维数组)

有什么建议吗?

【问题讨论】:

    标签: python numpy matplotlib scikit-learn mplot3d


    【解决方案1】:

    有一个关于3D surface matplotlib plot 的不错的官方文档。

    # This import registers the 3D projection, but is otherwise unused.
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
    
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from matplotlib.ticker import LinearLocator, FormatStrFormatter
    import numpy as np
    
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    # Make data.
    X = np.arange(-5, 5, 0.25)
    Y = np.arange(-5, 5, 0.25)
    X, Y = np.meshgrid(X, Y)
    R = np.sqrt(X**2 + Y**2)
    Z = np.sin(R)
    
    # Plot the surface.
    surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)
    
    # Customize the z axis.
    ax.set_zlim(-1.01, 1.01)
    ax.zaxis.set_major_locator(LinearLocator(10))
    ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
    
    # Add a color bar which maps values to colors.
    fig.colorbar(surf, shrink=0.5, aspect=5)
    
    plt.show()
    

    如果你在 Jupyter Notebook 中运行 python,你可以使用%matplotlib notebook 使其交互。

    在您的情况下,如果您想获得正确的 X、Y 和 Z,您可以从 gs.cv_results_ 获取值,即:

    m = len(param_grid['max_depth'])
    n = len(param_grid['max_leaf_nodes'])
    X = np.reshape(gs.cv_results_['param_max_depth'].data,[n,m]) # do mind the order of reshape, it might diff
    Y = np.reshape(gs.cv_results_['param_max_leaf_nodes'].data,[n,m])
    Z = np.reshape(gs.cv_results_['mean_test_score'],[n,m])
    

    【讨论】:

    • 是的,我知道,我遇到的问题是将数据设为您在此处的 X、Y、Z、R、Z 示例
    • gs= GridSearchCV 只返回一个包含所有列的表,不按轴分组
    • @figarom。您可以使用numpy.reshpe 来获取正确的值,但请记住检查参数的顺序。我添加了另一个示例来展示如何做到这一点。
    猜你喜欢
    • 2023-03-09
    • 2016-07-12
    • 2019-04-17
    • 2019-10-15
    • 1970-01-01
    • 2019-04-17
    • 1970-01-01
    • 2021-08-24
    • 2015-07-06
    相关资源
    最近更新 更多