我的想法是在训练/测试数据集 D 上做这样的事情(混合使用 python 和纯伪代码):
- 做类似的事情
# Before: D.schema == {num_col_1: int, cat_col_1: str, cat_col_2: str, ...}
# assign unique index for each distinct label for categorical column annd store in a new column
# http://spark.apache.org/docs/latest/ml-features.html#stringindexer
label_indexer = StringIndexer(inputCol="cat_col_i", outputCol="cat_col_i_index").fit(D)
D = label_indexer.transform(D)
# After: D.schema == {num_col_1: int, cat_col_1: str, cat_col_2: str, ..., cat_col_1_index: int, cat_col_2_index: int, ...}
对于所有分类列
- 然后为 D 中的所有这些分类名称和索引列,制作一个表单的映射
map = {}
for all categorical column names colname in D:
map[colname] = []
# create mapping dict for all categorical values for all
# see https://spark.apache.org/docs/latest/sql-programming-guide.html#untyped-dataset-operations-aka-dataframe-operations
for all rows r in D.select(colname, '%s_index' % colname).drop_duplicates():
enc_from = r['%s' % colname]
enc_to = r['%s_index' % colname]
map[colname].append((enc_from, enc_to))
# for cats that may appear later that have yet to be seen
# (IDK if this is best practice, may be another way, see https://medium.com/@vaibhavshukla182/how-to-solve-mismatch-in-train-and-test-set-after-categorical-encoding-8320ed03552f)
map[colname].append(('NOVEL_CAT', map[colname].len))
# sort by index encoding
map[colname].sort(key = lamdba pair: pair[1])
最终得到类似的东西
{
'cat_col_1': [('orig_label_11', 0), ('orig_label_12', 1), ...],
'cat_col_2': [(), (), ...],
...
'cat_col_n': [(orig_label_n1, 0), ...]
}
然后可用于为任何后续数据样本行 ds 中的每个分类列生成 1-hot-encoded 向量。例如。
for all categorical column names colname in ds:
enc_from = ds[colname]
# make zero vector for 1-hot for category
col_onehot = zeros.(size = map[colname].len)
for label, index in map[colname]:
if (label == enc_from):
col_onehot[index] = 1
# make new column in sample for 1-hot vector
ds['%s_onehot' % colname] = col_onehot
break
- 然后可以将此结构保存为 pickle
pickle.dump( map, open( "cats_map.pkl", "wb" ) ),以便在以后进行实际预测时与分类列值进行比较。
** 可能有更好的方法,但我认为需要更好地理解这篇文章(https://medium.com/@satnalikamayank12/on-learning-embeddings-for-categorical-data-using-keras-165ff2773fc9)。如果有的话会更新答案。