【发布时间】:2016-11-15 09:23:09
【问题描述】:
假设我已经安装了scikit's SGDC,从文档中我读到predict_proba() 函数返回一个概率估计向量,因此我做了以下操作:
In:
proba = clf.predict_proba(X_test)
print('proba:',proba.shape)
print(type(prediction))
Out:
proba: (292683, 39)
<class 'numpy.ndarray'>
但是,我不明白为什么proba 有那个维度
(292683, 39),安装于(292683,)。 那么,我的问题是我应该如何返回每个分类实例的概率? 例如,一个包含每个分类实例概率的向量:
.9098
.6789
.2346
.4545
...
.9076
更新
这是我的实际输出:
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38
1.6032895251736538e-09,0.0027001605689774967,1.3127275209812045e-05,0.0004133169272159469,6.421335538574734e-05,0.01244940641130727,4.971270475822253e-05,0.06927362982555345,0.05447770875726582,0.0002585581503775057,1.30512865257421e-05,0.00015347845576367026,0.004231831363568738,0.003134713706992086,0.00017618959500039568,0.004525087952898131,0.07230938415776024,0.004255936398577753,0.0006231217282368267,0.07381737590135892,1.7062740932146373e-05,0.04873946029933614,2.2579270275470988e-05,0.04738213671381574,0.011041250070307537,0.06786077438113797,0.008012001696580576,0.0009697583063038865,0.002640793732663328,0.00041955324710243576,0.005333452308762462,0.0023973060671898918,0.24386456744298726,1.2930500605063882e-05,0.010271860113445061,0.10478318644646997,0.1096803752152842,0.029709960729470408,0.0039009845913073
...
2.70775531177066e-05,0.056826721550724914,0.00021452452508401623,0.005773421211249144,0.03601322253697087,0.03387846954273534,0.0002233544773721261,0.0009621520077239175,0.005573279378280768,0.0011059321386392307,0.00014906386779747047,0.0007207742574711379,0.018149812871977058,0.017479374046348212,0.0004917497325634417,0.009446560753589354,0.37652447022205116,0.008895752894288417,0.00136242543496297,0.1961349850670937,0.011158949542858676,0.0010422870520728268,4.0487954942671204e-05,0.013908461124574075,0.005521009748034979,0.019087261334748272,0.00355886145992077,0.0054657023293853595,0.004395464092632666,0.00018729724505224616,0.0015209690844465442,0.003930224604070839,0.03922346296961368,2.1100171629256666e-05,0.001026959174556334,0.09177893762051553,0.021131552685297615,0.0007056741594152797,0.006342213576191516
【问题讨论】:
标签: python python-3.x numpy machine-learning scikit-learn