keras 层sequential 8的输入0与层不兼容:输入形状的预期轴-1具有值1,但接收到具有形状(32,1,21)的输入

6za6bjd0  于 7个月前  发布在  其他
关注(0)|答案(1)|浏览(86)

我正在使用Keras构建CNN,它显示错误。代码如下:

input_shape = (21,)  # For your tabular data, add a channel dimension of 1

model = models.Sequential()

# Convolutional layers
model.add(layers.Conv1D(32, 3, activation='relu', input_shape=(21, 1)))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))

# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))  # 10 output units for 10 classes

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(X_train_reshaped, Y_train, epochs=10, validation_data=(X_test, Y_test))

# Predict using the model
y_pred = model.predict(X_test_reshaped)
y_pred_classes = np.argmax(y_pred, axis=1)

字符串
输入有500万行,有21个特征。代码应该对10个不同的类进行多类分类。代码有什么问题?
谢谢你的帮助。

am46iovg

am46iovg1#

您的X_train_reshaped的形状是(32, 1, 21),但您已经使用input_shape=(21, 1)定义了模型。因此您需要将X_train_reshaped重新塑造为形状(32, 21, 1)以匹配模型的输入形状。
这里有一个简单的例子供您参考:

model = models.Sequential()

# Convolutional layers
model.add(layers.Conv1D(32, 3, activation='relu', input_shape=(21, 1)))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))

# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))  # 10 output units for 10 classes

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
X_train_reshaped = tf.random.normal([32,21,1])
Y_train = tf.random.normal([32,10])
X_test = tf.random.normal([32,21,1])
Y_test = tf.random.normal([32,10])
# Train the model
model.fit(X_train_reshaped, Y_train, epochs=10, validation_data=(X_test, Y_test))

字符串

相关问题