【问题标题】:apply OneHotEncoder for several categorical columns in SparkMlib对 SparkMlib 中的几个分类列应用 OneHotEncoder
【发布时间】:2016-06-18 16:50:58
【问题描述】:

我有几个分类特征,并希望使用OneHotEncoder 将它们全部转换。但是,当我尝试应用 StringIndexer 时,出现错误:

stringIndexer = StringIndexer(
    inputCol = ['a', 'b','c','d'],
    outputCol = ['a_index', 'b_index','c_index','d_index']
)  

model = stringIndexer.fit(Data)
An error occurred while calling o328.fit.
: java.lang.ClassCastException: java.util.ArrayList cannot be cast to java.lang.String
    at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:79)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:379)
    at py4j.Gateway.invoke(Gateway.java:259)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:207)
    at java.lang.Thread.run(Thread.java:745)

Traceback (most recent call last):
Py4JJavaError: An error occurred while calling o328.fit.
: java.lang.ClassCastException: java.util.ArrayList cannot be cast to java.lang.String
    at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:79)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:379)
    at py4j.Gateway.invoke(Gateway.java:259)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:207)
    at java.lang.Thread.run(Thread.java:745)

【问题讨论】:

    标签: python apache-spark pyspark apache-spark-mllib apache-spark-ml


    【解决方案1】:

    Spark >= 3.0

    在 Spark 3.0 中,OneHotEncoderEstimator 已重命名为 OneHotEncoder

    from pyspark.ml.feature import OneHotEncoderEstimator, OneHotEncoderModel
    
    encoder = OneHotEncoderEstimator(...)
    

    from pyspark.ml.feature import OneHotEncoder, OneHotEncoderModel
    
    encoder = OneHotEncoder(...)
    

    火花 >= 2.3

    您可以使用新添加的OneHotEncoderEstimator

    from pyspark.ml.feature import OneHotEncoderEstimator, OneHotEncoderModel
    
    encoder = OneHotEncoderEstimator(
        inputCols=[indexer.getOutputCol() for indexer in indexers],
        outputCols=[
            "{0}_encoded".format(indexer.getOutputCol()) for indexer in indexers]
    )
    
    assembler = VectorAssembler(
        inputCols=encoder.getOutputCols(),
        outputCol="features"
    )
    
    pipeline = Pipeline(stages=indexers + [encoder, assembler])
    pipeline.fit(df).transform(df)
    

    火花

    这是不可能的。 StringIndexer 转换器当时仅在单个列上运行,因此您需要为要转换的每一列使用一个索引器和一个编码器。

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    
    cols = ['a', 'b', 'c', 'd']
    
    indexers = [
        StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
        for c in cols
    ]
    
    encoders = [
        OneHotEncoder(
            inputCol=indexer.getOutputCol(),
            outputCol="{0}_encoded".format(indexer.getOutputCol())) 
        for indexer in indexers
    ]
    
    assembler = VectorAssembler(
        inputCols=[encoder.getOutputCol() for encoder in encoders],
        outputCol="features"
    )
    
    
    pipeline = Pipeline(stages=indexers + encoders + [assembler])
    pipeline.fit(df).transform(df).show()
    

    【讨论】:

    • 带有一列字符串。是否必须同时运行 StringIndexer()OneHotEncoderEstimator()
    【解决方案2】:

    我认为上面的代码不会给出与要求相同的结果。 在编码器部分,需要进行一些修改。因为,StringIndexer 再次应用于 Indexers。所以,这将产生相同的结果。

    #In the following section:
    encoders = [
        StringIndexer(
            inputCol=indexer.getOutputCol(),
            outputCol="{0}_encoded".format(indexer.getOutputCol())) 
        for indexer in indexers
    ]
    
    #Replace the StringIndexer with OneHotEncoder as follows:
    encoders = [OneHotEncoder(dropLast=False,inputCol=indexer.getOutputCol(),
                outputCol="{0}_encoded".format(indexer.getOutputCol())) 
                for indexer in indexers
    ]
    

    现在,完整的代码如下所示:

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    
    categorical_columns= ['Gender', 'Age', 'Occupation', 'City_Category','Marital_Status']
    
    # The index of string vlaues multiple columns
    indexers = [
        StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
        for c in categorical_columns
    ]
    
    # The encode of indexed vlaues multiple columns
    encoders = [OneHotEncoder(dropLast=False,inputCol=indexer.getOutputCol(),
                outputCol="{0}_encoded".format(indexer.getOutputCol())) 
        for indexer in indexers
    ]
    
    # Vectorizing encoded values
    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders],outputCol="features")
    
    pipeline = Pipeline(stages=indexers + encoders+[assembler])
    model=pipeline.fit(data_df)
    transformed = model.transform(data_df)
    transformed.show(5)
    

    更多详情,请参考: 访问:[1]https://spark.apache.org/docs/2.0.2/api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer 访问:[2]https://spark.apache.org/docs/2.0.2/api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder.

    【讨论】:

    • 带有一列字符串。你必须同时运行StringIndexer()OneHotEncoderEstimator() 吗?或者你可以只运行后者吗?
    猜你喜欢
    • 2020-04-19
    • 2019-02-25
    • 2021-09-11
    • 2018-09-07
    • 2020-12-31
    • 2020-10-06
    • 2019-03-25
    • 2019-07-27
    • 2021-12-12
    相关资源
    最近更新 更多