【问题标题】:Writing summary.scalar with Dataset API and Keras使用 Dataset API 和 Keras 编写 summary.scalar
【发布时间】:2019-07-22 07:24:00
【问题描述】:

我使用 tensorflow Keras API 并尝试将自定义标量添加到张量板上,但除了显示损失之外什么都没有。

这是模型的代码:

embedding_in = Embedding(
    input_dim=vocab_size + 1 + 1,  
    output_dim=dim,
    mask_zero=True,
)

embedding_out = Embedding(
    input_dim=vocab_size + 1 + 1,  
    output_dim=dim,
    mask_zero=True,
)

input_a = Input((None,))
input_b = Input((None,))
input_c = Input((None, None))

emb_target = embedding_in(input_a)
emb_context = embedding_out(input_b)
emb_negatives = embedding_out(input_c)

emb_gru = GRU(dim, return_sequences=True)(emb_target)

num_negatives = tf.shape(input_c)[-1]


def make_logits(tensors):
    emb_gru, emb_context, emb_negatives = tensors
    true_logits = tf.reduce_sum(tf.multiply(emb_gru, emb_context), axis=2)
    true_logits = tf.expand_dims(true_logits, -1)
    sampled_logits = tf.squeeze(
        tf.matmul(emb_negatives, tf.expand_dims(emb_gru, axis=2),
                  transpose_b=True), axis=3)
    true_logits = true_logits*0
    sampled_logits = sampled_logits*0

    logits = K.concatenate([true_logits, sampled_logits], axis=-1)
    return logits


logits = Lambda(make_logits)([emb_gru, emb_context, emb_negatives])

mean = tf.reduce_mean(logits)
tf.summary.scalar('mean_logits', mean)

model = keras.models.Model(inputs=[input_a, input_b, input_c], outputs=[logits])

我特别想看看mean_logits 标量在每批之后的演变。

我这样创建和编译模型:

model = build_model(dim, vocab_size)
model.compile(loss='binary_crossentropy', optimizer='sgd')
callbacks = [
        keras.callbacks.TensorBoard(logdir, histogram_freq=1)
]

我对模型使用 tf Dataset API:

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:

        sess.run(iterator.initializer)
        sess.run(tf.tables_initializer())
        model.fit(iterator, steps_per_epoch=100, 
                  callbacks=callbacks,
                  validation_data=iterator,
                  validation_steps=1
                 )

但是,我在 tensorboard 中没有得到任何 mean_logits 图表,而且它不在图表中。

如何在每批之后在张量板中跟踪 mean_logits 标量?

我使用 tf 1.12 和 keras 2.1。

