【发布时间】:2021-10-24 06:07:22
【问题描述】:
我想保存一个 Actor-Critic 模型,但出现了这个问题。
import os
import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, LSTM, BatchNormalization
from tensorflow.keras import Model
class ActorCritic(Model):
def __init__(self, action_size, state_size):
super(ActorCritic, self).__init__()
self.lstm1 = LSTM(16, return_sequences=True, input_shape=state_size)
self.lstm2 = LSTM(8, return_sequences=True)
self.flatten = Flatten()
self.policy = Dense(action_size, activation='linear')
self.value = Dense(1, activation='linear')
def call(self, x):
x = self.lstm1(x)
x = self.lstm2(x)
x = self.flatten(x)
policy = self.policy(x)
value = self.value(x)
return policy, value
class A3CAgent():
def __init__(self):
self.state_size = (9, 23)
self.action_size = 2
self.save_path = os.path.join(os.getcwd(), 'model')
self.global_model = ActorCritic(self.action_size, self.state_size)
self.global_model.build((None, *self.state_size))
self.global_model.save(self.save_path)
if __name__ == "__main__":
global_agent = A3CAgent()
输出:
Traceback (most recent call last):
ValueError: Model <__main__.ActorCritic object at 0x000001933D7F1E10> cannot be saved because the input shapes have not been set.
Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`.
To manually set the shapes, call `model.build(input_shape)`.
我写了 'self.global_model.build((None, *self.state_size))',但它不起作用。
如何拨打model.build(input_shape)或解决?
【问题讨论】:
标签: tensorflow keras reinforcement-learning