我也遇到过同样的问题,我通过对基线代码稍作调整解决了这个问题。
baselines 中有两对方法用于保存和加载模型(save_state&load_state 对和 save_variables&loas_variables 对),您可以在 baselines/common/tf_util.py(line325~line372) 中看到它.
对于最新版本的baselines,以.ckpt.meta、.ckpt.index、.ckpt.data和checkpoint格式保存和加载模型的save_state&load_state对已被废弃,因此您需要重新启用save_state&load_state对。
以ppo2为例,在baselines/ppo2/model.py中进行如下替换:
在第 125 行,替换
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
与
self.save = functools.partial(save_state, sess=sess)
self.load = functools.partial(load_state, sess=sess)
在第 4 行,
替换
from baselines.common.tf_util import get_session, save_variables, load_variables
与
from baselines.common.tf_util import get_session, save_state, load_state
这会将 save_variables&loas_variables 对替换为 save_state&load_state 对。
希望这会对你有所帮助。