【问题讨论】:

    标签: python tensorflow keras tensorboard


    【解决方案1】:

    我也遇到了同样的问题。似乎 Keras TensorBoard 回调不会自动写入所有现有摘要,而只会写入那些 registered as metrics(并出现在 logs 字典中)。更新logs 对象是一个不错的技巧,因为它允许在其他回调中使用这些值,请参阅Early stopping and learning rate schedule based on custom metric in Keras。我可以看到几种可能性:

    1.使用 Lambda 回调

    类似这样的:

    eval_callback = LambdaCallback(
        on_epoch_end=lambda epoch, logs: logs.update(
            {'mean_logits': K.eval(mean)}
        ))
    

    2。自定义 TensorBoard 回调

    您还可以对回调进行子类化并定义自己的逻辑。例如,我的学习率监控解决方法:

    class Tensorboard(Callback):                                                                                                                                                                                                                                          
        def __init__(self,                                                                                                                                                                                                                                                
                     log_dir='./log',                                                                                                                                                                                                                                     
                     write_graph=True):                                                                                                                                                                                                                                   
            self.write_graph = write_graph                                                                                                                                                                                                                                
            self.log_dir = log_dir                                                                                                                                                                                                                                        
    
        def set_model(self, model):                                                                                                                                                                                                                                       
            self.model = model                                                                                                                                                                                                                                            
            self.sess = K.get_session()                                                                                                                                                                                                                                   
            if self.write_graph:                                                                                                                                                                                                                                          
                self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)                                                                                                                                                                                        
            else:                                                                                                                                                                                                                                                         
                self.writer = tf.summary.FileWriter(self.log_dir)                                                                                                                                                                                                         
    
        def on_epoch_end(self, epoch, logs={}):                                                                                                                                                                                                                           
            logs.update({'learning_rate': float(K.get_value(self.model.optimizer.lr))})                                                                                                                                                                                   
            self._write_logs(logs, epoch)                                                                                                                                                                                                                                 
    
        def _write_logs(self, logs, index):                                                                                                                                                                                                                               
            for name, value in logs.items():                                                                                                                                                                                                                              
                if name in ['batch', 'size']:                                                                                                                                                                                                                             
                    continue                                                                                                                                                                                                                                              
                summary = tf.Summary()                                                                                                                                                                                                                                    
                summary_value = summary.value.add()                                                                                                                                                                                                                       
                if isinstance(value, np.ndarray):                                                                                                                                                                                                                         
                    summary_value.simple_value = value.item()                                                                                                                                                                                                             
                else:                                                                                                                                                                                                                                                     
                    summary_value.simple_value = value                                                                                                                                                                                                                    
                summary_value.tag = name                                                                                                                                                                                                                                  
                self.writer.add_summary(summary, index)                                                                                                                                                                                                                   
    
            self.writer.flush()                                                                                                                                                                                                                                           
    
        def on_train_end(self, _):                                                                                                                                                                                                                                        
            self.writer.close() 
    

    在这里,我只是将“learning_rate”明确添加到logs。但这种方式可以更加灵活和强大。

    3.指标技巧

    Here 是另一个有趣的解决方法。您需要做的是将自定义度量函数传递给模型的 compile() 调用,该调用返回聚合的汇总张量。这个想法是让 Keras 将您的汇总汇总操作传递给每个 session.run 调用并将其结果作为指标返回:

    x_entropy_t = K.sum(p_t * K.log(K.epsilon() + p_t), axis=-1, keepdims=True)
    full_policy_loss_t = -res_t + X_ENTROPY_BETA * x_entropy_t
    tf.summary.scalar("loss_entropy", K.sum(x_entropy_t))
    tf.summary.scalar("loss_policy", K.sum(-res_t))
    tf.summary.scalar("loss_full", K.sum(full_policy_loss_t))
    
    summary_writer = tf.summary.FileWriter("logs/" + args.name)
    
    def summary(y_true, y_pred):
        return tf.summary.merge_all()
    
    value_policy_model.compile(optimizer=Adagrad(), loss=loss_dict, metrics=[summary])
    l = value_policy_model.train_on_batch(x_batch, y_batch)
    l_dict = dict(zip(value_policy_model.metrics_names, l))
    
    summary_writer.add_summary(l_dict['value_summary'], global_step=iter_idx)
    summary_writer.flush()
    

    【讨论】:

    • 感谢您提供了许多解决方案,但令人惊讶的是,它们都不起作用。特别是,第一个需要输入数据以在mean 内提供。由于我使用 Dataset API 并且没有明确的批次,我不能这样做 K.eval
    • 第二个我认为有类似的问题,除了不清楚如何检索mean变量。
    • 我之前看到并尝试过的第三个,但问题是它需要iter_idx,这基本上意味着它不是使用.fit方法,而是使用.train_on_batch方法并独立提供批次在 for 循环中。而且我不想用train_on_batch 更改fit 方法,因为它的参数更少。所以我猜如果您使用 Dataset API,则不支持自定义 TB。
    • iter_idx 只是批号的整数值。我会尝试找到一些解决方法。例如,将summary_writer.add_summary 放在on_batch_end 的自定义回调中。然后你得到了 batch 参数的值。我无法运行您的代码,因此很难提供这样的有效解决方案。
    猜你喜欢
    • 2018-02-18
    • 1970-01-01
    • 1970-01-01
    • 2019-09-24
    • 1970-01-01
    • 2018-08-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多