Tensorflow Keras model.fit使用生成器函数和steps_per_epoch

unguejic  于 8个月前  发布在  其他
关注(0)|答案(1)|浏览(109)

我有这个model.fit电话:

transformer.fit(
   x=data_generation.generate_dataset(batch_size, dontchange, train_indices, filenames),
   epochs=epochs,
   steps_per_epoch=len(train_indices),
   validation_data=data_generation.generate_dataset(batch_size, dontchange, val_indices, filenames),
   validation_steps=len(val_indices)
)

这是generate_dataset函数的一部分:

def generate_dataset(batch_size_in, dontChange, index_list_in, filenames):
    epoch = 0
    while True:
        epoch = epoch + 1
        raw_dataset = tf.data.TFRecordDataset(filenames)
        batch_size = batch_size_in
        index_list = []
        index_list = index_list_in
        cx = 1
        for index in index_list:
            tf.print("Epoch: {}, Batch: {}, Batches total: {}".format(epoch, cx, len(index_list_in)), summarize=32 *
            10 * 250, output_stream="file://logtest.txt")
            cx = cx + 1
            how_much_to_take = batch_size
            batch_of_records = raw_dataset.skip(index).take(how_much_to_take)
            max_number_of_tv_in_batch = 0
            batch = []

    [...]

            yield (tf.stack(input_batch), tf.stack(attention_mask_batch), tf.stack(padding_mask_batch)), tf.stack(output_batch)

因此,对于索引列表中的每个索引,其产生一个批次,因此steps_per_epoch和validation_steps等于通过val生成器和train生成器中的相应index_list_in的一个循环生成的批次的数量。所以一切都应该很顺利。
正如你所看到的,我使用tf.print将进度打印到一个日志文件中,我注意到一些我无法真正解释的事情。

Epoch: 1, Batch: 1, Batches total: 607
Epoch: 1, Batch: 2, Batches total: 607
Epoch: 1, Batch: 3, Batches total: 607
Epoch: 1, Batch: 4, Batches total: 607
Epoch: 1, Batch: 5, Batches total: 607
Epoch: 1, Batch: 6, Batches total: 607
Epoch: 1, Batch: 7, Batches total: 607 
[...]
Epoch: 1, Batch: 605, Batches total: 607
Epoch: 1, Batch: 606, Batches total: 607
Epoch: 1, Batch: 607, Batches total: 607
Epoch: 2, Batch: 1, Batches total: 607
Epoch: 1, Batch: 1, Batches total: 67
Epoch: 1, Batch: 2, Batches total: 67
Epoch: 1, Batch: 3, Batches total: 67
[...]
Epoch: 1, Batch: 66, Batches total: 67
Epoch: 1, Batch: 67, Batches total: 67
Epoch: 2, Batch: 1, Batches total: 67
Epoch: 2, Batch: 2, Batches total: 607
Epoch: 2, Batch: 3, Batches total: 607

因此,它基本上加载了607个批次,但随后它加载了一个a107批次,正如您可以从以下输出中看到的那样:

Epoch: 2, Batch: 1, Batches total: 607

就在它进入瓦尔批次的发生器之前。然后对于瓦尔生成器,它是相同的,它工作得很好,它产生了67个批次,因为它应该,但然后它进入下一个循环,出于某种原因,我看到这一点:

Epoch: 2, Batch: 1, Batches total: 67

然后它在train_generator中的第二个epoch开始:

Epoch: 2, Batch: 2, Batches total: 607

而不是:

Epoch: 2, Batch: 1, Batches total: 607

就应该这样
第三个时期也是一样

Epoch: 3, Batch: 2, Batches total: 607

而不是:

Epoch: 3, Batch: 1, Batches total: 607

但它为什么要这么做?正如我所说的,steps_per_epoch与通过index_list_in的一个循环产生的批次完全匹配。这是显而易见的,因为我使用len(train_indices)/len(瓦尔_indices)来确定批处理的数量。但它似乎需要一个多批次每时代。为什么?为什么?
我想要的是,它只加载和准确的所有607批次每一个时期,它应该开始与同一批每一个时期ofc。我需要改变什么?这是Keras的bug还是我做错了什么?

wnvonmuf

wnvonmuf1#

实际上我想我知道了。
我需要在第一个epoch中的火车生成器的开头添加第一批两次。(所有其他批次仅为steps_per_epoch批次)。
我需要在每个epoch的瓦尔生成器的开头添加第一批两次。
因为对于这些批次,实际上没有计算准确性和损失(模型不会在这些批次上训练)。我认为模型只需要这些批次的上下文。
它需要每个epoch中的瓦尔_generator中的context batch,因为它实际上期望一个完整的数据集在每个epoch中重复,而对于train_dataset,它希望以epoch步骤遍历data_set,因此它只需要一次context。
我在文档中没有找到任何关于这个的东西,所以它要么是一个bug,要么很难找到。

相关问题