【问题标题】:Can't load tensorflow (tf-agent) saved model无法加载 tensorflow(tf-agent)保存的模型
【发布时间】:2019-06-11 00:35:32
【问题描述】:

我正在以下代码中创建一个 tf-agent DqnAgent:

tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
    train_step_counter=train_step_counter

)

在训练循环中,我将这个模型保存为

tf.saved_model.save(tf_agent, saved_models_path)

训练完成后,我想加载保存的模型

if tf.saved_model.contains_saved_model(saved_models_path):
    tf_agent = tf.saved_model.load(saved_models_path)

此代码仅在saved_path中的文件夹包含一个时才会加载保存的模型,函数contains_saved_model(saved_models_path)返回True,因此模型已加载,但出现异常并且程序崩溃:

Traceback (most recent call last):
    File "/home/claudino/Projetos/dino-tf-agents/dino_ia/model/agent.py", line 50, in <module>
        tf_agent = tf.saved_model.load(saved_models_path)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 408, in load
        return load_internal(export_dir, tags)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 432, in load_internal
        export_dir)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 58, in __init__
        self._load_all()
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 168, in _load_all
        slot_variable = optimizer_object.add_slot(
    AttributeError: '_UserObject' object has no attribute 'add_slot'

    Process finished with exit code 1

我浏览了 tensorflow 代码,但找不到问题所在。谁能帮帮我?

我正在使用tf-agents-nightly,因为谷歌的协作源代码不适用于tf-agents“稳定”版本(我不确定tf-agents 是否真的稳定),并使用tensorflow 1.3 和@ 尝试了代码987654333@,同样的问题。

【问题讨论】:

  • 这个问题已经有一段时间了,但是对于未来的查询; PolicySaver 类现在用于保存 TF-Agents 策略。开发者是planning to add it to a tutorial,但还没有完成。
  • 检查点保存、模型保存和策略保存有区别吗?

标签: python tensorflow


【解决方案1】:

您尝试过 TensorFlow 2.7 吗?这通常有助于解决这个问题。

对我有用的其他方法是以这种方式加载模型(假设模型是 keras/tf.keras 模型):

try:
    model = tf.keras.models.load_model(model_dir)
except:
  load_options = tf.saved_model.LoadOptions(experimental_io_device= '/job:localhost')
  model = tf.saved_model.load(model_dir, options= load_options)

try 子句会导致异常,因为load_model() 需要一个keras_metadata.pb 文件,而使用saved_model.save() 保存模型时,该文件不存在。

但是,运行该子句会以某种方式使tf.saved_model.load() 运行而不会出现任何问题。可能在后台发生了某种我不太了解的交互,但它对我有用,并且不会出现“no attribute add_slot”错误。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-03-13
    • 2019-04-12
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多