【问题标题】:Use dictionary in tf.function input_signature in Tensorflow 2.0在 Tensorflow 2.0 的 tf.function input_signature 中使用字典
【发布时间】:2020-03-24 09:14:19
【问题描述】:

我正在使用 Tensorflow 2.0 并面临以下情况:

@tf.function
def my_fn(items):
    .... #do stuff
    return

如果 items 是张量的字典,例如:

item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}

有没有办法使用 tf.function 的 input_signature 参数,这样我可以强制 tf2 在 item1 为 tf.zeros([2,1]) 时避免创建多个图形?

【问题讨论】:

    标签: python tensorflow2.0


    【解决方案1】:

    输入签名必须是一个列表,但列表中的元素可以是字典或张量规格列表。在您的情况下,我会尝试:(name 属性是可选的)

    signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
                       "item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") } 
                  
    
    # don't forget the brackets around the 'signature_dict'
    @tf.function(input_signature = [signature_dict])
    def my_fn(items):
        .... # do stuff
        return
    
    # calling the TensorFlow function
    my_fun(items)
    

    但是,如果你想调用my_fn 创建的特定具体函数,你必须解包字典。您还必须在tf.TensorSpec 中提供name 属性。

    # creating a concrete function with an input signature as before but without
    # brackets and with mandatory 'name' attributes in the TensorSpecs 
    my_concrete_fn = my_fn.get_concrete_function(signature_dict)
                                                 
    # calling the concrete function with the unpacking operator
    my_concrete_fn(**items)
    

    这很烦人,但应该在 TensorFlow 2.3 中解决。 (参见 TF Guide to 'Concrete functions' 的结尾)

    【讨论】:

      猜你喜欢
      • 2020-08-07
      • 1970-01-01
      • 1970-01-01
      • 2019-10-30
      • 1970-01-01
      • 1970-01-01
      • 2021-07-25
      • 1970-01-01
      • 2020-06-27
      相关资源
      最近更新 更多