【发布时间】:2018-08-06 22:37:42
【问题描述】:
我想从数据框中删除一些列,然后应用 ML 算法。我通过构建 2 个单独的管道来做到这一点。我的问题是如何将两条管道合并到一个管道中?
#######################
from typing import Iterable
import pandas as pd
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, Transformer
from pyspark.sql import DataFrame
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler
#######################
#Custom Class
#######################
class ColumnDropper_test(Transformer):
def __init__(self, banned_list: Iterable[str]):
super().__init__()
self.banned_list = banned_list
def _transform(self, df: DataFrame) -> DataFrame:
df = df.drop(
*[x for x in df.columns if any(y in x for y in self.banned_list)])
return df
#######################
#Sample Data
#######################
data = pd.DataFrame({
'ball_column': [0, 1, 2, 3],
'keep_column': [7, 8, 9, 10],
'hall_column': [14, 15, 16, 17],
'banned_me': [14, 15, 16, 17],
'target': [21, 31, 41, 51]
})
df = spark.createDataFrame(data)
#######################
# First Pipeline
#######################
column_dropper = ColumnDropper_test(banned_list=['banned_me'])
model = Pipeline(stages=[column_dropper]).fit(df).transform(df)
#######################
#Second Pipeline(Question: Add the block of code below to the above pipeline)
#########################
ready = [col for col in model.columns if col != 'target']
assembler = VectorAssembler(inputCols=ready, outputCol='features')
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='target')
model_2 = Pipeline(stages=[assembler,dtc])
train_data, test_data = model.randomSplit([0.5,0.5])
fit_model = model_2.fit(train_data)
results = fit_model.transform(test_data)
results.select('features','Prediction').show()
我发现的挑战在于上述代码中的变量ready。由于调用column_dropper 后model.columns 会有所不同(列数更少),因此使用(df.columns)将其添加到同一管道将导致以下错误,因为banned_me 已被原始数据删除.
#Combining both Pipelines failed attempt
model = Pipeline(stages=[column_dropper,assembler,dtc]).fit(df).transform(df)
调用 o188.transform 时出错。 : java.lang.IllegalArgumentException:字段“banned_me”不存在。 可用字段:ball_column、keep_column、hall_column、target
我最初的建议是创建一个从ColumnDropper_testclass 继承df.columns 的新变量的新类。我们如何才能让Pipeline 中的assembler 阶段从column_dropper 阶段看到新的df,而不是查看原来的df?
【问题讨论】:
-
非常有趣的问题。我想出的一个丑陋的解决方案是嵌套
Transformers,即将汇编器放在ColumnDropper类中。但必须有更好的解决方案。我确实偶然发现了this,但是,自定义Transformer没有getOutputCols方法。因此,对于一个好的解决方案,我们必须找到一种方法来实现它(我认为)。我还没有找到如何做到这一点,很想看看其他人的解决方法。
标签: python apache-spark pyspark