tfa.seq2seq.dynamic_decode生成的序列长度小于最大_迭代参数指定的最大长度

bbmckpt7  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(292)

我正在使用tensorflow和tensorflow_插件库为chatbot开发seq2seq模型。在解码器中,我使用 tfa.seq2seq.dynamic_decode 生成序列。我发现了,因为 maximum_iteration = 15 ,有时解码器生成一批序列,所有这些序列的长度为8(启用填充),有时一批序列的长度为7,有时为4,等等。下面是解码器的代码部分:

infer_decoder = tfa.seq2seq.BeamSearchDecoder(
    cell=self.decoder_cell,
    beam_width=args['beam_width'],
    output_layer=self.output_dense)

infer_output, _, _ = tfa.seq2seq.dynamic_decode(
    decoder=infer_decoder,
    swap_memory=True,
    maximum_iterations=args['max_len'],
    decoder_init_input=self.embedding,
    decoder_init_kwargs={
        'start_tokens': tf.tile(tf.constant([args['SOS_ID']], dtype=tf.int32), [tf.shape(context_with_latent)[1]]),
        'end_token': args['EOS_ID'],
        'initial_state': tfa.seq2seq.tile_batch(init_state_tuple, args['beam_width'])
    })
infer_predicted_ids = infer_output.predicted_ids[:, :, 0]

我还注意到,随着训练的进行,生成的序列的长度趋向于接近最大值_迭代,但长度略微徘徊在15左右,而不是每个批次都保持在15左右。
那么这种行为是预期的吗?
以下是一些其他信息:
tensorflow版本:2.3.1
tensorflow_插件版本:0.11.2

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题