抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

Tensorflow 自带的指标不够用,尤其是处理不平衡数据集时,macro recall 可以指示出不被 classes weight 所影响的平均recall

定义

自定义指标 class 需要继承 tf.keras.metrics.Metric

解释一下其中的 function

  • reset_states 每个epoch结束后清零
  • update_state 在一个epoch中每个step update的方式
  • result 返回的值

我在 class 中维护了一张 confusion_matrix 叫做total_cm。因此能很方便的计算macro recallmacro f1 等数值。

通过confusion_matrix可以计算其他很多的数值,可以通过修改process_confusion_matrix 来实现

用法

下面定义了两个自定义指标 MacroRecallMacroF1 ,使用的时候只要在model.compile时加进metrics即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class MacroRecall(tf.keras.metrics.Metric):
def __init__(self, num_classes=2, **kwargs):
super(MacroRecall,self).__init__(name='macro_recall',**kwargs) # handles base args (e.g., dtype)
self.num_classes=num_classes
self.total_cm = self.add_weight("total", shape=(num_classes,num_classes), initializer="zeros")

def reset_states(self):
for s in self.variables:
s.assign(tf.zeros(shape=s.shape))

def update_state(self, y_true, y_pred,sample_weight=None):
self.total_cm.assign_add(self.confusion_matrix(y_true,y_pred))
return self.total_cm

def result(self):
return self.process_confusion_matrix()

def confusion_matrix(self,y_true, y_pred):
"""
Make a confusion matrix
"""
y_true=tf.argmax(y_true,1)
y_pred=tf.argmax(y_pred,1)
cm=tf.math.confusion_matrix(y_true,y_pred,dtype=tf.float32,num_classes=self.num_classes)
return cm

def process_confusion_matrix(self):
cm=self.total_cm
diag_part=tf.linalg.diag_part(cm)
# precision=diag_part/(tf.reduce_sum(cm,0)+tf.constant(1e-15))
recall=diag_part/(tf.reduce_sum(cm,1)+tf.constant(1e-15))
# f1=2*precision*recall/(precision+recall+tf.constant(1e-15))
return tf.reduce_sum(recall)/self.num_classes


class MacroF1(tf.keras.metrics.Metric):
def __init__(self, num_classes=2, **kwargs):
super(MacroF1,self).__init__(name='macro_f1',**kwargs) # handles base args (e.g., dtype)
self.num_classes=num_classes
self.total_cm = self.add_weight("total", shape=(num_classes,num_classes), initializer="zeros")

def reset_states(self):
for s in self.variables:
s.assign(tf.zeros(shape=s.shape))

def update_state(self, y_true, y_pred,sample_weight=None):
self.total_cm.assign_add(self.confusion_matrix(y_true,y_pred))
return self.total_cm

def result(self):
return self.process_confusion_matrix()

def confusion_matrix(self,y_true, y_pred):
"""
Make a confusion matrix
"""
y_true=tf.argmax(y_true,1)
y_pred=tf.argmax(y_pred,1)
cm=tf.math.confusion_matrix(y_true,y_pred,dtype=tf.float32,num_classes=self.num_classes)
return cm

def process_confusion_matrix(self):
cm=self.total_cm
diag_part=tf.linalg.diag_part(cm)
precision=diag_part/(tf.reduce_sum(cm,0)+tf.constant(1e-15))
recall=diag_part/(tf.reduce_sum(cm,1)+tf.constant(1e-15))
f1=2*precision*recall/(precision+recall+tf.constant(1e-15))
return tf.reduce_sum(f1)/self.num_classes

# model
model.compile(metrics=['accuracy',MacroRecall(num_classes),MacroF1(num_classes)])

评论