【发布时间】:2017-11-25 10:42:18
【问题描述】:
最近想实现 GAN 模型,使用 tf.Dataset 和 Iterator 读取人脸图像作为训练数据。
数据集和迭代器对象的代码是:
self.dataset = tf.data.Dataset.from_tensor_slices(convert_to_tensor(self.data_ob.train_data_list, dtype=tf.string))
self.dataset = self.dataset.map(self._parse_function)
#self.dataset = self.dataset.shuffle(buffer_size=10000)
self.dataset = self.dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
self.iterator = tf.data.Iterator.from_structure(self.dataset.output_types, self.dataset.output_shapes)
self.next_x = self.iterator.get_next()
我的新 GAN 模型是:
self.z_mean, self.z_sigm = self.Encode(self.next_x)
self.z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigm))*self.ep)
self.x_tilde = self.generate(self.z_x, reuse=False)
#the feature
self.l_x_tilde, self.De_pro_tilde = self.discriminate(self.x_tilde)
#for Gan generator
self.x_p = self.generate(self.zp, reuse=True)
# the loss of dis network
self.l_x, self.D_pro_logits = self.discriminate(self.next_x, True)
所以,问题是我两次使用 self.next_x 作为输入张量。每次的数据集都不一样。那么,如何解决这个问题以保留第一批以供重复使用呢?
【问题讨论】:
标签: python tensorflow tensorboard