【问题标题】:tqdm not updating new set_postfix after last iterationtqdm 在最后一次迭代后没有更新新的 set_postfix
【发布时间】:2021-07-03 21:29:21
【问题描述】:

我想为 Pytorch 训练创建一个类似于 tensorflow.keras 的 tqdm 进度条。 这是我的要求:

  1. 对于每个训练步骤,都会显示进度和训练损失
  2. 在最后一次迭代中,它将提供验证损失的附加信息

我正在关注本教程https://towardsdatascience.com/training-models-with-a-progress-a-bar-2b664de3e13e,并且我设法满足了第一个要求。

唯一缺少的功能是在每次训练后给出验证损失。
这是我的代码:

for epoch in range(EPOCH):
    with tqdm(train_dataloader, unit=" batch") as tepoch:
        train_loss = 0
        val_loss = 0
        
        # Training part
        for idx,batch in enumerate(tepoch) :
            tepoch.set_description(f"Epoch {epoch}")
            <do training stuff>
            train_loss += loss.item()
            tepoch.set_postfix({'Train loss': loss.item()})
         train_loss /= (idx+1)

         # Evaluation part
         with torch.no_grad():
            for idx,batch in enumerate(val_dataloader) :
            <do inference stuff>
            val_loss += loss.item()
         val_loss /= (idx+1)

    tepoch.set_postfix({'Train loss': train_loss,"Val loss":val_loss})

这段代码给出这个:

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298]

但我想要的是:

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511, Val loss={number}]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298, Val loss={number}]

我已经看到了这个 SO tqdm update after last iteration,但对我来说这似乎不可行,因为验证损失是在所有训练完成后计算的。

【问题讨论】:

    标签: python tqdm


    【解决方案1】:

    工作示例

    import random
    import time
    EPOCH = 100
    BATCH_SIZE = 10
    for epoch in range(EPOCH):
      with tqdm(total=BATCH_SIZE, unit=" batch") as tepoch:
            tepoch.set_description(f"Epoch {epoch+1}")
            train_loss = 0
            val_loss = 0
            
            # Training part
            for idx,batch in enumerate(range(BATCH_SIZE)) :
                tepoch.update(1)
                # do training stuff
                time.sleep(0.5)
                loss = random.choice(range(10))
                train_loss += loss
                tepoch.set_postfix({'Batch': idx+1, 'Train loss (in progress)': loss})
    
            train_loss /= (idx+1)
    
            # Evaluation part
            time.sleep(0.5)
            val_loss += random.choice(range(10))
    
            val_loss /= (idx+1)
    
            tepoch.set_postfix({'Train loss (final)': train_loss, 'Val loss': val_loss})
            tepoch.close()
    

    输出

    Epoch 1: 100% 10/10 [00:11<00:00, 1.18s/ batch, Train loss (final)=4.4, Val loss=0.5]
    Epoch 2: 100% 10/10 [00:06<00:00, 1.62 batch/s, Train loss (final)=4.7, Val loss=0.3]
    Epoch 3: 80% 8/10 [00:03<00:00, 2.07 batch/s, Batch=7, Train loss (in progress)=9]
    

    【讨论】:

    • 谢谢,使用 OrderedDict 代替字典有什么好处?我正在使用常用的字典,您的解决方案仍然有效
    • 键不保证与声明的顺序相同
    • 只是偶然发现这个 SO,似乎从 Python 3.7 开始,字典的排序就像 OrderedDict stackoverflow.com/questions/39980323/…
    • 不错。我不知道。更新了不再使用它的答案
    猜你喜欢
    • 2021-06-27
    • 2023-04-03
    • 2015-08-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-02-13
    • 2023-04-03
    • 1970-01-01
    相关资源
    最近更新 更多