我正在使用Keras构建CNN,它显示错误。代码如下:
input_shape = (21,) # For your tabular data, add a channel dimension of 1
model = models.Sequential()
# Convolutional layers
model.add(layers.Conv1D(32, 3, activation='relu', input_shape=(21, 1)))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) # 10 output units for 10 classes
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(X_train_reshaped, Y_train, epochs=10, validation_data=(X_test, Y_test))
# Predict using the model
y_pred = model.predict(X_test_reshaped)
y_pred_classes = np.argmax(y_pred, axis=1)
字符串
输入有500万行,有21个特征。代码应该对10个不同的类进行多类分类。代码有什么问题?
谢谢你的帮助。
1条答案
按热度按时间am46iovg1#
您的
X_train_reshaped
的形状是(32, 1, 21)
,但您已经使用input_shape=(21, 1)
定义了模型。因此您需要将X_train_reshaped
重新塑造为形状(32, 21, 1)
以匹配模型的输入形状。这里有一个简单的例子供您参考:
字符串