【发布时间】:2020-02-10 09:04:31
【问题描述】:
GAN 鉴别器
我用下面这段代码得到了GAN神经网络的鉴别器:
import tensorflow as tf
import numpy as np
from IPython.display import display, Audio
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))
# here is z with underline, it doesn't showing ceractly in stack.
# I use random data to test this function.
_z = np.random.uniform(-1., 1., size=[5, 257])
x = graph.get_tensor_by_name('x:0')
D_z = graph.get_tensor_by_name('D_z:0')
D_z = sess.run(D_z, {x: _z})
print(D_z)
自定义 Keras 损失函数
我想创建一个函数来自定义 keras 损失函数:
# Load the graph
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))
def gan_loss(y_true, y_pred):
_z = y_pred
x = graph.get_tensor_by_name('x:0')
D_z = graph.get_tensor_by_name('D_z:0')
D_z = sess.run(D_z, {x: _z})
return D_z
我遇到的问题
我遇到的问题是:不能喂 tesor,你必须用 numpy 或其他类型的数据来喂它。
TypeError:提要的值不能是 tf.Tensor 对象。可接受的提要值包括 Python 标量、字符串、列表或 numpy ndarray。
我喜欢Stak中的相关问题:Converting Tensor to np.array using K.eval() in Keras returns InvalidArgumentError
Tensorflow: How to feed a placeholder variable with a tensor?
GAN 神经网络,我如何获得鉴别器
X = tf.placeholder(tf.float32, [None, 257], name='x')
D_z, h3 = discriminator(X)
D_z = tf.identity(D_z, name='D_z')
D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='GAN/Discriminator')
# global_step = tf.train.get_or_create_global_step()
saver = tf.train.Saver(D_vars)
infer_dir = './infer/'
tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')
infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
tf.train.export_meta_graph(
filename=infer_metagraph_fp,
clear_devices=True,
saver_def=saver.as_saver_def())
tf.reset_default_graph()
【问题讨论】:
标签: python tensorflow machine-learning keras deep-learning