【问题标题】:pyspark sql functions instead of rdd distinctpyspark sql函数而不是rdd distinct
【发布时间】:2017-04-10 23:08:13
【问题描述】:

我一直在尝试替换数据集中特定列的字符串。 1 或 0,如果为 1,则为 'Y',否则为 0。

我已经设法确定要定位哪些列,使用数据帧到 rdd 转换和 lambda,但需要一段时间来处理。

为每一列切换到一个rdd,然后执行一个distinct,这需要一段时间!

如果不同结果集中存在“Y”,则将该列标识为需要转换。

我想知道是否有人可以建议我如何专门使用 pyspark sql 函数来获得相同的结果,而不必为每一列进行切换?

示例数据的代码如下:

    import pyspark.sql.types as typ
    import pyspark.sql.functions as func

    col_names = [
        ('ALIVE', typ.StringType()),
        ('AGE', typ.IntegerType()),
        ('CAGE', typ.IntegerType()),
        ('CNT1', typ.IntegerType()),
        ('CNT2', typ.IntegerType()),
        ('CNT3', typ.IntegerType()),
        ('HE', typ.IntegerType()),
        ('WE', typ.IntegerType()),
        ('WG', typ.IntegerType()),
        ('DBP', typ.StringType()),
        ('DBG', typ.StringType()),
        ('HT1', typ.StringType()),
        ('HT2', typ.StringType()),
        ('PREV', typ.StringType())
        ]

    schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names])
    df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'),
                                ('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'),
                                ('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')]
                               ,schema=schema)

    cols = [(col.name, col.dataType) for col in df.schema]

    transform_cols = []

    for s in cols:
      if s[1] == typ.StringType():
        distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect()
        if 'Y' in distinct_result:
          transform_cols.append(s[0])

    print(transform_cols)

输出是:

['ALIVE', 'DBG', 'HT2', 'PREV']

【问题讨论】:

    标签: python pyspark data-cleaning


    【解决方案1】:

    我设法使用udf 来完成任务。首先,选择带有YN 的列(这里我使用func.first 以便浏览第一行):

    cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict()
    cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']]
    # return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG']
    

    接下来,您可以创建udf 函数以将YN 映射到10

    def map_input(val):
        map_dict = dict(zip(['Y', 'N'], [1, 0]))
        return map_dict.get(val)
    udf_map_input = func.udf(map_input, returnType=typ.IntegerType())
    
    for col in cols:
        df = df.withColumn(col, udf_map_input(col))
    df.show()
    

    最后,您可以对列求和。然后我将输出转换为字典并检查哪些列的值大于 0(即包含 Y

    out = df.select([func.sum(col).alias(col) for col in cols]).collect()
    out = out[0]
    print([col_name for (col_name, val) in out.asDict().items() if val > 0])
    

    输出

    ['DBG', 'HT2', 'ALIVE', 'PREV']
    

    【讨论】:

    • 谢谢,它不一定更有效,但看到另一个解决方案很有用,因为我是 pyspark 的新手。
    猜你喜欢
    • 1970-01-01
    • 2017-11-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多