在keras中连接不规则的输入

b0zn9rqh  于 2021-09-29  发布在  Java
关注(0)|答案(0)|浏览(210)

下面的代码创建一个连接2个输入的虚拟模型。一个输入用于输出大小为5的嵌入层,而第二个输入仅与嵌入层的输出合并:

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Embedding, Concatenate, Dense
from tensorflow.keras.models import Model
import keras

x = np.random.randint(0 ,50, size = (10,5,1))
y = np.random.randint(0 ,1, size = (10,1) )

def get_model():
  input1 = Input( shape =(None,11), name='timeseries_input' )
  input2 = Input( shape = (None,1) ,name='embedding_input')
  emb = Embedding(input_dim= len(np.unique(x)) , output_dim= 5)(input2)
  emb = keras.layers.Reshape( target_shape=( -1,5) )(emb)
  merged = Concatenate(axis =2 )([emb,input1])
  out = Dense(1)(merged)
  model = Model([input1,input2],out)
  model.summary()
  return model

m = get_model()
tf.keras.utils.plot_model(
    m,
    show_shapes=True,
    show_dtype=True,
    show_layer_names=True,
    rankdir="TB",
)

该代码工作并产生以下结构:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
embedding_input (InputLayer)    [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
embedding_17 (Embedding)        (None, None, 1, 5)   155         embedding_input[0][0]            
__________________________________________________________________________________________________
tf.compat.v1.shape_15 (TFOpLamb (4,)                 0           embedding_17[0][0]               
__________________________________________________________________________________________________
tf.__operators__.getitem_12 (Sl ()                   0           tf.compat.v1.shape_15[0][0]      
__________________________________________________________________________________________________
tf.reshape_12 (TFOpLambda)      (None, None, 5)      0           embedding_17[0][0]               
                                                                 tf.__operators__.getitem_12[0][0]
__________________________________________________________________________________________________
timeseries_input (InputLayer)   [(None, None, 11)]   0                                            
__________________________________________________________________________________________________
concatenate_11 (Concatenate)    (None, None, 16)     0           tf.reshape_12[0][0]              
                                                                 timeseries_input[0][0]           
__________________________________________________________________________________________________
dense_10 (Dense)                (None, None, 1)      17          concatenate_11[0][0]             
==================================================================================================
Total params: 172
Trainable params: 172
Non-trainable params: 0
__________________________________________________________________________________________________


但是,当添加 ragged=True 对于我的投入:

input1 = Input( shape =(None,11), name='timeseries_input',ragged=True )
  input2 = Input( shape = (None,1) ,name='embedding_input',ragged=True)

代码中断,出现以下错误:

TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 1, 5), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(None,), dtype=int64)). Consider casting elements to a supported type.

如何连接不规则的输入?我错过了什么?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题