【发布时间】:2019-08-17 15:50:50
【问题描述】:
我正在尝试创建一个混淆矩阵。
我的数据如下:
class Growth Negative Neutral
1 Growth 0.3082588 0.2993632 0.3923780
2 Neutral 0.4696949 0.2918042 0.2385009
3 Negative 0.3608549 0.2679748 0.3711703
4 Neutral 0.3636836 0.2431433 0.3931730
5 Growth 0.4325862 0.2011520 0.3662619
6 Negative 0.2939859 0.2397171 0.4662970
其中class 是“真实”观察到的结果,Growth、Negative 和Neutral 是模型预测它属于这些类别的概率。即在第一行中Neutral 的结果是0.3923780,因此模型会错误地预测这个类,而实际上它是Growth。
我通常会使用caret 中的confusionMatrix() 函数,但我的数据方式略有不同。我是否应该创建一个名为pred_class 的新列,其中包含最高值的列?比如:
class Growth Negative Neutral pred_class
1 Growth 0.3082588 0.2993632 0.3923780 Neutral
2 Neutral 0.4696949 0.2918042 0.2385009 Growth
3 Negative 0.3608549 0.2679748 0.3711703 Neutral
4 Neutral 0.3636836 0.2431433 0.3931730 Neutral
5 Growth 0.4325862 0.2011520 0.3662619 Growth
6 Negative 0.2939859 0.2397171 0.4662970 Neutral
然后我可以做类似confusionMatrix(df$pred_class, df$class) 的事情。如何编写函数以根据最高概率将列名粘贴到列中?
数据:
df <- structure(list(class = c("Growth", "Neutral", "Negative", "Neutral",
"Growth", "Negative", "Neutral", "Neutral", "Neutral", "Neutral",
"Neutral", "Negative", "Neutral", "Growth", "Growth", "Growth",
"Negative", "Negative", "Growth", "Negative"), Growth = c(0.308258818045192,
0.469694864370061, 0.360854910973552, 0.363683641698332, 0.43258619401693,
0.2939858517149, 0.397951949316298, 0.235376278828237, 0.3685791718903,
0.330295647415191, 0.212072592205125, 0.220703558050626, 0.389445269278106,
0.286933037813081, 0.315659629884986, 0.30185119811882, 0.273429057319956,
0.277357131556229, 0.339004410008943, 0.407114176119814), Negative = c(0.299363167088292,
0.291804233603859, 0.267974798034839, 0.243143322044808, 0.201151951415105,
0.239717129555608, 0.351629585705591, 0.258325790152011, 0.281660024058527,
0.189920159505041, 0.265058882513953, 0.433664278547707, 0.114765460651494,
0.402354633060689, 0.370370354887748, 0.3239536031819, 0.3279406609037,
0.327198131828346, 0.298583999674218, 0.337902573718712), Neutral = c(0.392378014866516,
0.23850090202608, 0.371170290991609, 0.39317303625686, 0.366261854567965,
0.466297018729492, 0.250418464978111, 0.506297931019752, 0.349760804051173,
0.479784193079769, 0.522868525280922, 0.345632163401667, 0.4957892700704,
0.31071232912623, 0.313970015227266, 0.374195198699279, 0.398630281776344,
0.395444736615424, 0.362411590316838, 0.254983250161474)), row.names = c(NA,
20L), class = "data.frame")
【问题讨论】:
标签: r