【发布时间】:2018-03-26 13:44:44
【问题描述】:
我有 n 个网络,其中所有输入都带有占位符,我想将所有这些网络链接到另一个占位符(之后创建)作为公共输入。
class GroupOfNetworks(object):
def __init__(self,subtask_nets,ob_space):
self.x_inputs = [st_net.x for st_net in subtask_nets] #list of network inputs
其中st_net.x 是一个占位符,声明如下。
class Network(object):
def __init__(self, ob_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) `#single network input
我希望对所有这些网络都有一个共同的输入,因此我只需要在我的feed_dict 中有一个键值对。我尝试对占位符进行分配操作(下面的代码 sn-p),但这会引发错误,因为它们是张量而不是变量。
#in class GroupOfNetworks...
common_x = tf.placeholder(tf.float32, [None] + list(ob_space),"common_input")
set_input = tf.assign(self.x_inputs[0].x,common_x,"link_subtask_input") # DOES NOT WORK
到目前为止,我一直使用以编程方式生成的 feed_dict(如下所示),但这不在图表上,并且在从 .meta 文件加载图表时无法导入。
def make_common_feed_dict(self,x):
return {placeholder:x for placeholder in self.x_inputs}
有人知道更好的解决方案吗?
【问题讨论】:
标签: python tensorflow