【问题标题】:pytorch cnn model stop at loss.backward() without any prompt?pytorch cnn 模型停在 loss.backward() 没有任何提示?
【发布时间】:2020-04-02 08:52:29
【问题描述】:

我的目标是做一个五类文本分类

我正在使用 cnnbase 模型运行 bert 微调,但我的项目在 loss.backward() 停止,cmd 没有任何提示。

我的程序在rnn baselstmrcnn中运行成功。

但是当我运行一些cnnbase 模型时,会出现一个奇怪的错误。

我的cnn型号代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
# from ..Models.Conv import Conv1d
from transformers.modeling_bert import BertPreTrainedModel, BertModel
n_filters = 200
filter_sizes = [2,3,4]
class BertCNN(BertPreTrainedModel):
    def __init__(self, config):
        super(BertPreTrainedModel, self).__init__(config)
        self.num_filters = n_filters
        self.filter_sizes = filter_sizes
        self.bert = BertModel(config)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, self.num_filters, (k, config.hidden_size))
                for k in self.filter_sizes])
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.fc_cnn = nn.Linear(self.num_filters *
                                len(self.filter_sizes), config.num_labels)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, input_ids,
                attention_mask=None, token_type_ids=None, head_mask=None):
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            head_mask=head_mask)
        encoder_out, text_cls = outputs
        out = encoder_out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv)
                         for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc_cnn(out)
        return out

我的火车代码:

        for step, batch in enumerate(data):
            self.model.train()
            batch = tuple(t.to(self.device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            print("input_ids, input_mask, segment_ids, label_ids SIZE: \n")   
            print(input_ids.size(), input_mask.size(),segment_ids.size(), label_ids.size()) 
            # torch.Size([2, 80]) torch.Size([2, 80]) torch.Size([2, 80]) torch.Size([2])
            logits = self.model(input_ids, segment_ids, input_mask)
            print("logits and label ids size: ",logits.size(), label_ids.size())
            # torch.Size([2, 5]) torch.Size([2])
            loss = self.criterion(output=logits, target=label_ids)
            if len(self.n_gpu) >= 2:
                loss = loss.mean()
            if self.gradient_accumulation_steps > 1:
                loss = loss / self.gradient_accumulation_steps
            if self.fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                clip_grad_norm_(amp.master_params(self.optimizer), self.grad_clip)
            else:
                loss.backward() # I debug find that the program stop at this line without any error prompt

将批大小更改为 1 该错误仍然发生

第一步登录:

logits 张量([[ 0.8831, -0.0368, -0.2206, -2.3484, -1.3595]], device='cuda:1', grad_fn=)

step1 loss:

张量(1.5489, device='cuda:1', grad_fn=NllLossBackward>)

但是为什么 loss.backward() 不能呢?

【问题讨论】:

  • 请澄清并编辑您的问题以添加有关错误或错误的信息。
  • 我的程序停在第一个火车步骤 loss.backward() 。在训练模型时没有任何错误提示。如果有错误提示,我会贴出来,但是……
  • 这个bug只有在使用cnnbase模型时才会出现。
  • 你的程序到底停在哪里?在执行loss.backward() 之前或之后或执行该行之后?你能添加一些打印语句来找到它吗?
  • 感谢您的回复。我的程序在 loss.backward() 之后停止。我尝试在 loss.backward() 之后添加一些打印功能,但没有成功。并且损失值可以按照我上面提到的来计算。我已经调试并确认了。

标签: python nlp pytorch


【解决方案1】:

我尝试在linux平台上运行我的程序,运行成功。

所以很有可能是不同操作系统造成的

上一个操作系统:win 10

【讨论】:

  • 我经历了完全相同的事情 {model: CNN-GRU;以前的操作系统:win10,当前操作系统:linux}。我不明白为什么。
【解决方案2】:

我也遇到了同样的问题。 就我而言,这个问题源于 pytorch 的版本兼容性。当我将我的 pytorch 升级到最新版本(1.5.1 -> 1.8.x)时,问题得到了解决。我认为这种问题来自 pytorch 的 nn.Conv 类......因为我发现我的脚本在删除它们时运行良好。

【讨论】:

    猜你喜欢
    • 2021-12-31
    • 2021-08-23
    • 2020-10-31
    • 2022-10-15
    • 2021-03-28
    • 1970-01-01
    • 2020-11-09
    • 1970-01-01
    • 2021-11-24
    相关资源
    最近更新 更多