【发布时间】:2021-09-23 18:15:41
【问题描述】:
我正在使用 BERT 模型对 Steam 评论数据集进行情感分析,其中我有 2 个标签:正面和负面。我已经用 2 个线性层对模型进行了微调,其代码如下。
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased",
num_labels = len(label_dict),
output_attentions = False,
output_hidden_states = False)
class bertModel(nn.Module):
def __init__(self, bert):
super(bertModel, self).__init__()
self.bert = bert
self.dropout1 = nn.Dropout(0.1)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(768, 512)
self.fc2 = nn.Linear(512, 2)
self.softmax = nn.LogSoftmax(dim = 1)
def forward(self, **inputs):
_, x = self.bert(**inputs)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.softmax(x)
return x
这是我的火车功能:
def model_train(model, device, criterion, scheduler, optimizer, n_epochs):
train_loss = []
model.train()
for epoch in range(1, epochs+1):
total_train_loss, training_loss = 0,0
for idx, batch in enumerate(dataloader_train):
model.zero_grad()
data = tuple(b.to(device) for b in batch)
inputs = {'input_ids': data[0],'attention_mask': data[1],'labels':data[2]}
outputs = model(**inputs)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#update the weights
optimizer.step()
scheduler.step()
training_loss += loss.item()
total_train_loss += training_loss
if idx % 25 == 0:
print('Epoch: {}, Batch: {}, Training Loss: {}'.format(epoch, idx, training_loss/10))
training_loss = 0
#avg training loss
avg_train_loss = total_train_loss/len(dataloader_train)
#validation data loss
avg_pred_loss = model_evaluate(dataloader_val)
#print for every end of epoch
print('End of Epoch {}, Avg. Training Loss: {}, Avg. validation Loss: {} \n'.format(epoch, avg_train_loss, avg_pred_loss))
我正在 Google Colab 上运行此代码。当我运行 train 函数时,出现以下错误,我尝试使用批量大小 32、256、512。
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
谁能帮我解决这个问题?谢谢。
代码更新:我尝试在 CPU 上运行代码,错误在于矩阵形状不匹配。输入的shape,self.bert之后的shape打印在图像中。由于第一个线性层(fc1)没有被执行,之后的形状没有被打印出来。
【问题讨论】:
标签: python pytorch sentiment-analysis bert-language-model