【发布时间】:2019-12-07 01:29:47
【问题描述】:
对于这个 Humanpose TensorFlow 网络,network_cmu 和 base,它只接受 NHWC 输入格式。 如果我以 NCHW 格式构建网络,则会出现错误
Depth of input (32) is not a multiple of input depth of filter (3) for 'conv1_1/Conv2D' (op: 'Conv2D') with input shapes: [1,3,24,32], [3,3,3,64].
我构建网络的代码是
import tensorflow as tf
import numpy as np
from network_cmu import CmuNetwork
def main():
#print(tensor_util.MakeNdarray(n.attr['value'].tensor))
placeholder_input = tf.placeholder(dtype=tf.float32, shape=(1, 3, 24, 32), name="image")
net = CmuNetwork({'image': placeholder_input}, trainable=False)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
#for n in tf.get_default_graph().as_graph_def().node:
# print(n.name)
save_path = saver.save(sess, "cmuThreeOutputs/model.ckpt")
if __name__ == '__main__':
main()
我应该改变什么以拥有 NCHW 格式的网络?
【问题讨论】:
-
这能回答你的问题吗? Convert between NHWC and NCHW in TensorFlow
标签: python tensorflow