【问题标题】:extract weights from a RWeka SMOreg model从 RWeka SMOreg 模型中提取权重
【发布时间】:2018-04-21 00:00:37
【问题描述】:

我正在使用很棒的 RWeka 包来适应在 Weka 中实现的 SMOreg 模型。虽然一切正常,但我在从拟合模型中提取权重时遇到了一些问题。

作为所有 Weka 分类器对象,我的模型有一个很好的打印方法,可以显示所有特征及其相对权重。但是,我无法以任何方式提取这些权重。

你可以通过运行以下代码自己查看:

library(RWeka)
data("mtcars")
SMOreg_classifier <- make_Weka_classifier("weka/classifiers/functions/SMOreg")
model_SMOreg <- SMOreg_classifier(mpg ~ ., data = mtcars)

现在,如果您只是调用模型

model_SMOreg

您会看到它打印了模型中使用的所有特征及其相对权重。我想将这些权重作为向量访问,或者更好的是,作为 2 列的表,其中一列包含特征名称,另一列包含权重。

我正在使用 Windows 7 x64 系统,使用 RStudio 版本 1.0.153、R 3.4.2 Short Summer 和 RWeka 0.4-35。

有人知道怎么做吗?

【问题讨论】:

    标签: r svm weka rweka


    【解决方案1】:

    根据@knb 的建议,我编写了一个函数来从 SMOreg 模型中提取权重,并返回一个小标题,其中一列用于特征名称,一列用于特征权重,行排列在绝对值之后重量。

    请注意,此功能仅适用于 SMOreg 分类器,因为其他分类器的输出在布局方面略有不同。但是,我认为该功能可以轻松地适应其他分类器。

    library(stringr)
    library(tidyverse)
    
    extract_weights_from_SMOreg <- function(model) {
    
      oldw <- getOption("warn")
      options(warn = -1)
    
    
      raw_output <- capture.output(model)
      trimmed_output <- raw_output[-c(1:3,(length(raw_output) - 4): length(raw_output))]
      df <- data_frame(features_name = vector(length = length(trimmed_output) + 1, "character"), 
                       features_weight = vector(length = length(trimmed_output) + 1, "numeric"))
    
      for (line in 1:length(trimmed_output)) {
    
    
        string_as_vector <- trimmed_output[line] %>%
          str_split(string = ., pattern = " ") %>%
          unlist(.)
    
    
        numeric_element <- trimmed_output[line] %>%
          str_split(string = ., pattern = " ") %>%
          unlist(.) %>%
          as.numeric(.)
    
        position_mul <- string_as_vector[is.na(numeric_element)] %>%
          str_detect(string = ., pattern = "[*]") %>%
          which(.)
    
        numeric_element <- numeric_element %>%
          `[`(., c(1:position_mul))
    
        text_element <- string_as_vector[is.na(numeric_element)]
    
    
        there_is_plus <- string_as_vector[is.na(numeric_element)] %>%
          str_detect(string = ., pattern = "[+]") %>%
          sum(.)
    
        if (there_is_plus) { sign_is <- "+"} else { sign_is <- "-"}
    
    
    
        feature_weight <- numeric_element[!is.na(numeric_element)]
    
        if (sign_is == "-") {df[line, "features_weight"] <- feature_weight * -1} else {df[line, "features_weight"] <- numeric_element[!(is.na(numeric_element))]}
    
        df[line, "features_name"] <- paste(text_element[(position_mul + 1): length(text_element)], collapse = " ")
    
      }
    
      intercept_line <- raw_output[length(raw_output) - 4]
    
    
      there_is_plus_intercept <- intercept_line %>%
        str_detect(string = ., pattern = "[+]") %>%
        sum(.)
    
      if (there_is_plus_intercept) { intercept_sign_is <- "+"} else { intercept_sign_is <- "-"}
    
      numeric_intercept <- intercept_line %>%
        str_split(string = ., pattern = " ") %>%
        unlist(.) %>%
        as.numeric(.) %>%
        `[`(., length(.))
    
      df[nrow(df), "features_name"] <- "intercept"
    
      if (intercept_sign_is == "-") {df[nrow(df), "features_weight"] <- numeric_intercept * -1} else {df[nrow(df), "features_weight"] <- numeric_intercept}
    
      options(warn = oldw)
    
      df <- df %>%
        arrange(desc(abs(features_weight)))
    
      return(df)
    }
    

    这里是一个模型的示例

    library(RWeka)
    data("mtcars")
    SMOreg_classifier <- make_Weka_classifier("weka/classifiers/functions/SMOreg")
    mpg_model_weights <- extract_weights_from_SMOreg(SMOreg_classifier(data = mtcars, mpg ~ .))
    mpg_model_weights 
    

    【讨论】:

      【解决方案2】:

      我认为您无法以数字格式获取此信息。

      attr(model_SMOreg, "meta")$class                      #  "Weka_classifier"
      
      getAnywhere("print.Weka_classifier")
      

      结果:

      A single object matching ‘print.Weka_classifier’ was found
      It was found in the following places
        registered S3 method for print from namespace RWeka
        namespace:RWeka
      with value
      
      function (x, ...) 
      {
          writeLines(.jcall(x$classifier, "S", "toString"))
          invisible(x)
      }
      <bytecode: 0x8328630>
      <environment: namespace:RWeka>
      

      所以我们看到:print.Weka_classifier() 调用 .writeLines(),然后调用 rJava::.jcall,返回一个字符串。

      因此,我认为您需要自己解析权重,也许通过调用capture.output() 方法。

      【讨论】:

      • 谢谢!我不知道getAnywhere 命令也不知道capture.output,我知道它们有多么有用。
      猜你喜欢
      • 1970-01-01
      • 2017-06-18
      • 1970-01-01
      • 2016-07-24
      • 2021-04-18
      • 1970-01-01
      • 2020-06-19
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多