1. 程式人生 > 程式設計 >Pytorch 實現計算分類器準確率(總分類及子分類)

Pytorch 實現計算分類器準確率(總分類及子分類)

分類器平均準確率計算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i,(images,labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output,1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分類器各個子類準確率計算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i,1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1,acc)

以上這篇Pytorch 實現計算分類器準確率(總分類及子分類)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。