我正试图用神经网络练习分类,只是从scikit学习导入的旧Iris数据集,但遇到了一个维度问题,我不知道如何解决。
我还意识到,还有其他方法,包括tensorflow本身,可以检索虹膜数据,这些方法可能已经采用了更好的格式来使用数据,但只是为了便于理解,我想继续使用从scikit learn导入的数据。
# import some data to play with
iris = datasets.load_iris()
X = iris.data # we only take the first two features.
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
# get sequence length
T = X_train.shape[1]
X_train.shape
(100, 4)
i = Input(shape=(T,))
x = Dense(32, activation='swish')(i)
x = Dropout(0.40)(x)
x = Dense(64, activation='swish')(x)
x = Dropout(0.40)(x)
x = Dense(32, activation='swish')(x)
x = Dropout(0.40)(x)
x = Dense(3, activation="softmax")(x)
model_1 = Model(i, x)
# Compile the model
model_1.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy", precision])
# Fit the model (to the normalized data)
r = model_1.fit(X_train,
y_train,
epochs=40,
validation_data=(X_test, y_test))
上述代码导致以下错误消息:
Epoch 1/40
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-19-262061005876> in <module>()
8 y_train,
9 epochs=40,
---> 10 validation_data=(X_test, y_test))
9 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
ValueError: in user code:
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:835 run_step **
outputs = model.train_step(data)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:792 train_step
self.compiled_metrics.update_state(y, y_pred, sample_weight)
/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py:457 update_state
metric_obj.update_state(y_t, y_p, sample_weight=mask)
/usr/local/lib/python3.7/dist-packages/keras/utils/metrics_utils.py:73 decorated
update_op = update_state_fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/keras/metrics.py:177 update_state_fn
return ag_update_state(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/keras/metrics.py:1366 update_state **
sample_weight=sample_weight)
/usr/local/lib/python3.7/dist-packages/keras/utils/metrics_utils.py:623 update_confusion_matrix_variables
y_pred.shape.assert_is_compatible_with(y_true.shape)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_shape.py:1161 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (None, 3) and (None, 1) are incompatible
我不清楚该如何解决这个问题。