【问题标题】:ActorCritic - model cannot be saved because the input shapes have not been setActorCritic - 无法保存模型,因为尚未设置输入形状
【发布时间】:2021-10-24 06:07:22
【问题描述】:

我想保存一个 Actor-Critic 模型,但出现了这个问题。

import os
import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, LSTM, BatchNormalization
from tensorflow.keras import Model


class ActorCritic(Model):
    def __init__(self, action_size, state_size):
        super(ActorCritic, self).__init__()
        self.lstm1 = LSTM(16, return_sequences=True, input_shape=state_size)
        self.lstm2 = LSTM(8, return_sequences=True)
        self.flatten = Flatten()
        self.policy = Dense(action_size, activation='linear')
        self.value = Dense(1, activation='linear')

    def call(self, x):
        x = self.lstm1(x)
        x = self.lstm2(x)
        x = self.flatten(x)
        policy = self.policy(x)
        value = self.value(x)

        return policy, value


class A3CAgent():

    def __init__(self):
        self.state_size = (9, 23)
        self.action_size = 2
        self.save_path = os.path.join(os.getcwd(), 'model')

        self.global_model = ActorCritic(self.action_size, self.state_size)
        self.global_model.build((None, *self.state_size))

        self.global_model.save(self.save_path)


if __name__ == "__main__":
    global_agent = A3CAgent()

输出:

Traceback (most recent call last):
ValueError: Model <__main__.ActorCritic object at 0x000001933D7F1E10> cannot be saved because the input shapes have not been set.
Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`.
To manually set the shapes, call `model.build(input_shape)`.

我写了 'self.global_model.build((None, *self.state_size))',但它不起作用。

如何拨打model.build(input_shape)或解决?

【问题讨论】:

    标签: tensorflow keras reinforcement-learning


    【解决方案1】:

    在我看来,问题不在于构建方法没问题,而是模型类的保存方法不适用于自定义模型。相反,您可以使用 tensorflow 检查点格式,它允许保存模型以及优化器等其他内容(对于给定的迭代)。 你可以在这里找到信息: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint, https://www.tensorflow.org/guide/checkpoint

    你可以像这样修改你的代码(删除self.global_model.save(self.save_path)):

    if __name__ == "__main__":
        state_size = (9, 23)
        action_size = 2
        save_path = os.path.join(os.getcwd(), 'model')
        print(save_path)
    
        global_agent = A3CAgent()
        global_model = global_agent.global_model
        
        # init check point
        checkpoint = tf.train.Checkpoint(model=global_model)
        #  restore latest model
        checkpoint.restore(tf.train.latest_checkpoint(save_path)).assert_consumed()
        
        # do training
        # train the model and save each xxx iterations
        checkpoint.save(os.path.join(save_path, 'ckpt'))
        
        # you can reload the latest model :
        checkpoint.restore(tf.train.latest_checkpoint(save_path)).assert_consumed()
    

    【讨论】:

    • 我见过'save_weights',但不知道'checkpoint.restore'功能。衷心感谢。
    【解决方案2】:

    运行compute_output_shape 以完全构建模型,如下所示。

    self.global_model.build((None, *self.state_size))
    self.global_model.compute_output_shape(input_shape=(None,9, 23)) # added this line
    

    只是将最后一行添加到您的代码中,没有任何修改。请查看gist here

    【讨论】:

      猜你喜欢
      • 2021-11-17
      • 2022-06-30
      • 1970-01-01
      • 1970-01-01
      • 2019-08-31
      • 2021-06-03
      • 1970-01-01
      • 2013-10-02
      • 1970-01-01
      相关资源
      最近更新 更多