【问题标题】:How to fix the batch size in keras subclassing model?如何修复keras子类模型中的批量大小?
【发布时间】:2022-12-20 20:43:15
【问题描述】:

在 tf.keras 函数式 API 中,我可以像下面这样固定批量大小:

import tensorflow as tf

inputs = tf.keras.Input(shape=(64, 64, 3), batch_size=1)    # I can fix batch size like this
x = tf.keras.layers.Conv2DTranspose(3, 3, strides=2, padding="same", activation="relu")(inputs)
outputs = x
model = keras.Model(inputs=inputs, outputs=outputs, name="custom")

我的问题是,当我使用 keras 子类化方法时,如何修复批量大小?

【问题讨论】:

标签: python tensorflow keras tensorflow2.0


【解决方案1】:

间接处理参数的一种方法(当无法访问它时)是使用 tf.keras.backend 访问。在这种情况下,tf 通过调用函数来定义输入格式:

def call(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

然后遍历每个批次

for step, x_batch_train in enumerate(train_dataset):
    with tf.GradientTape() as tape:
        reconstructed = vae(x_batch_train)
        # Compute reconstruction loss
        loss = mse_loss_fn(x_batch_train, reconstructed)
        loss += sum(vae.losses)  # Add KLD regularization loss

    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
        print("step %d: mean loss = %.4f" % (step, loss_metric.result())

【讨论】:

    猜你喜欢
    • 2021-05-22
    • 1970-01-01
    • 2020-08-07
    • 1970-01-01
    • 1970-01-01
    • 2016-05-05
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多