服务器之家

服务器之家 > 正文

Pytorch 实现计算分类器准确率(总分类及子分类)

时间:2020-04-15 09:52     来源/作者:疯狂的小猪oO

分类器平均准确率计算:

?
1
2
3
4
5
6
7
8
9
10
11
12
correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())
 
      output = model(images)
 
      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())
 
      output = model(images)
 
      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/u014657795/article/details/86419197

相关文章

热门资讯

2022年最旺的微信头像大全 微信头像2022年最新版图片
2022年最旺的微信头像大全 微信头像2022年最新版图片 2022-01-10
蜘蛛侠3英雄无归3正片免费播放 蜘蛛侠3在线观看免费高清完整
蜘蛛侠3英雄无归3正片免费播放 蜘蛛侠3在线观看免费高清完整 2021-08-24
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
暖暖日本高清免费中文 暖暖在线观看免费完整版韩国
暖暖日本高清免费中文 暖暖在线观看免费完整版韩国 2021-05-08
返回顶部