【问题标题】:Metrics mismatch between BertForSequenceClassification Class and my custom Bert ClassificationBertForSequenceClassification 类和我的自定义 Bert 分类之间的指标不匹配
【发布时间】:2021-03-22 08:43:36
【问题描述】:

我实现了我的自定义 Bert 二进制分类模型类,方法是在 Bert 模型(附在下面)之上添加一个分类器层。但是,当我使用官方的 BertForSequenceClassification 模型进行训练时,准确度/指标有很大不同,这让我怀疑我是否在课堂上遗漏了一些东西。

我有几个疑问:

在加载官方BertForSequenceClassificationfrom_pretrained 时,分类器的权重是从预训练模型中初始化的还是随机初始化的?因为在我的自定义类中它们是随机初始化的。

class MyCustomBertClassification(nn.Module):
    def __init__(self, encoder='bert-base-uncased',
                        num_labels,
                        hidden_dropout_prob):

    super(MyCustomBertClassification, self).__init__()
    self.config  = AutoConfig.from_pretrained(encoder)
    self.encoder = AutoModel.from_config(self.config)
    self.dropout = nn.Dropout(hidden_dropout_prob)
    self.classifier = nn.Linear(self.config.hidden_size, num_labels)

def forward(self, input_sent):
    outputs = self.encoder(input_ids=input_sent['input_ids'],
                          attention_mask=input_sent['attention_mask'],
                          token_type_ids=input_sent['token_type_ids'],
                          return_dict=True)
    
    pooled_output = self.dropout(outputs[1])
    # for both tasks
    logits = self.classifier(pooled_output)

    return logits

【问题讨论】:

    标签: pytorch huggingface-transformers


    【解决方案1】:

    当您使用from_pretrained 方法时,每个模型都会通过警告消息告诉您哪些层是随机初始化的:

    from transformers import BertForSequenceClassification
    
    b = BertForSequenceClassification.from_pretrained('bert-base-uncased')
    

    输出:

    Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
    

    您的实现与BertForSequenceClassification 之间的区别在于您根本不使用任何预训练的权重。 from_config 方法不会从 state_dict 加载预训练的权重:

    import torch
    from transformers import AutoModelForSequenceClassification, AutoConfig
    
    b2 = AutoModelForSequenceClassification.from_config(AutoConfig.from_pretrained('bert-base-uncased'))
    b3 = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
    
    print("Does from_config provides pretrained weights: {}".format(torch.equal(b.bert.embeddings.word_embeddings.weight, b2.base_model.embeddings.word_embeddings.weight)))
    print("Does from_pretrained provides pretrained weights: {}".format(torch.equal(b.bert.embeddings.word_embeddings.weight, b3.base_model.embeddings.word_embeddings.weight)))
    

    输出:

    Does from_config provides pretrained weights: False
    Does from_pretrained provides pretrained weights: True
    

    因此,您可能希望将班级更改为:

    class MyCustomBertClassification(nn.Module):
        def __init__(self, encoder='bert-base-uncased',
                           num_labels=2,
                           hidden_dropout_prob=0.1):
    
           super(MyCustomBertClassification, self).__init__()
           self.config  = AutoConfig.from_pretrained(encoder)
           self.encoder = AutoModel.from_pretrained(encoder)
           self.dropout = nn.Dropout(hidden_dropout_prob)
           self.classifier = nn.Linear(self.config.hidden_size, num_labels)
    
        def forward(self, input_sent):
           outputs = self.encoder(input_ids=input_sent['input_ids'],
                             attention_mask=input_sent['attention_mask'],
                             token_type_ids=input_sent['token_type_ids'],
                             return_dict=True)
        
           pooled_output = self.dropout(outputs[1])
           # for both tasks
           logits = self.classifier(pooled_output)
    
           return logits
    
    myB = MyCustomBertClassification()
    
    print(torch.equal(b.bert.embeddings.word_embeddings.weight, myB.encoder.embeddings.word_embeddings.weight))
    

    输出:

    True
    

    【讨论】:

    • 是的,这确实是主要问题之一,使用 from_config 而不是 from_pretrained。但是,我仍然看到官方课程存在一些性能滞后,这让我想知道在他们的分类器网络初始化中是否发生了一些特别的事情?想法?
    • 您可以查看代码here。但我不确定这是否解释了差异。 @pseudo_teetotaler
    猜你喜欢
    • 2021-11-10
    • 1970-01-01
    • 2021-11-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-06-03
    相关资源
    最近更新 更多