【问题标题】:`tf.function` decorator causes batch shape to be `NoneType` (Tensorflow2, Python)`tf.function` 装饰器导致批处理形状为`NoneType`(Tensorflow2,Python)
【发布时间】:2020-09-16 16:19:18
【问题描述】:

在以下代码中:

@tf.function
def get_x_y(dataset,count=1):
  X = tf.TensorArray(tf.float32,count)
  Y = tf.TensorArray(tf.float32,count)
  idx = tf.Variable(0,dtype=tf.int32)
  for batch in dataset.take(count):
    max_x,max_y,max_z = batch.shape
    x = tf.slice(batch,[0,0,0],[-1,max_y-1,-1])
    y = tf.slice(batch,[0,1,0],[-1,-1,-1 ]) 
    y = tf.argmax(y,-1)
    y = tf.cast(y,tf.float32)
    X = X.write(idx.numpy(),x)
    Y = Y.write(idx.numpy(),y)
    idx.assign_add(1)
  return X.stack(),Y.stack()

使用 tf.function 装饰器时的输出符合预期,但是使用 tf.function 装饰器时,会出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-56-2306b6173e70> in <module>()
----> 1 print(get_x_y(dataset))nt 

8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in user code:

    <ipython-input-55-4a32428fa07d>:8 get_x_y  *
        x = tf.slice(batch,[0,0,0],[-1,max_y-1,-1])

    TypeError: unsupported operand type(s) for -: 'NoneType' and 'int'

预期的输出是(张量A,张量B)形状的二元组(count,batch_size,batch_seq_len,vocab_size)(count,batch_size,batch_seq_len),分别是

count = argument provided to the function
batch_size = 128
batch_seq_len = ?, max_len(seq_i) for all seq in batch
vocab_size = 78

例如,函数未使用tf.function 修饰时的预期输出:

(<tf.Tensor: shape=(1, 128, 262, 78), dtype=float32, numpy=
array([[[[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]],

        [[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]],

        [[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 1.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 1.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]],

        [[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 1.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]],

        [[0., 1., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 1.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)>, <tf.Tensor: shape=(1, 128, 262), dtype=float32, numpy=
array([[[ 7., 53., 69., ...,  0.,  0.,  0.],
        [15., 21.,  5., ...,  0.,  0.,  0.],
        [15., 77.,  5., ...,  0.,  0.,  0.],
        ...,
        [15., 77.,  5., ...,  0.,  0.,  0.],
        [15., 77.,  5., ...,  0.,  0.,  0.],
        [15., 77.,  5., ...,  0.,  0.,  0.]]], dtype=float32)>)

有人知道为什么吗?我怀疑这是因为数据集的形状为&lt;MapDataset shapes: (128, None, 78), types: tf.float32&gt;,但是我不明白为什么max_yNone,因为批处理形状在循环之前是已知的?

【问题讨论】:

    标签: tensorflow2.0 tensorflow-datasets tensorflow2.x


    【解决方案1】:

    我希望我能很好地理解你: max_y 存在问题,因为它没有问题,它是 tf.slice 的大小参数的一部分引发异常。

    strided_slice() 需要 4 个参数 input_、begin、end、strides。

    该方法的功能非常简单: 它的工作原理类似于循环遍历,其中 begin 是循环开始的张量中元素的位置,end 是循环停止的位置。

    所以请尝试: 这里的步骤或步幅可以是dataset.take(count)

    tf.strided_slice(input, [start1, start2, ..., startN],
        [end1, end2, ..., endN], [step1, step2, ..., stepN])
    
    
    max_y = tf.Variable(np.array(batch) 
    s = tf.strided_slice(max_y, begin, end, max_y, name='var_slice')
    

    【讨论】:

    • 您好 Mahsa,谢谢您的回答。您已经正确理解问题是,max_yNone。但是,max_y 不应该是批次,只是批次的第二个轴的长度。此外,我不想从批处理中获取一个 numpy 数组,这将迫使计算脱离 GPU。问题是当我尝试执行_,max_y,_ = batch.shape 时,如果tf.function 装饰器用于禁用急切执行,max_y 最终会成为None
    猜你喜欢
    • 1970-01-01
    • 2021-12-01
    • 2021-07-11
    • 2021-07-15
    • 1970-01-01
    • 1970-01-01
    • 2019-09-17
    • 1970-01-01
    • 2016-09-27
    相关资源
    最近更新 更多