自定义函数可以写成,在特定列上应用一个热编码 (OHE),仅用于前 N 个频繁值(比如 N = 3)。
它与Python比较相似,1)构建一个基于top n频繁的Dataframe/Dictionary。 2)旋转前 n 个频繁的 Dataframe,即创建 OHE 向量。 3)左连接给定的Dataframe和pivot Dataframe,用0替换null,即默认的OHE向量。
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, lit, when}
import org.apache.spark.sql.Column
import spark.implicits._
val df = spark
.sparkContext
.parallelize(Seq("a", "b", "c", "a", "b", "c", "d", "e", "a", "b", "f", "a", "g", "a", "b", "c", "a", "d", "e", "f", "a", "b", "g", "b", "c", "f", "a", "b", "c"))
.toDF("value")
val oheEncodedDF = oheEncoding(df, "value", 3)
def oheEncoding(df: DataFrame, colName: String, n: Int): DataFrame = {
df.createOrReplaceTempView("data")
val topNDF = spark.sql(s"select $colName, count(*) as count from data group by $colName order by count desc limit $n")
val pivotTopNDF = topNDF
.groupBy(colName)
.pivot(colName)
.count()
.withColumn("default", lit(1))
val joinedTopNDF = df.join(pivotTopNDF, Seq(colName), "left").drop(colName)
val oheEncodedDF = joinedTopNDF
.na.fill(0, joinedTopNDF.columns)
.withColumn("default", flip(col("default")))
oheEncodedDF
}
def flip(col: Column): Column = when(col === 1, lit(0)).otherwise(lit(1))