【问题标题】:How to "remember" categorical encodings for actual predictions after training?训练后如何“记住”实际预测的分类编码?
【发布时间】:2021-03-03 21:18:39
【问题描述】:

假设想在一些数据集上训练机器学习算法,包括一些分类参数。 (机器学习新手,但我的想法是……)即使将所有分类数据转换为 1-hot-encoded 向量,训练后如何“记住”这个编码图?

例如。在训练之前将初始数据集转换为使用 1-hot 编码,比如

universe of categories for some column c is {"good","bad","ok"}, so convert rows to
[1, 2, "good"] ---> [1, 2, [1, 0, 0]],
[3, 4, "bad"]  ---> [3, 4, [0, 1, 0]],
... 

,在训练模型之后,所有未来的预测输入都需要对 c 列使用相同的编码方案。

那么在未来的预测中,数据输入如何记住该映射(“好”映射到索引 0 等)(特别是在计划使用keras RNN 或 LSTM 模型时)?我是否需要将它保存在某个地方(例如 python pickle)(如果是,我如何获得显式映射)?或者有没有办法让模型在内部自动处理分类输入,以便在训练和未来使用期间只输入原始标签数据?

如果这个问题中的任何内容表明我对某事有任何严重的困惑,请告诉我(再次强调,对于 ML 来说非常陌生)。

** 不确定这是否属于https://stats.stackexchange.com/,但在这里发布是因为特别想知道如何处理这个问题的实际代码实现。

【问题讨论】:

  • 记住这个映射并将其用于未来的预测
  • @rvinas 正计划这样做。 Jut 想确保不会错过其他一些最佳实践方法。已经提供了我自己打算做的草稿的答案。如果那里的基本策略有问题,请告诉我。

标签: machine-learning keras


【解决方案1】:

我一直在做的事情如下:

使用 StringIndexer.fit() 后,您可以保存其元数据(包括实际的编码器映射,例如“好”是第一列)

这是我使用的以下代码(使用java,但可以调整为python):

StringIndexerModel sim = new StringIndexer()
        .setInputCol(field)
        .setOutputCol(field + "_INDEX")
        .setHandleInvalid("skip")
        .fit(dataset);

sim.write().overwrite().save("IndexMappingModels/" + field + "_INDEX");

之后,当尝试对新数据集进行预测时,您可以加载存储的元数据:

StringIndexerModel sim = StringIndexerModel.load("IndexMappingModels/" + field + "_INDEX");

dataset = sim.transform(dataset);

我想你已经解决了这个问题,因为它是在 2018 年发布的,但是我在其他任何地方都没有找到这个解决方案,所以我相信它值得分享。

【讨论】:

    【解决方案2】:

    我的想法是在训练/测试数据集 D 上做这样的事情(混合使用 python 和纯伪代码):

    1. 做类似的事情
        # Before: D.schema == {num_col_1: int, cat_col_1: str, cat_col_2: str, ...}
    
        # assign unique index for each distinct label for categorical column annd store in a new column
        # http://spark.apache.org/docs/latest/ml-features.html#stringindexer
        label_indexer = StringIndexer(inputCol="cat_col_i", outputCol="cat_col_i_index").fit(D)
        D = label_indexer.transform(D)
    
        # After: D.schema == {num_col_1: int, cat_col_1: str, cat_col_2: str, ..., cat_col_1_index: int, cat_col_2_index: int, ...}
    

    对于所有分类列

    1. 然后为 D 中的所有这些分类名称和索引列,制作一个表单的映射
        map = {}
        for all categorical column names colname in D:
            map[colname] = []
            # create mapping dict for all categorical values for all 
            # see https://spark.apache.org/docs/latest/sql-programming-guide.html#untyped-dataset-operations-aka-dataframe-operations
            for all rows r in D.select(colname, '%s_index' % colname).drop_duplicates():
                enc_from = r['%s' % colname]
                enc_to = r['%s_index' % colname]
                map[colname].append((enc_from, enc_to))
    
            # for cats that may appear later that have yet to be seen 
            # (IDK if this is best practice, may be another way, see https://medium.com/@vaibhavshukla182/how-to-solve-mismatch-in-train-and-test-set-after-categorical-encoding-8320ed03552f)
            map[colname].append(('NOVEL_CAT', map[colname].len))
            # sort by index encoding
            map[colname].sort(key = lamdba pair: pair[1])    
    

    最终得到类似的东西

        {
            'cat_col_1': [('orig_label_11', 0), ('orig_label_12', 1), ...],
            'cat_col_2': [(), (), ...],
            ...
            'cat_col_n': [(orig_label_n1, 0), ...]
        }
    

    然后可用于为任何后续数据样本行 ds 中的每个分类列生成 1-hot-encoded 向量。例如。

        for all categorical column names colname in ds:
            enc_from = ds[colname]
            # make zero vector for 1-hot for category 
            col_onehot = zeros.(size = map[colname].len)
            for label, index in map[colname]:
                if (label == enc_from):
                    col_onehot[index] = 1
                    # make new column in sample for 1-hot vector
                    ds['%s_onehot' % colname] = col_onehot
                    break
    
    1. 然后可以将此结构保存为 pickle pickle.dump( map, open( "cats_map.pkl", "wb" ) ),以便在以后进行实际预测时与分类列值进行比较。

    ** 可能有更好的方法,但我认为需要更好地理解这篇文章(https://medium.com/@satnalikamayank12/on-learning-embeddings-for-categorical-data-using-keras-165ff2773fc9)。如果有的话会更新答案。

    【讨论】:

      猜你喜欢
      • 2021-02-22
      • 2022-01-21
      • 2018-10-23
      • 2021-10-16
      • 2021-11-24
      • 2020-01-15
      • 1970-01-01
      • 2018-04-15
      • 2021-09-07
      相关资源
      最近更新 更多