【发布时间】:2020-09-29 19:51:26
【问题描述】:
****我将我的模型和数据设置为同一个设备,但总是引发如下错误: RuntimeError:输入类型(torch.FloatTensor)和权重类型(torch.cuda.FloatTensor)应该相同** 以下是训练代码**
total_epoch = 1
best_epoch = 0
training_losses = []
val_losses = []
for epoch in range(total_epoch):
epoch_train_loss = 0
for X, y in train_loader:
X, y = X.cuda(), y.cuda()
optimizer.zero_grad()
result = model(X)
loss = criterion(result, y).
epoch_train_loss += loss.item()
loss.backward()
optimizer.step()
training_losses.append(epoch_train_loss)
epoch_val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.cuda(), y.cuda()
result = model(X)
loss = criterion(result, y)
epoch_val_loss += loss.item()
_, maximum = torch.max(result.data, 1)
total += y.size(0)
correct += (maximum == y).sum().item()
val_losses.append(epoch_val_loss)
accuracy = correct/total
print("EPOCH:", epoch, ", Training Loss:", epoch_train_loss, ", Validation Loss:", epoch_val_loss, ", Accuracy: ", accuracy)
if min(val_losses) == val_losses[-1]:
best_epoch = epoch
checkpoint = {'model': model,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict()}
torch.save(checkpoint, models_dir + '{}.pth'.format(epoch))
print("Model saved")
当我使用cv2.capture(0) 运行以下代码进行检测时。
import cvlib as cv
from PIL import Image
cap = cv2.VideoCapture(0)
font_scale=1
thickness = 2
red = (0,0,255)
green = (0,255,0)
blue = (255,0,0)
font=cv2.FONT_HERSHEY_SIMPLEX
face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades +'haarcascade_frontalface_default.xml')
while(cap.isOpened()):
ret, frame = cap.read()
if ret == True:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.4, 4)
for (x, y, w, h) in faces:
cv2.rectangle(frame, (x, y), (x+w, y+h), blue, 2)
croped_img = frame[y:y+h, x:x+w]
pil_image = Image.fromarray(croped_img, mode = "RGB")
pil_image = train_transforms(pil_image)
image = pil_image.unsqueeze(0)
result = loaded_model(image)
_, maximum = torch.max(result.data, 1)
prediction = maximum.item()
if prediction == 0:
cv2.putText(frame, "Masked", (x,y - 10), font, font_scale, green, thickness)
cv2.rectangle(frame, (x, y), (x+w, y+h), green, 2)
elif prediction == 1:
cv2.putText(frame, "No Mask", (x,y - 10), font, font_scale, red, thickness)
cv2.rectangle(frame, (x, y), (x+w, y+h), red, 2)
cv2.imshow('frame',frame)
if (cv2.waitKey(1) & 0xFF) == ord('q'):
break
else:
break
cap.release() cv2.destroyAllWindows()
关于函数loaded_model声明如下
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = checkpoint['model']
model.load_state_dict(checkpoint['state_dict'])
for parameter in model.parameters():
parameter.requires_grad = False
return model.eval()
filepath = models_dir + str(best_epoch) + ".pth"
loaded_model = load_checkpoint(filepath)
错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-29-b3a630684f44> in <module>()
43
44
---> 45 result = loaded_model(image)
46 _, maximum = torch.max(result.data, 1)
47 prediction = maximum.item()
5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
344 _pair(0), self.dilation, self.groups)
345 return F.conv2d(input, weight, self.bias, self.stride,
--> 346 self.padding, self.dilation, self.groups)
347
348 def forward(self, input):
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
**I hope you can answer it.Thanks!**
【问题讨论】:
-
请不要发布错误的屏幕截图,而是将它们作为文本粘贴到格式化的代码块中,就像代码一样,并包含完整的错误。您已经切断了一些堆栈跟踪,这将包括错误的来源(在您的代码中)。那么,错误是发生在
result = loaded_model(image)这一行吗? -
我编辑了问题,我还添加了函数'loaded_model(image)',请再看一下。
标签: python machine-learning deep-learning pytorch gpu