【问题标题】:Python unittest mock pyspark chainPython unittest 模拟 pyspark 链
【发布时间】:2021-12-28 21:09:01
【问题描述】:

我想为具有 pyspark 代码的简单方法编写一些单元测试。

def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
    df2 = self.spark.read.format('parquet').load(df2_path)
    return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')

如何模拟 spark 读取部分?我试过这个:

@patch("class_to_test.SparkSession")
def test_do_stuff(self, mock_spark: MagicMock) -> None:
    spark = MagicMock()
    spark.read.return_value.format.return_value.load.return_value = \
        self.spark.createDataFrame([(1, 2)], ["key2", "c2"])
    mock_spark.return_value = spark

    input_df = self.spark.createDataFrame([(1, 1)], ["key1", "c1"])
    actual_df = ClassToTest().do_stuff(input_df, "df2", "key1", "key2")
    expected_df = self.spark.createDataFrame([(1, 1, 1, 2)], ["key1", "c1", "key2", "c2"])
    assert_pyspark_df_equal(actual_df, expected_df)

但它失败并出现此错误:
py4j.Py4JException: Method join([class java.util.ArrayList, class org.apache.spark.sql.Column, class java.lang.String]) does not exist
看起来模拟没有像我预期的那样工作,我应该怎么做才能让 spark.read.load 返回我指定的测试数据帧?

编辑:full code here

【问题讨论】:

    标签: python pyspark mocking python-unittest python-unittest.mock


    【解决方案1】:

    您可以使用PropertyMock 来完成。这是一个例子:

    # test.py
    import unittest
    from unittest.mock import patch, PropertyMock, Mock
    
    from pyspark.sql import SparkSession, DataFrame, functions as f
    from pyspark_test import assert_pyspark_df_equal
    
    
    class ClassToTest:
        def __init__(self) -> None:
            self._spark = SparkSession.builder.getOrCreate()
    
        @property
        def spark(self):
            return self._spark
    
        def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
            df2 = self.spark.read.format('parquet').load(df2_path)
            return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')
    
    
    class TestClassToTest(unittest.TestCase):
        def setUp(self) -> None:
            self.spark = SparkSession.builder.getOrCreate()
    
        def test_do_stuff(self) -> None:
            # let's say ClassToTest().spark.read.format().load() will return a DataFrame
            with patch(
                # change __main__ to your module...
                '__main__.ClassToTest.spark',
                new_callable=PropertyMock,
                return_value=Mock(
                    # read property
                    read=Mock(
                        # format() method
                        format=Mock(
                            return_value=Mock(
                                # load() method result:
                                load=Mock(return_value=self.spark.createDataFrame([(1, 2)], ['key2', 'c2']))))))
            ):
                input_df = self.spark.createDataFrame([(1, 1)], ['key1', 'c1'])
                df = ClassToTest().do_stuff(input_df, 'df2_path', 'key1', 'key2')
                assert_pyspark_df_equal(
                    df,
                    self.spark.createDataFrame([(1, 1, 1, 2)], ['key1', 'c1', 'key2', 'c2'])
                )
    
    
    if __name__ == '__main__':
        unittest.main()
    

    让我们检查一下:

    python test.py
    # result:
    ----------------------------------------------------------------------
    Ran 1 test in 7.460s
    
    OK
    

    【讨论】:

      猜你喜欢
      • 2021-04-12
      • 2020-11-26
      • 2022-01-22
      • 1970-01-01
      • 1970-01-01
      • 2020-06-20
      • 2015-05-10
      • 2019-11-24
      • 2010-11-22
      相关资源
      最近更新 更多