【问题标题】:Detecting context manager nesting检测上下文管理器嵌套
【发布时间】:2017-06-28 14:00:18
【问题描述】:

我最近一直想知道是否有办法检测上下文管理器是否嵌套。

我创建了 Timer 和 TimerGroup 类:

class Timer:
    def __init__(self, name="Timer"):
        self.name = name
        self.start_time = clock()

    @staticmethod
    def seconds_to_str(t):
        return str(timedelta(seconds=t))

    def end(self):
        return clock() - self.start_time

    def print(self, t):
        print(("{0:<" + str(line_width - 18) + "} >> {1}").format(self.name, self.seconds_to_str(t)))

    def __enter__(self):
        return self

    def __exit__(self, exc_type, value, traceback):
        self.print(self.end())


class TimerGroup(Timer):
    def __enter__(self):
        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()

此代码以可读格式打印计时:

with TimerGroup("Collecting child documents for %s context" % context_name):
    with Timer("Collecting context features"):
        # some code...
    with Timer("Collecting child documents"):
        # some code...


= Collecting child documents for Global context ============
Collecting context features                >> 0:00:00.001063
Collecting child documents                 >> 0:00:10.611130
====================================== Total: 0:00:10.612292

但是,当我嵌套 TimerGroups 时,事情就搞砸了:

with TimerGroup("Choosing the best classifier for %s context" % context_name):
    with Timer("Splitting datasets"):
        # some code...
    for cname, cparams in classifiers.items():
        with TimerGroup("%s classifier" % cname):
            with Timer("Training"):
                # some code...
            with Timer("Calculating accuracy on testing set"):
                # some code


= Choosing the best classifier for Global context ==========
Splitting datasets                         >> 0:00:00.002054
= Naive Bayes classifier ===================================
Training                                   >> 0:00:34.184903
Calculating accuracy on testing set        >> 0:05:08.481904
====================================== Total: 0:05:42.666949

====================================== Total: 0:05:42.669078

我需要做的就是以某种方式缩进嵌套的 Timers 和 TimerGroups。我应该将任何参数传递给他们的构造函数吗?或者我可以从课堂上检测到吗?

【问题讨论】:

  • 您可以更改TimerGroup 以接受另一个计时器组作为“父计时器组”,并在每个TimeGroup 实例中存储缩进(如果没有父则为0,如果没有父则为parent.indentation + 1有一个)

标签: python nested with-statement contextmanager


【解决方案1】:

没有检测嵌套上下文管理器的特殊工具,没有。你必须自己处理这件事。您可以在自己的上下文管理器中执行此操作:

import threading


class TimerGroup(Timer):
    _active_group = threading.local()

    def __enter__(self):
        if getattr(TimerGroup._active_group, 'current', False):
            raise RuntimeError("Can't nest TimerGroup context managers")
        TimerGroup._active_group.current = self
        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TimerGroup._active_group.current = None
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()

然后您可以在别处使用TimerGroup._active_group 属性来获取当前活动的组。我使用了thread-local object 来确保它可以跨多个执行线程使用。

或者,您可以将其设为堆栈计数器,并在嵌套的 __enter__ 调用中递增和递减,或者堆栈 list 并将 self 推送到该堆栈上,当您 @ 时再次弹出它987654326@:

import threading


class TimerGroup(Timer):
    _active_group = threading.local()

    def __enter__(self):
        if not hasattr(TimerGroup._active_group, 'current'):
            TimerGroup._active_group.current = []
        stack = TimerGroup._active_group.current
        if stack:
            # nested context manager.
            # do something with stack[-1] or stack[0]
        TimerGroup._active_group.current.append(self)

        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        last = TimerGroup._active_group.current.pop()
        assert last == self, "Context managers being exited out of order"
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()

【讨论】:

  • 自然也可以让组嵌套,让_active_group变成list,也就是堆栈。
【解决方案2】:

如果您需要做的只是根据您正在执行的嵌套上下文管理器的数量调整缩进级别,那么有一个名为 indent_level 的类属性,并在每次进入和退出上下文管理器时调整它。类似于以下内容:

class Context:
    indent_level = 0

    def __init__(self, name):
        self.name = name

    def __enter__(self):
        print(' '*4*self.indent_level + 'Entering ' + self.name)
        self.adjust_indent_level(1)
        return self

    def __exit__(self, *a, **k):
        self.adjust_indent_level(-1)
        print(' '*4*self.indent_level + 'Exiting ' + self.name)

    @classmethod
    def adjust_indent_level(cls, val):
        cls.indent_level += val

并将其用作:

>>> with Context('Outer') as outer_context:
        with Context('Inner') as inner_context:
            print(' '*inner_context.indent_level*4 + 'In the inner context')


Entering Outer
    Entering Inner
        In the inner context
    Exiting Inner
Exiting Outer

【讨论】:

  • 您的代码似乎只适用于单个计时器类。但是,我在这里有TimerTimerGroup(Timer),并且缩进计数器似乎是为每个独立创建的。有没有办法创建一个在层次结构中的所有类之间共享的字段?我的想法是在adjust_indent_level() 方法中修改Timer.indent_level 并仅在TimerGroup 的__enter____exit__ 方法中调用它。
  • 如果你总是单线程的,这将工作。对于多线程方法,请参阅 Martijn-Pieters 对使用 threading.local 的回答。
【解决方案3】:

import this:

显式优于隐式

更简洁的设计将明确允许指定组:

with TimerGroup('Doing big task') as big_task_tg:
    with Timer('Foo', big_task_tg):
      foo_result = foo()
    with Timer('Bar', big_task_tg):
      bar(baz(foo_result))

另一方面,您始终可以使用traceback.extract_stack 并查找上游特定函数的调用。它对于日志记录和错误报告非常有用,并且对于确保仅在特定上下文中调用特定函数也很有用。但它往往会创建很难跟踪的依赖关系。

我会避免将它用于分组计时器,但您可以尝试。如果您非常需要自动分组,@Martijn-Pieters 的方法要好得多。

【讨论】:

  • 您的想法对于基本用法来说非常好,但是我有时需要从不同的方法调用计时器(一个方法调用另一个方法,并使用需要嵌套的计时器)。所以我认为@Billy 的回答更适合我的问题。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-03-03
  • 2019-06-18
相关资源
最近更新 更多