我尝试使用torchmetrics
库的MetricTracker
记录不平衡分类数据集的几个指标。我发现Precision
Recall
和F1Score
的结果总是等于Accuracy
,尽管它不应该。
如何产生这种行为的最小示例如下所示:
import torch
import torchmetrics
from torchmetrics import MetricTracker, MetricCollection
from torchmetrics import Accuracy, F1Score, Precision, Recall, CohenKappa
num_classes = 3
list_of_metrics = [Accuracy(task="multiclass", num_classes=num_classes),
F1Score(task="multiclass", num_classes=num_classes),
Precision(task="multiclass",num_classes=num_classes),
Recall(task="multiclass",num_classes=num_classes),
CohenKappa(task="multiclass",num_classes=num_classes)
]
maximize_list=[True,True,True,True,True]
metric_coll = MetricCollection(list_of_metrics)
tracker = MetricTracker(metric_coll, maximize=maximize_list)
pred = torch.Tensor([[0,.1,.5], # 2
[0,.1,.5], # 2
[0,.1,.5], # 2
[0,.1,.5], # 2
[0,.1,.5], # 2
[0.9,.1,.5]]) # 0
label = torch.Tensor([2,2,2,2,2,1])
tracker.increment()
tracker.update(pred, label)
for key, val in tracker.compute_all().items():
print(key,val)
字符串
输出量:
MulticlassAccuracy tensor([0.8333])
MulticlassF1Score tensor([0.8333])
MulticlassPrecision tensor([0.8333])
MulticlassRecall tensor([0.8333])
MulticlassCohenKappa tensor([0.4545])
型
有人知道这里的问题是什么以及如何解决它吗?
我使用torchmetrics
库的0.11.1
版本。
1条答案
按热度按时间643ylb081#
显然有一个documentation bug。
解决办法是总是明确地说明你想要什么样的“平均值”。