【问题标题】:How to manually clear the tf.function caches (or manage the max size) in tensorflow 2.0?如何在 tensorflow 2.0 中手动清除 tf.function 缓存(或管理最大大小)?
【发布时间】:2019-09-12 13:35:07
【问题描述】:

下面的示例显示了手动清除缓存的简单方法。是否有更标准/更稳定的方式来管理缓存?或者也许是一种首先避免这种情况的模式?

在某些情况下,批量大小变化很大,我们遇到了内存问题,因为我的 def_fun 没有超出范围并且缓存可能没有清除。

In [164]: @tf.function
     ...: def f(x):
     ...:     return dict(something=x ** 2)
     ...:
     ...:
     ...:

In [165]: f._list_all_concrete_functions_for_serialization()
Out[165]: []

In [166]: _ = f(tf.convert_to_tensor(np.random.randn(109, 3).astype(np.float32)))

In [167]: _ = f(tf.convert_to_tensor(np.random.randn(111, 3).astype(np.float32)))

In [168]: f._list_all_concrete_functions_for_serialization()
Out[168]:
[<tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>,
 <tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>]

In [169]: f._stateful_fn._function_cache._garbage_collectors
Out[169]:
[<tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac94252390>,
 <tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac7b0c6048>,
 <tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac7b0c6d68>]

In [170]: f._stateful_fn._function_cache._garbage_collectors[0]
Out[170]: <tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac94252390>

In [171]: f._stateful_fn._function_cache._garbage_collectors[0]._cache
Out[171]:
OrderedDict([(CacheKey(input_signature=('UTd1s109-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
              <tensorflow.python.eager.function.ConcreteFunction at 0x7fac7371def0>),
             (CacheKey(input_signature=('UTd1s111-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
              <tensorflow.python.eager.function.ConcreteFunction at 0x7fac77a514a8>),
             (CacheKey(input_signature=('URu', (TensorSpec(shape=(111, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
              <tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>),
             (CacheKey(input_signature=('URu', (TensorSpec(shape=(109, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
              <tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>)])

In [172]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[172]:
(CacheKey(input_signature=('URu', (TensorSpec(shape=(109, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
 <tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>)

In [173]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[173]:
(CacheKey(input_signature=('URu', (TensorSpec(shape=(111, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
 <tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>)

In [174]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[174]:
(CacheKey(input_signature=('UTd1s111-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
 <tensorflow.python.eager.function.ConcreteFunction at 0x7fac77a514a8>)

In [175]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[175]:
(CacheKey(input_signature=('UTd1s109-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
 <tensorflow.python.eager.function.ConcreteFunction at 0x7fac7371def0>)

In [176]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()

【问题讨论】:

  • 您找到答案了吗?我正在做一个元学习循环,我必须定期在 tf.function 中重新包装我的方法,并且遇到内存泄漏。我认为它与 tf.function 缓存有关,但找不到清除它的好方法。
  • 我想我确实找到了一些东西,但现在找不到了。我基本上在 REPL 中抓取了一个 tf.function 修饰函数,并查看了 f._get_tracing_count 之类的下划线方法,然后阅读了代码。我认为那里有一些东西要流行。

标签: tensorflow tensorflow2.0


【解决方案1】:

在调用 tf.function 之前,复制对象。

import tensorflow as tf
import copy

@tf.function
def test1(a):
    print('trace')
    return a * a

test2 = copy.copy(test1)
print(test1(1))
print(test1.experimental_get_tracing_count())

print(test1(1))
print(test1.experimental_get_tracing_count())

print(test2(1))
print(test2.experimental_get_tracing_count())

结果:

trace
tf.Tensor(1, shape=(), dtype=int32)
1
tf.Tensor(1, shape=(), dtype=int32)
1
trace
tf.Tensor(1, shape=(), dtype=int32)
1

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2012-02-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-02-25
    • 2017-02-27
    • 2017-10-30
    • 1970-01-01
    相关资源
    最近更新 更多