【发布时间】:2017-03-30 02:16:36
【问题描述】:
我正在编写一个 Spark 作业,它接收来自多个来源的数据、过滤错误的输入行并输出稍微修改过的输入版本。这项工作有两个额外的要求:
- 我必须跟踪每个源的错误输入行数,以通知上游提供者。
- 我必须支持每个源的输出限制。
这项工作看起来很简单,我使用累加器来跟踪每个源的过滤行数来解决这个问题。然而,当我实现最终的.limit(N) 时,我的累加器行为发生了变化。以下是一些在单一来源上触发行为的精简示例代码:
from pyspark.sql import Row, SparkSession
from pyspark.sql.types import *
from random import randint
def filter_and_transform_parts(rows, filter_int, accum):
for r in rows:
if r[0] == filter_int:
accum.add(1)
continue
yield r[0], r[1] + 1, r[2] + 1
def main():
spark= SparkSession \
.builder \
.appName("Test") \
.getOrCreate()
sc = spark.sparkContext
accum = sc.accumulator(0)
# 20 inputs w/ tuple having 4 as first element
inputs = [(4, randint(1, 10), randint(1, 10)) if x % 5 == 0 else (randint(6, 10), randint(6, 10), randint(6, 10)) for x in xrange(100)]
rdd = sc.parallelize(inputs)
# filter out tuples where 4 is first element
rdd = rdd.mapPartitions(lambda r: filter_and_transform_parts(r, 4, accum))
# if not limit, accumulator value is 20
# if limit and limit_count <= 63, accumulator value is 0
# if limit and limit_count >= 64, accumulator value is 20
limit = True
limit_count = 63
if limit:
rdd = rdd.map(lambda r: Row(r[0], r[1], r[2]))
df_schema = StructType([StructField("val1", IntegerType(), False),
StructField("val2", IntegerType(), False),
StructField("val3", IntegerType(), False)])
df = spark.createDataFrame(rdd, schema=df_schema)
df = df.limit(limit_count)
df.write.mode("overwrite").csv('foo/')
else:
rdd.saveAsTextFile('foo/')
print "Accum value: {}".format(accum.value)
if __name__ == "__main__":
main()
问题是我的累加器有时会报告过滤的行数,有时不报告,具体取决于指定的限制和源的输入数。但是,在所有情况下,过滤后的行都不会进入输出,这意味着过滤器发生了,累加器应该有一个值。
如果您能对此有所了解,那将非常有帮助,谢谢!
更新:
- 在
mapPartitions之后添加rdd.persist()调用使累加器行为一致。
【问题讨论】:
标签: apache-spark pyspark spark-dataframe