【问题标题】:In databricks, using unittest.mock.patch on function in a different notebook在 databricks 中,在不同的笔记本中使用 unittest.mock.patch 函数
【发布时间】:2021-07-07 16:39:52
【问题描述】:

在数据块上,我有一个代码笔记本和一个单元测试笔记本。

使用“%run”命令将代码“导入”到单元测试笔记本中。

如何从单元测试笔记本中制作代码笔记本中某个函数的模拟对象?我通常会为此使用补丁上下文管理器。

这里是要修补的函数(get_name)的代码笔记本:

# Databricks notebook source
def get_name_func():
  return 'name1'

单元测试代码如下:

# Databricks notebook source:

from unittest.mock import patch
import inspect


# COMMAND ----------


# MAGIC %run ./get_name


# COMMAND ----------


def local_get_name():
  return 'name_local'


# COMMAND ----------


get_name_func()


# COMMAND ----------


print(inspect.getmodule(get_name_func))
print(inspect.getsourcefile(get_name_func))


# COMMAND ----------


inspect.unwrap(get_name_func)


# COMMAND ----------


with patch('get_name_func') as mock_func:
  print(mock_func)


# COMMAND ----------
with patch('local_get_name') as mock_func:
  print(mock_func)

对于本地函数和代码笔记本中的函数,两次修补尝试都给出相同的错误:

TypeError: Need a valid target to patch. You supplied: 'get_name_func'

检查命令返回:

<module '__main__' from '/local_disk0/tmp/1625490167313-0/PythonShell.py'>
<command-6807918>

Out[38]: <function __main__.get_name_func()>

我尝试了模块路径的各种组合,但没有成功。

奇怪的是,__name__ 返回'__main__'。但是在补丁调用中使用路径'__main__.get_name_func'是行不通的。

我的信念是,如果该对象存在于笔记本中(它确实存在),那么它一定是可修补的。

有什么建议吗?

【问题讨论】:

    标签: python-3.x unit-testing databricks


    【解决方案1】:

    我必须自己制作补丁功能:

    class FunctionPatch():
      '''
      This class is a context manager that allows patching of functions "imported" from another notebook using %run.
      
      The patch function must be at global scope (i.e. top level)
      '''
      def __init__(self, real_func_name: str, patch_func: Callable):
        self._real_func_name = real_func_name
        self._patch_func = patch_func
        self._backup_real_func = None
        
      def __enter__(self):
        self._backup_real_func = globals()[self._real_func_name]
        globals()[self._real_func_name] = self._patch_func
        
      def __exit__(self, exc_type, exc_value, tb):
        if exc_type is not None:
          traceback.print_exception(exc_type, exc_value, tb)
          
        globals()[self._real_func_name] = self._backup_real_func
    

    用法:

    def test_function_patch_real_func():
      return 'real1'
    
    def test_function_patch():
      
      assert test_function_patch_real_func() == 'real1'
      
      def mock_func():
        return 'mock1'
      
      with FunctionPatch('test_function_patch_real_func', mock_func):
        assert test_function_patch_real_func() == 'mock1' 
        
      assert test_function_patch_real_func() == 'real1'
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-08-19
      • 2020-12-23
      • 1970-01-01
      • 2019-04-18
      • 2021-07-08
      • 1970-01-01
      • 2022-11-07
      • 1970-01-01
      相关资源
      最近更新 更多