【问题标题】:Python 3.x: how to use ast to search for a print statementPython 3.x:如何使用 ast 搜索打印语句
【发布时间】:2019-09-05 13:27:38
【问题描述】:

我正在创建一个测试,它应该检查函数是否包含print 语句(Python 3.x,我使用的是 3.7.4)。我一直在使用ast 来检查类似的事情(参考this 问题中的答案),例如return 或列表推导,但我被困在print 上。

online AST explorer 在正文中列出了 Print 子类,它采用 Python 3 prints,所以我知道这不是 Python 2 的东西。

Green Tree Snakes ast 文档说 Print 在 Python 2 中只有一个 ast 节点。这更接近我所经历的。这是我将用来进行断言的函数:

def printsSomething(func):
    return any(isinstance(node, ast.Print) for node in ast.walk(ast.parse(inspect.getsource(func))))

返回:

TypeError: isinstance() arg 2 must be a type or tuple of types

我假设这与print 是 Python 3.x 中的一个函数有关,但我不知道如何利用这些知识来发挥我的优势。我如何使用ast 来查明print 是否已被调用?

我想重申,我已经让这段代码适用于其他 ast 节点,例如 return,所以我应该确信这不是我的代码特有的错误。

谢谢!

【问题讨论】:

    标签: python python-3.x module abstract-syntax-tree


    【解决方案1】:

    如果调用了任何命名对象(包括print 等函数),那么您的nodes 中将至少有一个_ast.Name 对象。对象的名称('print')存储在该节点的id 属性下。

    我相信您知道,print 在 python 版本 2 和 3 之间从语句变为函数,这可能解释了您遇到问题的原因。

    尝试以下方法:

    import ast
    import inspect
    
    def do_print():
        print('hello')
    
    def dont_print():
        pass
    
    def prints_something(func):
        is_print = False
        for node in ast.walk(ast.parse(inspect.getsource(func))):
            try:
                is_print = (node.id == 'print')
            except AttributeError:  # only expect id to exist for Name objs
                pass
            if is_print:
                break
        return is_print
    
    prints_something(do_print), prints_something(dont_print)
    
    >>> True, False
    

    ...或者,如果您是单线的粉丝(func 是您要测试的函数):

    any(hasattr(node,'id') and node.id == 'print' 
        for node in ast.walk(ast.parse(inspect.getsource(func))))
    

    【讨论】:

    • 我真的希望有一个单行,并且花了最后一个小时左右尝试重构,但我无法解决对 try/except 的需求。我真的很喜欢这个答案,因为它适用于问题中的风格。谢谢!
    • 如果有帮助,我已经添加了一个单行字!
    【解决方案2】:

    print 是 python 3 中的一个函数,因此您需要检查一个 ast.Expr,其中包含一个 ast.Call,其中 ast.Name 的 ID 为 print

    这是一个简单的函数:

    def bar(x: str) -> None:
        string = f"Hello {x}!"  # ast.Assign
        print(string)           # ast.Expr
    

    这是完整的 ast 转储:

    Module(body=[FunctionDef(name='bar', args=arguments(args=[arg(arg='x', annotation=Name(id='str', ctx=Load()))], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Assign(targets=[Name(id='string', ctx=Store())], value=JoinedStr(values=[Str(s='Hello '), FormattedValue(value=Name(id='x', ctx=Load()), conversion=-1, format_spec=None), Str(s='!')])), Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='string', ctx=Load())], keywords=[]))], decorator_list=[], returns=NameConstant(value=None))])
    

    相关部分(print)是:

    Expr(value=Call(func=Name(id='print', ctx=Load())
    

    下面是一个带有节点访问者的简单示例(sublcasing ast.NodeVisitor):

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    import ast
    import inspect
    from typing import Callable
    
    
    class MyNodeVisitor(ast.NodeVisitor):
        def visit_Expr(self, node: ast.Expr):
            """Called when the visitor visits an ast.Expr"""
            print(f"Found expression node at: line: {node.lineno}; col: {node.col_offset}")
    
            # check "value" which must be an instance of "Call" for a 'print'
            if not isinstance(node.value, ast.Call):
                return
    
            # now check the function itself.
            func = node.value.func  # ast.Name
            if func.id == "print":
                print("found a print")
    
    
    def contains_print(f: Callable):
        source = inspect.getsource(f)
        node = ast.parse(source)
        func_name = [_def.name for _def in node.body if isinstance(_def, ast.FunctionDef)][0]
        print(f"{'-' * 79}\nvisiting function: {func_name}")
        print(f"node dump: {ast.dump(node)}")
        node_visitor = MyNodeVisitor()
        node_visitor.visit(node)
    
    
    def foo(x: int) -> int:
        return x + 1
    
    def bar(x: str) -> None:
        string = f"Hello {x}!"  # ast.Assign
        print(string)           # ast.Expr
    
    def baz(x: float) -> float:
        if x == 0.0:
            print("oh noes!")
            raise ValueError
    
        return 10 / x
    
    
    if __name__ == "__main__":
        contains_print(bar)
        contains_print(foo)
        contains_print(baz)
    

    这里是输出(减去 ast 转储):

    -------------------------------------------------------------------------------
    visiting function: bar
    Found expression node at: line: 3; col: 4
    found a print
    -------------------------------------------------------------------------------
    visiting function: foo
    -------------------------------------------------------------------------------
    visiting function: baz
    Found expression node at: line: 3; col: 8
    found a print
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-06-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多