【问题标题】:How to get batch_size in call() function in TF2?如何在 TF2 的 call() 函数中获取 batch_size?
【发布时间】:2021-07-09 09:09:26
【问题描述】:

我正在尝试在 TF2 模型中的 call() 函数中获取 batch_size。 但是,我无法得到它,因为我知道的所有方法都返回 None 或 Tensor 而不是维度元组。

这是一个简短的例子

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
    
    def call(self, x):
        print(len(x))
        print(x.shape)
        print(tf.size(x))
        print(np.shape(x))
        print(x.get_shape())
        print(x.get_shape().as_list())
        print(tf.rank(x))
        print(tf.shape(x))
        print(tf.shape(x)[0])
        print(tf.shape(x)[1])        
        return tf.random.uniform((2, 10))


m = MyModel()
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
m.fit(np.array([[1,2,3,4], [5,6,7,8]]), np.array([0, 1]), epochs=1)

输出是:

Tensor("my_model_26/strided_slice:0", shape=(), dtype=int32)
(None, 4)
Tensor("my_model_26/Size:0", shape=(), dtype=int32)
(None, 4)
(None, 4)
[None, 4]
Tensor("my_model_26/Rank:0", shape=(), dtype=int32)
Tensor("my_model_26/Shape_2:0", shape=(2,), dtype=int32)
Tensor("my_model_26/strided_slice_1:0", shape=(), dtype=int32)
Tensor("my_model_26/strided_slice_2:0", shape=(), dtype=int32)

1/1 [==============================] - 0s 1ms/step - loss: 3.1796 - accuracy: 0.0000e+00

在此示例中,我将 (2,4) numpy 数组作为输入并将 (2, ) 作为目标提供给模型。 但如您所见,我无法在call() 函数中获取batch_size

我需要它的原因是因为我必须为 batch_size 迭代张量,这在我的真实模型中是动态的。

例如,如果数据集大小为 10,batch size 为 3,则 last batch 中的最后一个 batch size 为 1。所以,我必须动态知道 batch size。

谁能帮帮我?


  • 张量流 2.3.3
  • CUDA 10.2
  • python 3.6.9

【问题讨论】:

  • call() 方法中你只需要创建你的模型。数据迭代,应该发生在train_step函数中。您确定在call() 中需要batch_szie 吗?
  • @Kaveh 是的,我需要像here 那样做 3d 稀疏张量批量乘法。为此,据我所知,我必须知道 call() 函数中的批量大小(实际上我是 pytorch 用户,所以可能是错误的。请告诉我)。

标签: tensorflow keras tensorflow2.0 tensorflow-datasets batchsize


【解决方案1】:

这是因为您使用的是 TensorFlow(这是强制性的,因为 Keras 现在在 TensorFlow 中),并且通过使用 TensorFlow,您需要了解将动态图“编译”为静态图。

简而言之,您的 call 方法(在后台)使用 @tf.function 装饰器进行装饰。

这个装饰器:

  1. 跟踪 python 函数执行
  2. 在 TensorFlow 操作中转换 python 操作(例如,if a > b 变为 tf.cond(tf.greater(a,b), something, something_else)
  3. 创建一个tf.Graph(静态图)
  4. 执行刚刚创建的静态图。

您的所有print 调用都是在第一步(python 执行跟踪)中执行的,这就是为什么即使您训练您的模型,您也只能看到 1 次输出。

也就是说,要获得张量的运行时(动态形状),您必须使用tf.shape(x),批量大小正好是batch_size = tf.shape(x)[0]

请注意,如果要查看形状(使用打印),则不能使用打印,但必须使用tf.print

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()

    def call(self, x):

        shape = tf.shape(x)
        batch_size = shape[0]

        tf.print(shape, batch_size)

        return tf.random.uniform((2, 10))


m = MyModel()
m.compile(
    optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
m.fit(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), np.array([0, 1]), epochs=1)

有关静态和动态形状的更多信息:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

有关 tf.function 行为的更多信息:https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/

注意:这些文章是我写的。

【讨论】:

    【解决方案2】:

    如果您想准确地获取数据和形状,您可以将 eager run 变为 true,但这不是一个好的解决方案,因为它会使训练变慢。

    这样设置:

    m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy",
                               metrics=['accuracy'], run_eagerly=True)
    

    那么输出将是:

    (2, 4)
    tf.Tensor(8, shape=(), dtype=int32)
    (2, 4)
    (2, 4)
    [2, 4]
    tf.Tensor(2, shape=(), dtype=int32)
    tf.Tensor([2 4], shape=(2,), dtype=int32)
    tf.Tensor(2, shape=(), dtype=int32)
    tf.Tensor(4, shape=(), dtype=int32)
    

    【讨论】:

    • 感谢您的回复。顺便说一句,我将 batch_size (int) 传递给 call() 函数,我发现 batch_size (int) 已更改为“常量”张量。我想知道为什么会这样。你有什么线索吗?
    猜你喜欢
    • 2021-08-07
    • 1970-01-01
    • 2022-06-13
    • 2020-11-09
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-08-21
    • 1970-01-01
    相关资源
    最近更新 更多