【问题标题】:tf.io.GFile with Tensor String Input带有张量字符串输入的 tf.io.GFile
【发布时间】:2023-01-07 18:07:36
【问题描述】:

我想检索 GCS 对象/任何 S3 对象作为模型的一部分,作为第一层,它将根据文件名获取特征,因为它会降低网络开销,我正在尝试将下载包装到tf.function,但没有成功。 这是 MWE:

import tensorflow as tf
@tf.function
def load_file(a):
    if tf.is_tensor(a):
        a_path = tf.strings.substr(a, 0, 2) + "/" + a
    else:
        a_path = a[0:2] + "/" + a
    with tf.io.gfile.GFile("gs://some_bucket" + a_path) as f:
        return f.read()
load_file(tf.constant("file3"))

这会引发错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [22], line 9
      7     with tf.io.gfile.GFile("gs://some_bucket" + a_path) as f:
      8         return f.read()
----> 9 load_file(tf.constant("file3"))

File /opt/conda/envs/wanna-hmic/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File /opt/conda/envs/wanna-hmic/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1147, in func_graph_from_py_func.<locals>.autograph_handler(*args, **kwargs)
   1145 except Exception as e:  # pylint:disable=broad-except
   1146   if hasattr(e, "ag_error_metadata"):
-> 1147     raise e.ag_error_metadata.to_exception(e)
   1148   else:
   1149     raise

TypeError: in user code:

File "/tmp/ipykernel_4006/3877294148.py", line 8, in load_file  *
    return f.read()

TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. tensorflow.python.lib.io._pywrap_file_io.BufferedInputStream(filename: str, buffer_size: int, token: tensorflow.python.lib.io._pywrap_file_io.TransactionToken = None)

Invoked with: <tf.Tensor 'add_2:0' shape=() dtype=string>, 524288

该代码在 load_file("file3") 的急切模式下运行良好,但为了表现良好,我需要它甚至在图形模式下也能运行。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    tf.io.read_file 可以解决问题。 修改整个代码为

    @tf.function
    def load_file(a):
        a = tf.convert_to_tensor(a)
        a_path = tf.strings.substr(a, 0, 2) + "/" + a
        return tf.io.read_file("gs://some_bucket" + a_path)
    

    使其在 eager 和图形环境中都能工作,并使输入变量始终为 tf 张量。

    【讨论】:

      猜你喜欢
      • 2015-02-25
      • 1970-01-01
      • 2016-02-23
      • 1970-01-01
      • 1970-01-01
      • 2012-06-05
      • 1970-01-01
      • 1970-01-01
      • 2019-08-26
      相关资源
      最近更新 更多