【问题标题】:OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this functionOperatorNotAllowedInGraphError:不允许使用 `tf.Tensor` 作为 Python `bool`:AutoGraph 确实转换了此函数
【发布时间】:2020-12-29 07:16:30
【问题描述】:

我正在尝试按索引过滤tensorflow.dataset

    dataset = tf.data.Dataset.from_tensor_slices((sequences_matrix, label_data.astype(np.int8)))
    dataset = dataset.cache()
    dataset = dataset.enumerate()

    @tf.function
    def filter_function(i, data):
        return i in train_index # train_index is a list of integers

    train_dataset = dataset.filter(filter_function)

但我收到以下错误:

Traceback (most recent call last):
  File "/home/marzi/workspace/nlp_classification/src/main.py", line 355, in <module>
    if __name__ == '__main__': main()
  File "/home/marzi/workspace/nlp_classification/src/main.py", line 320, in main
    deep_learning_algo(THE_DATA, HYPER_DICT)
  File "/home/marzi/workspace/nlp_classification/src/main.py", line 226, in deep_learning_algo
    tokenizer_name=tokenizer_name
  File "/home/marzi/workspace/nlp_classification/src/train.py", line 118, in fit_normal
    train_dataset = dataset.filter(filter_function)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1862, in filter
    return FilterDataset(self, predicate)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4264, in __init__
    use_legacy_function=use_legacy_function)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3371, in __init__
    self._function = wrapper_fn.get_concrete_function()
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2939, in get_concrete_function
    *args, **kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2906, in _get_concrete_function_garbage_collected
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3364, in wrapper_fn
    ret = _wrapper_helper(*args)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3299, in _wrapper_helper
    ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:

    /home/marzi/workspace/nlp_classification/src/train.py:116 filter_function  *
        return i in train_index
    /home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:877 __bool__
        self._disallow_bool_casting()
    /home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:487 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

但是,如果我将过滤器函数中的条件从 i in train_index 更改为 i &gt; 10,它就可以正常工作。我不明白这两个条件之间有什么区别使得其中一个引发错误而另一个不引发错误?

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    使用@tf.function 会将您的操作转换为图形模式,并将列表理解is not supported 转换为图形模式。您可以改用tf.map_fntf.py_function

    @tf.function
    def filter_function(i, data):
        return tf.py_function(lambda x: x in train_index, inp=[i], Tout=tf.bool)
    

    例如:

    import tensorflow as tf
    
    train_index = [i for i in range(25) if i > 10]
    
    dataset = tf.data.Dataset.from_tensor_slices(list(range(25)))
    dataset = dataset.cache()
    dataset = dataset.enumerate()
    
    
    @tf.function
    def filter_function(i, data):
        return tf.py_function(lambda x: x in train_index, inp=[i], Tout=tf.bool)
    
    
    train_dataset = dataset.filter(filter_function)
    
    for i in train_dataset:
        print(i[0].numpy())
    
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    

    更多阅读:Better performance with tf.function

    【讨论】:

      猜你喜欢
      • 2020-07-31
      • 2021-05-26
      • 1970-01-01
      • 2018-08-17
      • 1970-01-01
      • 1970-01-01
      • 2018-07-12
      • 2021-04-18
      • 2020-05-04
      相关资源
      最近更新 更多