keras model.fit()是否会在每个纪元后重置指标?如何手动重置指标?

gv8xihay  于 7个月前  发布在  其他
关注(0)|答案(2)|浏览(79)

据我所知,model.fit(epochs=NUM_EPOCHS)不会为每个epoch重置指标。我的指标和model.fit()代码看起来像这样(简化):

import tensorflow as tf
from tensorflow.keras import applications

NUM_CLASSES = 4
INPUT_SHAPE = (256, 256, 3)
MODELS = {
    'DenseNet121': applications.DenseNet121,
    'DenseNet169': applications.DenseNet169
}
REDUCE_LR_PATIENCE = 2
REDUCE_LR_FACTOR = 0.7
EARLY_STOPPING_PATIENCE = 4

for modelName, model in MODELS.items():

    loadedModel = model(include_top=False, weights='imagenet',
                        pooling='avg', input_shape=INPUT_SHAPE)

    sequentialModel = tf.keras.models.Sequential()
    sequentialModel.add(loadedModel)
    sequentialModel.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))

    aucCurve = tf.keras.metrics.AUC(curve = 'ROC', multi_label = True)
    categoricalAccuracy = tf.keras.metrics.CategoricalAccuracy()
    F1Score  = tfa.metrics.F1Score(num_classes = NUM_CLASSES, average = 'macro', threshold = None)
    metrics = [aucCurve, categoricalAccuracy, F1Score]

    sequentialModel.compile(metrics=metrics)

    callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=REDUCE_LR_PATIENCE, verbose=1, factor=REDUCE_LR_FACTOR),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=EARLY_STOPPING_PATIENCE),
    tf.keras.callbacks.ModelCheckpoint(filepath=modelName + '_epoch-{epoch:02d}.h5', monitor='val_loss', save_best_only=False, verbose=1),
    tf.keras.callbacks.CSVLogger(modelName + '_training.csv')]

    sequentialModel.fit(epochs=NUM_EPOCHS)

字符串
也许我可以通过在NUM_EPOCHS范围内执行for循环来重置指标,并在for循环中初始化指标,但我不确定这是否是一个好的解决方案。此外,我有ModelCheckpoint和CSVLogger回调,它们需要来自model.fit()的epoch编号,所以如果我执行for循环,它不会真正工作。
你对如何重置每个epoch的指标有什么建议吗?在NUM_EPOCHS范围内进行for循环是唯一的解决方案吗?谢谢。

91zkwejq

91zkwejq1#

不,指标是按epoch计算的。它们不是在epoch上平均的,而是在每个epoch的批次上平均的。你会看到指标在epoch之后不断提高,因为你的模型正在训练。

gr8qqesn

gr8qqesn2#

**通过方法reset_state**控制行为。

通常情况下,它看起来像

def reset_state(self):
    # Reset the metric state at the start of each epoch.
    self.my_state_variable.assign(0.0)

字符串
但是如果需要的话,可以用不同的方式定义(例如,在子类化时,在您自己的度量中定义)。

相关问题