【问题标题】:Check equality for two Spark DataFrames in Scala在 Scala 中检查两个 Spark DataFrame 的相等性
【发布时间】:2019-10-25 18:31:50
【问题描述】:

我是 Scala 新手,在编写单元测试时遇到问题。

我试图在 Scala 中比较和检查两个 Spark DataFrame 的相等性以进行单元测试,并意识到没有简单的方法来检查两个 Spark DataFrame 的相等性。

C++ 等效代码是(假设 DataFrame 在 C++ 中表示为双精度数组):

    int expected[10][2];
    int result[10][2];
    for (int row = 0; row < 10; row++) {
        for (int col = 0; col < 2; col++) {
            if (expected[row][col] != result[row][col]) return false;
        }
    }

实际测试将涉及基于 DataFrame 列的数据类型的相等性测试(对浮点数进行精度容差测试等)。

似乎没有一种简单的方法可以使用 Scala 对 DataFrame 中的所有元素进行迭代循环,而其他用于检查两个 DataFrame 相等性的解决方案(例如df1.except(df2))在我的情况下不起作用,因为我需要能够为测试相等性和浮点数和双精度数提供支持。

当然,我可以尝试事先对所有元素进行四舍五入,然后比较结果,但我想看看是否有任何其他解决方案可以让我遍历 DataFrame 以检查是否相等。

【问题讨论】:

  • 你的数据框有多大?如果它们不是那么大,您可以对它们进行排序/收集,然后轻松进行比较。
  • 因为这些是单元测试数据帧,所以应该很小。只需将它们收集到一个列表中并进行比较。
  • 是的,我的测试目前将数据帧收集到一个列表中并进行比较,但我希望创建也可以测试更大数据帧的工具。我猜没有简单的方法可以做到这一点?
  • *** 3 年 4 个月前询问过 5 个月前有效 已查看 7k 次 --- YET 仍然没有接受答案 ...

标签: scala unit-testing apache-spark spark-dataframe


【解决方案1】:
import org.scalatest.{BeforeAndAfterAll, FeatureSpec, Matchers}

outDf.collect() should contain theSameElementsAs (dfComparable.collect())
# or ( obs order matters ! )

// outDf.except(dfComparable).toDF().count should be(0)
outDf.except(dfComparable).count should be(0)   

【讨论】:

  • except 函数已经返回了一个数据框,所以不需要toDF
  • outDf.except(dfComparable).count should be(0) 不是一个好选择,因为.except 返回一个表格,其中包含左侧不在右侧的元素。如果左侧缺少某些元素,则测试不会失败。 assertSmallDataFrameEquality是一个更好的选择见stackoverflow.com/questions/31197353/…
  • assertDataFrameEquals 来自 spark-testing-base 也是一种替代方案。
【解决方案2】:

如果要检查两个数据框是否相等以进行测试,可以使用subtract()数据框方法(1.3及以上版本支持)

您可以检查两个数据帧的 diff 是否为空或 0。 例如df1.subtract(df2).count() == 0

【讨论】:

  • 感谢您的建议,但我在问题中提到的df1.except(df2) 具有与df1.subtract(df2) 相同的功能,并且在这种情况下实际上不起作用,我希望将这些值与精度公差。
【解决方案3】:

假设您有固定的列数和行数,一种解决方案可以是按行索引连接两个 Df(如果您没有记录的 id),然后直接在最终的 DF 中迭代 [所有两个 DF 的列]。 像这样的:

Schemas
DF1
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)

DF2
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)

df1
+----------+-----------+------+
|      col1|       col2|  col3|
+----------+-----------+------+
|1.20000001|       1.21|   1.2|
|    2.1111|        2.3|  22.2|
|       3.2|2.330000001| 2.333|
|    2.2444|      2.344|2.3331|
+----------+-----------+------+

df2
+------+-----+------+
|  col1| col2|  col3|
+------+-----+------+
|   1.2| 1.21|   1.2|
|2.1111|  2.3|  22.2|
|   3.2| 2.33| 2.333|
|2.2444|2.344|2.3331|
+------+-----+------+

Added row index
df1
+----------+-----------+------+---+
|      col1|       col2|  col3|row|
+----------+-----------+------+---+
|1.20000001|       1.21|   1.2|  0|
|    2.1111|        2.3|  22.2|  1|
|       3.2|2.330000001| 2.333|  2|
|    2.2444|      2.344|2.3331|  3|
+----------+-----------+------+---+

df2
+------+-----+------+---+
|  col1| col2|  col3|row|
+------+-----+------+---+
|   1.2| 1.21|   1.2|  0|
|2.1111|  2.3|  22.2|  1|
|   3.2| 2.33| 2.333|  2|
|2.2444|2.344|2.3331|  3|
+------+-----+------+---+

Combined DF
+---+----------+-----------+------+------+-----+------+
|row|      col1|       col2|  col3|  col1| col2|  col3|
+---+----------+-----------+------+------+-----+------+
|  0|1.20000001|       1.21|   1.2|   1.2| 1.21|   1.2|
|  1|    2.1111|        2.3|  22.2|2.1111|  2.3|  22.2|
|  2|       3.2|2.330000001| 2.333|   3.2| 2.33| 2.333|
|  3|    2.2444|      2.344|2.3331|2.2444|2.344|2.3331|
+---+----------+-----------+------+------+-----+------+

你可以这样做:

println("Schemas")
    println("DF1")
    df1.printSchema()
    println("DF2")
    df2.printSchema()
    println("df1")
    df1.show
    println("df2")
    df2.show
    val finaldf1 = df1.withColumn("row", monotonically_increasing_id())
    val finaldf2 = df2.withColumn("row", monotonically_increasing_id())
    println("Added row index")
    println("df1")
    finaldf1.show()
    println("df2")
    finaldf2.show()

    val joinedDfs = finaldf1.join(finaldf2, "row")
    println("Combined DF")
    joinedDfs.show()

    val tolerance = 0.001
    def isInValidRange(a: Double, b: Double): Boolean ={
      Math.abs(a-b)<=tolerance
    }
    joinedDfs.take(10).foreach(row => {
      assert( isInValidRange(row.getDouble(1), row.getDouble(4)) , "Col1 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(2), row.getDouble(5)) , "Col2 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(3), row.getDouble(6)) , "Col3 validation. Row %s".format(row.getLong(0)+1))
    })

注意:Assert 没有序列化,解决方法是使用 take() 来避免错误。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2015-11-23
    • 2015-09-20
    • 2013-02-25
    • 1970-01-01
    • 2020-02-13
    • 2015-07-19
    • 1970-01-01
    • 2015-09-28
    相关资源
    最近更新 更多